diff --git a/.coveragerc b/.coveragerc index 3052ac253d..ef69581861 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,28 +2,27 @@ source = urllib3 -omit = - *urllib3/packages/* - *urllib3/contrib/appengine.py - *urllib3/contrib/ntlmpool.py - *urllib3/contrib/pyopenssl.py - *urllib3/contrib/securetransport.py - *urllib3/contrib/_securetransport/* - [paths] source = src/urllib3 - .tox/*/lib/python*/site-packages/urllib3 - .tox\*\Lib\site-packages\urllib3 - .tox/pypy*/site-packages/urllib3 + */urllib3 + *\urllib3 [report] +omit = + src/urllib3/contrib/pyopenssl.py + src/urllib3/contrib/securetransport.py + src/urllib3/contrib/_securetransport/* + exclude_lines = + except ModuleNotFoundError: except ImportError: pass import - raise + raise NotImplementedError .* # Platform-specific.* .*:.* # Python \d.* .* # Abstract .* # Defensive: + if (?:typing.)?TYPE_CHECKING: + ^\s*?\.\.\.\s*$ diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9dadfa918e..64afa535c7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ # Restrict all files related to deploying to # require lead maintainer approval. -.github/CODEOWNERS @sethmlarson @shazow -src/urllib3/_version.py @sethmlarson @shazow -setup.py @sethmlarson @shazow -ci/ @sethmlarson @shazow -.travis.yml @sethmlarson @shazow +.github/workflows/ @sethmlarson @pquentin @shazow +.github/CODEOWNERS @sethmlarson @pquentin @shazow +src/urllib3/_version.py @sethmlarson @pquentin @shazow +pyproject.toml @sethmlarson @pquentin @shazow +ci/ @sethmlarson @pquentin @shazow diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 7cc1a23e9e..5e16c95532 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,3 +1,3 @@ tidelift: pypi/urllib3 +github: urllib3 open_collective: urllib3 -custom: https://gitcoin.co/grants/65/urllib3 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 9b435ceda4..6cbf5a88f1 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,7 +1,7 @@ blank_issues_enabled: false contact_links: - name: 📚 Documentation - url: https://urllib3.readthedocs.io/en/latest/ + url: https://urllib3.readthedocs.io about: Make sure you read the relevant docs - name: ❓ Ask on StackOverflow url: https://stackoverflow.com/questions/tagged/urllib3 diff --git a/.github/PULL_REQUEST_TEMPLATE/release.md b/.github/PULL_REQUEST_TEMPLATE/release.md new file mode 100644 index 0000000000..5dc84b28dc --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/release.md @@ -0,0 +1,19 @@ +* [ ] See if all tests, including integration, pass +* [ ] Get the release pull request approved by a [CODEOWNER](https://github.com/urllib3/urllib3/blob/main/.github/CODEOWNERS) +* [ ] Squash merge the release pull request with message "`Release `" +* [ ] Tag with X.Y.Z, push tag on urllib3/urllib3 (not on your fork, update `` accordingly) + * Notice that the `` shouldn't have a `v` prefix (Use `1.26.6` instead of `v.1.26.6`) + * ``` + git tag -s -a '' -m 'Release: ' + git push --tags + ``` +* [ ] Execute the `publish` GitHub workflow. This requires a review from a maintainer. +* [ ] Ensure that all expected artifacts are added to the new GitHub release. Should + be one `.whl`, one `.tar.gz`, and one `multiple.intoto.jsonl`. Update the GitHub + release to have the content of the release's changelog. +* [ ] Announce on: + * [ ] Twitter + * [ ] Discord + * [ ] OpenCollective +* [ ] Update Tidelift metadata +* [ ] If this was a 1.26.x release, add changelog to the `main` branch diff --git a/.github/SECURITY.md b/.github/SECURITY.md new file mode 100644 index 0000000000..005467cec0 --- /dev/null +++ b/.github/SECURITY.md @@ -0,0 +1,4 @@ +# Security Disclosures + +To report a security vulnerability, please use the [Tidelift security contact](https://tidelift.com/security). +Tidelift will coordinate the fix and disclosure with maintainers. diff --git a/.github/codeql.yml b/.github/codeql.yml new file mode 100644 index 0000000000..1e62c926ed --- /dev/null +++ b/.github/codeql.yml @@ -0,0 +1,3 @@ +paths: +- "src/" + diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..27e0fa5149 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + ignore: + # Ignore all patch releases as we can manually + # upgrade if we run into a bug and need a fix. + - dependency-name: "*" + update-types: ["version-update:semver-patch"] diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 0000000000..2a9841d9d8 --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,28 @@ +name: Check + +on: + pull_request: + types: [labeled, unlabeled, opened, reopened, synchronize] + +permissions: "read-all" + +jobs: + check-changelog-entry: + name: changelog entry + runs-on: ubuntu-latest + + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + with: + # `towncrier check` runs `git diff --name-only origin/main...`, which + # needs a non-shallow clone. + fetch-depth: 0 + + - name: "Check changelog" + if: "!contains(github.event.pull_request.labels.*.name, 'Skip Changelog')" + run: | + if ! pipx run towncrier check --compare-with origin/${{ github.base_ref }}; then + echo "Please see https://github.com/urllib3/urllib3/blob/main/changelog/README.rst for guidance." + false + fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ca94207b1..500f39f3ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,81 +2,137 @@ name: CI on: [push, pull_request] +permissions: "read-all" + defaults: run: shell: bash jobs: - lint: - runs-on: ubuntu-latest - - steps: - - name: Checkout Repository - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install dependencies - run: python3.8 -m pip install nox - - name: Lint the code - run: nox -s lint - package: runs-on: ubuntu-latest + timeout-minutes: 10 steps: - - name: Checkout Repository - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v2 + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Setup Python" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" with: - python-version: 3.7 - - name: Check packages + python-version: "3.x" + cache: "pip" + + - name: "Check packages" run: | - python3.7 -m pip install wheel twine rstcheck; - python3.7 setup.py sdist bdist_wheel; - rstcheck README.rst CHANGES.rst - python3.7 -m twine check dist/* + python -m pip install -U pip setuptools wheel build twine rstcheck + python -m build + rstcheck CHANGES.rst + python -m twine check dist/* + test: strategy: fail-fast: false matrix: - python-version: [2.7, 3.5, 3.6, 3.7, 3.8] - os: [macos-latest, windows-latest] - experimental: [false] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12-dev"] + os: + - macos-11 + - windows-latest + - ubuntu-20.04 # OpenSSL 1.1.1 + - ubuntu-22.04 # OpenSSL 3.0 + nox-session: [''] include: - - python-version: 3.9-dev - os: macos-latest + - experimental: false + - python-version: "pypy-3.7" + os: ubuntu-latest + experimental: false + nox-session: test-pypy + - python-version: "pypy-3.8" + os: ubuntu-latest + experimental: false + nox-session: test-pypy + - python-version: "3.x" + os: ubuntu-latest + experimental: false + nox-session: test_brotlipy + # Test CPython with a broken hostname_checks_common_name (the fix is in 3.9.3) + - python-version: "3.9.2" + os: ubuntu-20.04 # CPython 3.9.2 is not available for ubuntu-22.04. + experimental: false + nox-session: test-3.9 + - python-version: "3.12-dev" experimental: true + exclude: + # Ubuntu 22.04 comes with OpenSSL 3.0, so only CPython 3.9+ is compatible with it + # https://github.com/python/cpython/issues/83001 + - python-version: "3.7" + os: ubuntu-22.04 + - python-version: "3.8" + os: ubuntu-22.04 + # Testing with non-final CPython on macOS is too slow for CI. + - python-version: "3.12-dev" + os: macos-11 runs-on: ${{ matrix.os }} - name: ${{ fromJson('{"macos-latest":"macOS","windows-latest":"Windows"}')[matrix.os] }} (${{ matrix.python-version }}) + name: ${{ fromJson('{"macos-11":"macOS","windows-latest":"Windows","ubuntu-latest":"Ubuntu","ubuntu-20.04":"Ubuntu 20.04 (OpenSSL 1.1.1)","ubuntu-22.04":"Ubuntu 22.04 (OpenSSL 3.0)"}')[matrix.os] }} ${{ matrix.python-version }} ${{ matrix.nox-session}} continue-on-error: ${{ matrix.experimental }} + timeout-minutes: 30 steps: - - name: Checkout Repository - uses: actions/checkout@v2 + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" - - name: Set Up Python - ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: "Setup Python ${{ matrix.python-version }}" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" with: python-version: ${{ matrix.python-version }} - - name: Set Up Python 3.7 to run nox - if: matrix.python-version != '3.7' - uses: actions/setup-python@v2 - with: - python-version: 3.7 - - - name: Install Dependencies - run: python -m pip install --upgrade nox + - name: "Install dependencies" + run: python -m pip install --upgrade pip setuptools nox - - name: Run Tests + - name: "Run tests" run: ./ci/run_tests.sh env: PYTHON_VERSION: ${{ matrix.python-version }} + NOX_SESSION: ${{ matrix.nox-session }} - - name: Upload Coverage - run: ./ci/upload_coverage.sh - env: - JOB_NAME: "${{ runner.os }} (${{ matrix.python-version }})" + - name: "Upload artifact" + uses: "actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce" + with: + name: coverage-data + path: ".coverage.*" + if-no-files-found: error + + + coverage: + if: always() + runs-on: "ubuntu-latest" + needs: test + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Setup Python" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" + with: + python-version: "3.x" + + - name: "Install coverage" + run: "python -m pip install --upgrade coverage" + + - name: "Download artifact" + uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a + with: + name: coverage-data + + - name: "Combine & check coverage" + run: | + python -m coverage combine + python -m coverage html --skip-covered --skip-empty + python -m coverage report --ignore-errors --show-missing --fail-under=100 + + - if: ${{ failure() }} + name: "Upload report if check failed" + uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce + with: + name: coverage-report + path: htmlcov diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000..d27556450e --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,35 @@ +name: "CodeQL" + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + schedule: + - cron: "0 0 * * 5" + +permissions: "read-all" + +jobs: + analyze: + name: "Analyze" + runs-on: "ubuntu-latest" + permissions: + actions: read + contents: read + security-events: write + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Run CodeQL init" + uses: "github/codeql-action/init@b2c19fb9a2a485599ccf4ed5d65527d94bc57226" + with: + config-file: "./.github/codeql.yml" + languages: "python" + + - name: "Run CodeQL autobuild" + uses: "github/codeql-action/autobuild@b2c19fb9a2a485599ccf4ed5d65527d94bc57226" + + - name: "Run CodeQL analyze" + uses: "github/codeql-action/analyze@b2c19fb9a2a485599ccf4ed5d65527d94bc57226" diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 0000000000..376391bee8 --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,29 @@ +name: Downstream + +on: [push, pull_request] + +permissions: "read-all" + +jobs: + integration: + strategy: + fail-fast: false + matrix: + downstream: [botocore, requests] + runs-on: ubuntu-22.04 + timeout-minutes: 30 + + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Setup Python" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" + with: + python-version: "3.x" + + - name: "Install dependencies" + run: python -m pip install --upgrade nox + + - name: "Run downstream tests" + run: nox -s downstream_${{ matrix.downstream }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000000..a991884894 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,29 @@ +name: lint + +on: [push, pull_request] + +permissions: "read-all" + +jobs: + lint: + runs-on: ubuntu-20.04 + timeout-minutes: 10 + + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Setup Python" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" + with: + python-version: "3.x" + cache: pip + + - name: "Run pre-commit" + uses: pre-commit/action@646c83fcd040023954eafda54b4db0192ce70507 + + - name: "Install dependencies" + run: python -m pip install nox + + - name: "Run mypy" + run: nox -s mypy diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000..7a28c75ce2 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,87 @@ +name: Publish to PyPI + +on: + push: + tags: + - "*" + +permissions: + contents: read + +jobs: + build: + name: "Build dists" + runs-on: "ubuntu-latest" + environment: + name: "publish" + outputs: + hashes: ${{ steps.hash.outputs.hashes }} + + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + + - name: "Setup Python" + uses: "actions/setup-python@57ded4d7d5e986d7296eab16560982c6dd7c923b" + with: + python-version: "3.x" + + - name: "Install dependencies" + run: python -m pip install build==0.8.0 + + - name: "Build dists" + run: | + SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct) \ + python -m build + + - name: "Generate hashes" + id: hash + run: | + cd dist && echo "::set-output name=hashes::$(sha256sum * | base64 -w0)" + + - name: "Upload dists" + uses: "actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce" + with: + name: "dist" + path: "dist/" + if-no-files-found: error + retention-days: 5 + + provenance: + needs: [build] + permissions: + actions: read + contents: write + id-token: write # Needed to access the workflow's OIDC identity. + uses: "slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v1.5.0" + with: + base64-subjects: "${{ needs.build.outputs.hashes }}" + upload-assets: true + compile-generator: true # Workaround for https://github.com/slsa-framework/slsa-github-generator/issues/1163 + + publish: + name: "Publish" + if: startsWith(github.ref, 'refs/tags/') + needs: ["build", "provenance"] + permissions: + contents: write + runs-on: "ubuntu-latest" + + steps: + - name: "Download dists" + uses: "actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a" + with: + name: "dist" + path: "dist/" + + - name: "Upload dists to GitHub Release" + env: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + run: | + gh release upload ${{ github.ref_name }} dist/* --repo ${{ github.repository }} + + - name: "Publish dists to PyPI" + uses: "pypa/gh-action-pypi-publish@48b317d84d5f59668bb13be49d1697e36b3ad009" + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml new file mode 100644 index 0000000000..72a345a451 --- /dev/null +++ b/.github/workflows/scorecards.yml @@ -0,0 +1,33 @@ +name: "Scorecard" +on: + branch_protection_rule: + schedule: + - cron: "0 0 * * 0" + push: + branches: ["main", "1.26.x"] + +permissions: read-all + +jobs: + analysis: + name: "Scorecard" + runs-on: "ubuntu-latest" + permissions: + security-events: write + id-token: write + contents: read + actions: read + + steps: + - name: "Checkout repository" + uses: "actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3" + with: + persist-credentials: false + + - name: "Run Scorecard" + uses: "ossf/scorecard-action@e38b1902ae4f44df626f11ba0734b14fb91f8f86" + with: + results_file: results.sarif + results_format: sarif + repo_token: ${{ secrets.SCORECARD_TOKEN }} + publish_results: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..dc05ec6820 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: ["--py37-plus"] + + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + args: ["--target-version", "py37"] + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-2020] diff --git a/.readthedocs.yml b/.readthedocs.yml index 1da640f032..78bd2064f6 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,14 +1,20 @@ version: 2 +build: + os: ubuntu-22.04 + tools: + python: "3" + python: install: + - requirements: docs/requirements.txt - method: pip path: . extra_requirements: - brotli - secure - socks - - requirements: docs/requirements.txt + - zstd sphinx: fail_on_warning: true diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index ca52bf2681..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,95 +0,0 @@ -language: python -os: linux -dist: xenial - -before_install: - - env - - openssl version - - python -c "import ssl; print(ssl.OPENSSL_VERSION)" - -install: - - ./ci/install.sh - -script: - - ./ci/run.sh - - ./ci/upload_coverage.sh - -cache: - directories: - - ${HOME}/.cache - -notifications: - email: false - -env: - global: - - PYTHONWARNINGS=always::DeprecationWarning - - - PYPI_USERNAME=urllib3 - # PYPI_PASSWORD is set in Travis control panel. - -jobs: - allow_failures: - - python: nightly - include: - # Unit tests - - python: 2.7 - env: NOX_SESSION=test-2.7 - - python: 3.5 - env: NOX_SESSION=test-3.5 - - python: 3.6 - env: NOX_SESSION=test-3.6 - - python: 3.7 - env: NOX_SESSION=test-3.7 - - python: 3.8 - env: NOX_SESSION=test-3.8 - - python: 3.9-dev - env: NOX_SESSION=test-3.9 - - python: nightly - env: NOX_SESSION=test-3.10 - - python: pypy2.7-6.0 - env: NOX_SESSION=test-pypy - - python: pypy3.5-6.0 - env: NOX_SESSION=test-pypy - - # Extras - - python: 2.7 - env: NOX_SESSION=app_engine GAE_SDK_PATH=${HOME}/.cache/google_appengine - - python: 2.7 - env: NOX_SESSION=google_brotli-2 - - python: 3.7 - env: NOX_SESSION=google_brotli-3 - - # Downstream integration tests. - - python: 2.7 - env: DOWNSTREAM=requests - stage: integration - - - python: 3.7 - env: DOWNSTREAM=requests - stage: integration - - - python: 2.7 - env: DOWNSTREAM=botocore - stage: integration - - - python: 3.7 - env: DOWNSTREAM=botocore - stage: integration - - - python: 3.7 - stage: deploy - script: - - ./ci/deploy.sh - -stages: - - name: test - if: tag IS blank - - # Run integration tests for release candidates - - name: integration - if: type = pull_request AND head_branch =~ ^release-[\d.]+$ AND tag IS blank - - # Deploy on any tags - - name: deploy - if: tag IS present AND tag =~ /^(\d+\.\d+(?:.\d+)?)$/ AND repo = urllib3/urllib3 diff --git a/CHANGES.rst b/CHANGES.rst index df1a0dbda4..1707eaddc9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,8 +1,298 @@ -Changes -======= +2.0.0 (2023-04-26) +================== + +Read the `v2.0 migration guide `__ for help upgrading to the latest version of urllib3. + +Removed +------- + +* Removed support for Python 2.7, 3.5, and 3.6 (`#883 `__, `#2336 `__). +* Removed fallback on certificate ``commonName`` in ``match_hostname()`` function. + This behavior was deprecated in May 2000 in RFC 2818. Instead only ``subjectAltName`` + is used to verify the hostname by default. To enable verifying the hostname against + ``commonName`` use ``SSLContext.hostname_checks_common_name = True`` (`#2113 `__). +* Removed support for Python with an ``ssl`` module compiled with LibreSSL, CiscoSSL, + wolfSSL, and all other OpenSSL alternatives. Python is moving to require OpenSSL with PEP 644 (`#2168 `__). +* Removed support for OpenSSL versions earlier than 1.1.1 or that don't have SNI support. + When an incompatible OpenSSL version is detected an ``ImportError`` is raised (`#2168 `__). +* Removed the list of default ciphers for OpenSSL 1.1.1+ and SecureTransport as their own defaults are already secure (`#2082 `__). +* Removed ``urllib3.contrib.appengine.AppEngineManager`` and support for Google App Engine Standard Environment (`#2044 `__). +* Removed deprecated ``Retry`` options ``method_whitelist``, ``DEFAULT_REDIRECT_HEADERS_BLACKLIST`` (`#2086 `__). +* Removed ``urllib3.HTTPResponse.from_httplib`` (`#2648 `__). +* Removed default value of ``None`` for the ``request_context`` parameter of ``urllib3.PoolManager.connection_from_pool_key``. This change should have no effect on users as the default value of ``None`` was an invalid option and was never used (`#1897 `__). +* Removed the ``urllib3.request`` module. ``urllib3.request.RequestMethods`` has been made a private API. + This change was made to ensure that ``from urllib3 import request`` imported the top-level ``request()`` + function instead of the ``urllib3.request`` module (`#2269 `__). +* Removed support for SSLv3.0 from the ``urllib3.contrib.pyopenssl`` even when support is available from the compiled OpenSSL library (`#2233 `__). +* Removed the deprecated ``urllib3.contrib.ntlmpool`` module (`#2339 `__). +* Removed ``DEFAULT_CIPHERS``, ``HAS_SNI``, ``USE_DEFAULT_SSLCONTEXT_CIPHERS``, from the private module ``urllib3.util.ssl_`` (`#2168 `__). +* Removed ``urllib3.exceptions.SNIMissingWarning`` (`#2168 `__). +* Removed the ``_prepare_conn`` method from ``HTTPConnectionPool``. Previously this was only used to call ``HTTPSConnection.set_cert()`` by ``HTTPSConnectionPool`` (`#1985 `__). +* Removed ``tls_in_tls_required`` property from ``HTTPSConnection``. This is now determined from the ``scheme`` parameter in ``HTTPConnection.set_tunnel()`` (`#1985 `__). + +Deprecated +---------- + +* Deprecated ``HTTPResponse.getheaders()`` and ``HTTPResponse.getheader()`` which will be removed in urllib3 v2.1.0. Instead use ``HTTPResponse.headers`` and ``HTTPResponse.headers.get(name, default)``. (`#1543 `__, `#2814 `__). +* Deprecated ``urllib3.contrib.pyopenssl`` module which will be removed in urllib3 v2.1.0 (`#2691 `__). +* Deprecated ``urllib3.contrib.securetransport`` module which will be removed in urllib3 v2.1.0 (`#2692 `__). +* Deprecated ``ssl_version`` option in favor of ``ssl_minimum_version``. ``ssl_version`` will be removed in urllib3 v2.1.0 (`#2110 `__). +* Deprecated the ``strict`` parameter as it's not longer needed in Python 3.x. It will be removed in urllib3 v2.1.0 (`#2267 `__) +* Deprecated the ``NewConnectionError.pool`` attribute which will be removed in urllib3 v2.1.0 (`#2271 `__). +* Deprecated ``format_header_param_html5`` and ``format_header_param`` in favor of ``format_multipart_header_param`` (`#2257 `__). +* Deprecated ``RequestField.header_formatter`` parameter which will be removed in urllib3 v2.1.0 (`#2257 `__). +* Deprecated ``HTTPSConnection.set_cert()`` method. Instead pass parameters to the ``HTTPSConnection`` constructor (`#1985 `__). +* Deprecated ``HTTPConnection.request_chunked()`` method which will be removed in urllib3 v2.1.0. Instead pass ``chunked=True`` to ``HTTPConnection.request()`` (`#1985 `__). + +Added +----- + +* Added top-level ``urllib3.request`` function which uses a preconfigured module-global ``PoolManager`` instance (`#2150 `__). +* Added the ``json`` parameter to ``urllib3.request()``, ``PoolManager.request()``, and ``ConnectionPool.request()`` methods to send JSON bodies in requests. Using this parameter will set the header ``Content-Type: application/json`` if ``Content-Type`` isn't already defined. + Added support for parsing JSON response bodies with ``HTTPResponse.json()`` method (`#2243 `__). +* Added type hints to the ``urllib3`` module (`#1897 `__). +* Added ``ssl_minimum_version`` and ``ssl_maximum_version`` options which set + ``SSLContext.minimum_version`` and ``SSLContext.maximum_version`` (`#2110 `__). +* Added support for Zstandard (RFC 8878) when ``zstandard`` 1.18.0 or later is installed. + Added the ``zstd`` extra which installs the ``zstandard`` package (`#1992 `__). +* Added ``urllib3.response.BaseHTTPResponse`` class. All future response classes will be subclasses of ``BaseHTTPResponse`` (`#2083 `__). +* Added ``FullPoolError`` which is raised when ``PoolManager(block=True)`` and a connection is returned to a full pool (`#2197 `__). +* Added ``HTTPHeaderDict`` to the top-level ``urllib3`` namespace (`#2216 `__). +* Added support for configuring header merging behavior with HTTPHeaderDict + When using a ``HTTPHeaderDict`` to provide headers for a request, by default duplicate + header values will be repeated. But if ``combine=True`` is passed into a call to + ``HTTPHeaderDict.add``, then the added header value will be merged in with an existing + value into a comma-separated list (``X-My-Header: foo, bar``) (`#2242 `__). +* Added ``NameResolutionError`` exception when a DNS error occurs (`#2305 `__). +* Added ``proxy_assert_hostname`` and ``proxy_assert_fingerprint`` kwargs to ``ProxyManager`` (`#2409 `__). +* Added a configurable ``backoff_max`` parameter to the ``Retry`` class. + If a custom ``backoff_max`` is provided to the ``Retry`` class, it + will replace the ``Retry.DEFAULT_BACKOFF_MAX`` (`#2494 `__). +* Added the ``authority`` property to the Url class as per RFC 3986 3.2. This property should be used in place of ``netloc`` for users who want to include the userinfo (auth) component of the URI (`#2520 `__). +* Added the ``scheme`` parameter to ``HTTPConnection.set_tunnel`` to configure the scheme of the origin being tunnelled to (`#1985 `__). +* Added the ``is_closed``, ``is_connected`` and ``has_connected_to_proxy`` properties to ``HTTPConnection`` (`#1985 `__). +* Added optional ``backoff_jitter`` parameter to ``Retry``. (`#2952 `__) + +Changed +------- + +* Changed ``urllib3.response.HTTPResponse.read`` to respect the semantics of ``io.BufferedIOBase`` regardless of compression. Specifically, this method: + + * Only returns an empty bytes object to indicate EOF (that is, the response has been fully consumed). + * Never returns more bytes than requested. + * Can issue any number of system calls: zero, one or multiple. + + If you want each ``urllib3.response.HTTPResponse.read`` call to issue a single system call, you need to disable decompression by setting ``decode_content=False`` (`#2128 `__). +* Changed ``urllib3.HTTPConnection.getresponse`` to return an instance of ``urllib3.HTTPResponse`` instead of ``http.client.HTTPResponse`` (`#2648 `__). +* Changed ``ssl_version`` to instead set the corresponding ``SSLContext.minimum_version`` + and ``SSLContext.maximum_version`` values. Regardless of ``ssl_version`` passed + ``SSLContext`` objects are now constructed using ``ssl.PROTOCOL_TLS_CLIENT`` (`#2110 `__). +* Changed default ``SSLContext.minimum_version`` to be ``TLSVersion.TLSv1_2`` in line with Python 3.10 (`#2373 `__). +* Changed ``ProxyError`` to wrap any connection error (timeout, TLS, DNS) that occurs when connecting to the proxy (`#2482 `__). +* Changed ``urllib3.util.create_urllib3_context`` to not override the system cipher suites with + a default value. The new default will be cipher suites configured by the operating system (`#2168 `__). +* Changed ``multipart/form-data`` header parameter formatting matches the WHATWG HTML Standard as of 2021-06-10. Control characters in filenames are no longer percent encoded (`#2257 `__). +* Changed the error raised when connecting via HTTPS when the ``ssl`` module isn't available from ``SSLError`` to ``ImportError`` (`#2589 `__). +* Changed ``HTTPConnection.request()`` to always use lowercase chunk boundaries when sending requests with ``Transfer-Encoding: chunked`` (`#2515 `__). +* Changed ``enforce_content_length`` default to True, preventing silent data loss when reading streamed responses (`#2514 `__). +* Changed internal implementation of ``HTTPHeaderDict`` to use ``dict`` instead of ``collections.OrderedDict`` for better performance (`#2080 `__). +* Changed the ``urllib3.contrib.pyopenssl`` module to wrap ``OpenSSL.SSL.Error`` with ``ssl.SSLError`` in ``PyOpenSSLContext.load_cert_chain`` (`#2628 `__). +* Changed usage of the deprecated ``socket.error`` to ``OSError`` (`#2120 `__). +* Changed all parameters in the ``HTTPConnection`` and ``HTTPSConnection`` constructors to be keyword-only except ``host`` and ``port`` (`#1985 `__). +* Changed ``HTTPConnection.getresponse()`` to set the socket timeout from ``HTTPConnection.timeout`` value before reading + data from the socket. This previously was done manually by the ``HTTPConnectionPool`` calling ``HTTPConnection.sock.settimeout(...)`` (`#1985 `__). +* Changed the ``_proxy_host`` property to ``_tunnel_host`` in ``HTTPConnectionPool`` to more closely match how the property is used (value in ``HTTPConnection.set_tunnel()``) (`#1985 `__). +* Changed name of ``Retry.BACK0FF_MAX`` to be ``Retry.DEFAULT_BACKOFF_MAX``. +* Changed TLS handshakes to use ``SSLContext.check_hostname`` when possible (`#2452 `__). +* Changed ``server_hostname`` to behave like other parameters only used by ``HTTPSConnectionPool`` (`#2537 `__). +* Changed the default ``blocksize`` to 16KB to match OpenSSL's default read amounts (`#2348 `__). +* Changed ``HTTPResponse.read()`` to raise an error when calling with ``decode_content=False`` after using ``decode_content=True`` to prevent data loss (`#2800 `__). + +Fixed +----- + +* Fixed thread-safety issue where accessing a ``PoolManager`` with many distinct origins would cause connection pools to be closed while requests are in progress (`#1252 `__). +* Fixed an issue where an ``HTTPConnection`` instance would erroneously reuse the socket read timeout value from reading the previous response instead of a newly configured connect timeout. + Instead now if ``HTTPConnection.timeout`` is updated before sending the next request the new timeout value will be used (`#2645 `__). +* Fixed ``socket.error.errno`` when raised from pyOpenSSL's ``OpenSSL.SSL.SysCallError`` (`#2118 `__). +* Fixed the default value of ``HTTPSConnection.socket_options`` to match ``HTTPConnection`` (`#2213 `__). +* Fixed a bug where ``headers`` would be modified by the ``remove_headers_on_redirect`` feature (`#2272 `__). +* Fixed a reference cycle bug in ``urllib3.util.connection.create_connection()`` (`#2277 `__). +* Fixed a socket leak if ``HTTPConnection.connect()`` fails (`#2571 `__). +* Fixed ``urllib3.contrib.pyopenssl.WrappedSocket`` and ``urllib3.contrib.securetransport.WrappedSocket`` close methods (`#2970 `__) + +1.26.15 (2023-03-10) +==================== + +* Fix socket timeout value when ``HTTPConnection`` is reused (`#2645 `__) +* Remove "!" character from the unreserved characters in IPv6 Zone ID parsing + (`#2899 `__) +* Fix IDNA handling of '\x80' byte (`#2901 `__) + +1.26.14 (2023-01-11) +==================== + +* Fixed parsing of port 0 (zero) returning None, instead of 0. (`#2850 `__) +* Removed deprecated getheaders() calls in contrib module. Fixed the type hint of ``PoolKey.key_retries`` by adding ``bool`` to the union. (`#2865 `__) + +1.26.13 (2022-11-23) +==================== + +* Deprecated the ``HTTPResponse.getheaders()`` and ``HTTPResponse.getheader()`` methods. +* Fixed an issue where parsing a URL with leading zeroes in the port would be rejected + even when the port number after removing the zeroes was valid. +* Fixed a deprecation warning when using cryptography v39.0.0. +* Removed the ``<4`` in the ``Requires-Python`` packaging metadata field. + +1.26.12 (2022-08-22) +==================== + +* Deprecated the `urllib3[secure]` extra and the `urllib3.contrib.pyopenssl` module. + Both will be removed in v2.x. See this `GitHub issue `_ + for justification and info on how to migrate. + +1.26.11 (2022-07-25) +==================== + +* Fixed an issue where reading more than 2 GiB in a call to ``HTTPResponse.read`` would + raise an ``OverflowError`` on Python 3.9 and earlier. + +1.26.10 (2022-07-07) +==================== + +* Removed support for Python 3.5 +* Fixed an issue where a ``ProxyError`` recommending configuring the proxy as HTTP + instead of HTTPS could appear even when an HTTPS proxy wasn't configured. + +1.26.9 (2022-03-16) +=================== + +* Changed ``urllib3[brotli]`` extra to favor installing Brotli libraries that are still + receiving updates like ``brotli`` and ``brotlicffi`` instead of ``brotlipy``. + This change does not impact behavior of urllib3, only which dependencies are installed. +* Fixed a socket leaking when ``HTTPSConnection.connect()`` raises an exception. +* Fixed ``server_hostname`` being forwarded from ``PoolManager`` to ``HTTPConnectionPool`` + when requesting an HTTP URL. Should only be forwarded when requesting an HTTPS URL. + +1.26.8 (2022-01-07) +=================== + +* Added extra message to ``urllib3.exceptions.ProxyError`` when urllib3 detects that + a proxy is configured to use HTTPS but the proxy itself appears to only use HTTP. +* Added a mention of the size of the connection pool when discarding a connection due to the pool being full. +* Added explicit support for Python 3.11. +* Deprecated the ``Retry.MAX_BACKOFF`` class property in favor of ``Retry.DEFAULT_MAX_BACKOFF`` + to better match the rest of the default parameter names. ``Retry.MAX_BACKOFF`` is removed in v2.0. +* Changed location of the vendored ``ssl.match_hostname`` function from ``urllib3.packages.ssl_match_hostname`` + to ``urllib3.util.ssl_match_hostname`` to ensure Python 3.10+ compatibility after being repackaged + by downstream distributors. +* Fixed absolute imports, all imports are now relative. + + +1.26.7 (2021-09-22) +=================== + +* Fixed a bug with HTTPS hostname verification involving IP addresses and lack + of SNI. (Issue #2400) +* Fixed a bug where IPv6 braces weren't stripped during certificate hostname + matching. (Issue #2240) + + +1.26.6 (2021-06-25) +=================== + +* Deprecated the ``urllib3.contrib.ntlmpool`` module. urllib3 is not able to support + it properly due to `reasons listed in this issue `_. + If you are a user of this module please leave a comment. +* Changed ``HTTPConnection.request_chunked()`` to not erroneously emit multiple + ``Transfer-Encoding`` headers in the case that one is already specified. +* Fixed typo in deprecation message to recommend ``Retry.DEFAULT_ALLOWED_METHODS``. + + +1.26.5 (2021-05-26) +=================== + +* Fixed deprecation warnings emitted in Python 3.10. +* Updated vendored ``six`` library to 1.16.0. +* Improved performance of URL parser when splitting + the authority component. + + +1.26.4 (2021-03-15) +=================== + +* Changed behavior of the default ``SSLContext`` when connecting to HTTPS proxy + during HTTPS requests. The default ``SSLContext`` now sets ``check_hostname=True``. + + +1.26.3 (2021-01-26) +=================== + +* Fixed bytes and string comparison issue with headers (Pull #2141) + +* Changed ``ProxySchemeUnknown`` error message to be + more actionable if the user supplies a proxy URL without + a scheme. (Pull #2107) + + +1.26.2 (2020-11-12) +=================== + +* Fixed an issue where ``wrap_socket`` and ``CERT_REQUIRED`` wouldn't + be imported properly on Python 2.7.8 and earlier (Pull #2052) + + +1.26.1 (2020-11-11) +=================== + +* Fixed an issue where two ``User-Agent`` headers would be sent if a + ``User-Agent`` header key is passed as ``bytes`` (Pull #2047) + + +1.26.0 (2020-11-10) +=================== + +* **NOTE: urllib3 v2.0 will drop support for Python 2**. + `Read more in the v2.0 Roadmap `_. + +* Added support for HTTPS proxies contacting HTTPS servers (Pull #1923, Pull #1806) + +* Deprecated negotiating TLSv1 and TLSv1.1 by default. Users that + still wish to use TLS earlier than 1.2 without a deprecation warning + should opt-in explicitly by setting ``ssl_version=ssl.PROTOCOL_TLSv1_1`` (Pull #2002) + **Starting in urllib3 v2.0: Connections that receive a ``DeprecationWarning`` will fail** + +* Deprecated ``Retry`` options ``Retry.DEFAULT_METHOD_WHITELIST``, ``Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST`` + and ``Retry(method_whitelist=...)`` in favor of ``Retry.DEFAULT_ALLOWED_METHODS``, + ``Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT``, and ``Retry(allowed_methods=...)`` + (Pull #2000) **Starting in urllib3 v2.0: Deprecated options will be removed** + +* Added default ``User-Agent`` header to every request (Pull #1750) + +* Added ``urllib3.util.SKIP_HEADER`` for skipping ``User-Agent``, ``Accept-Encoding``, + and ``Host`` headers from being automatically emitted with requests (Pull #2018) + +* Collapse ``transfer-encoding: chunked`` request data and framing into + the same ``socket.send()`` call (Pull #1906) + +* Send ``http/1.1`` ALPN identifier with every TLS handshake by default (Pull #1894) + +* Properly terminate SecureTransport connections when CA verification fails (Pull #1977) + +* Don't emit an ``SNIMissingWarning`` when passing ``server_hostname=None`` + to SecureTransport (Pull #1903) + +* Disabled requesting TLSv1.2 session tickets as they weren't being used by urllib3 (Pull #1970) + +* Suppress ``BrokenPipeError`` when writing request body after the server + has closed the socket (Pull #1524) + +* Wrap ``ssl.SSLError`` that can be raised from reading a socket (e.g. "bad MAC") + into an ``urllib3.exceptions.SSLError`` (Pull #1939) + 1.25.11 (2020-10-19) --------------------- +==================== * Fix retry backoff time parsed from ``Retry-After`` header when given in the HTTP date format. The HTTP date was parsed as the local timezone @@ -15,7 +305,7 @@ Changes 1.25.10 (2020-07-22) --------------------- +==================== * Added support for ``SSLKEYLOGFILE`` environment variable for logging TLS session keys with use with programs like @@ -34,7 +324,7 @@ Changes 1.25.9 (2020-04-16) -------------------- +=================== * Added ``InvalidProxyConfigurationWarning`` which is raised when erroneously specifying an HTTPS proxy URL. urllib3 doesn't currently @@ -58,7 +348,7 @@ Changes 1.25.8 (2020-01-20) -------------------- +=================== * Drop support for EOL Python 3.4 (Pull #1774) @@ -66,7 +356,7 @@ Changes 1.25.7 (2019-11-11) -------------------- +=================== * Preserve ``chunked`` parameter on retries (Pull #1715, Pull #1734) @@ -80,14 +370,14 @@ Changes 1.25.6 (2019-09-24) -------------------- +=================== * Fix issue where tilde (``~``) characters were incorrectly percent-encoded in the path. (Pull #1692) 1.25.5 (2019-09-19) -------------------- +=================== * Add mitigation for BPO-37428 affecting Python <3.7.4 and OpenSSL 1.1.1+ which caused certificate verification to be enabled when using ``cert_reqs=CERT_NONE``. @@ -95,7 +385,7 @@ Changes 1.25.4 (2019-09-19) -------------------- +=================== * Propagate Retry-After header settings to subsequent retries. (Pull #1607) @@ -114,7 +404,7 @@ Changes 1.25.3 (2019-05-23) -------------------- +=================== * Change ``HTTPSConnection`` to load system CA certificates when ``ca_certs``, ``ca_cert_dir``, and ``ssl_context`` are @@ -124,7 +414,7 @@ Changes 1.25.2 (2019-04-28) -------------------- +=================== * Change ``is_ipaddress`` to not detect IPvFuture addresses. (Pull #1583) @@ -133,7 +423,7 @@ Changes 1.25.1 (2019-04-24) -------------------- +=================== * Add support for Google's ``Brotli`` package. (Pull #1572, Pull #1579) @@ -141,7 +431,7 @@ Changes 1.25 (2019-04-22) ------------------ +================= * Require and validate certificates by default when using HTTPS (Pull #1507) @@ -153,7 +443,7 @@ Changes * Add TLSv1.3 support to CPython, pyOpenSSL, and SecureTransport ``SSLContext`` implementations. (Pull #1496) -* Switched the default multipart header encoder from RFC 2231 to HTML 5 working draft. (Issue #303, PR #1492) +* Switched the default multipart header encoder from RFC 2231 to HTML 5 working draft. (Issue #303, Pull #1492) * Fixed issue where OpenSSL would block if an encrypted client private key was given and no password was given. Instead an ``SSLError`` is raised. (Pull #1489) @@ -168,12 +458,12 @@ Changes * Implemented a more efficient ``HTTPResponse.__iter__()`` method. (Issue #1483) 1.24.3 (2019-05-01) -------------------- +=================== * Apply fix for CVE-2019-9740. (Pull #1591) 1.24.2 (2019-04-17) -------------------- +=================== * Don't load system certificates by default when any other ``ca_certs``, ``ca_certs_dir`` or ``ssl_context`` parameters are specified. @@ -184,7 +474,7 @@ Changes 1.24.1 (2018-11-02) -------------------- +=================== * Remove quadratic behavior within ``GzipDecoder.decompress()`` (Issue #1467) @@ -192,7 +482,7 @@ Changes 1.24 (2018-10-16) ------------------ +================= * Allow key_server_hostname to be specified when initializing a PoolManager to allow custom SNI to be overridden. (Pull #1449) @@ -218,7 +508,7 @@ Changes 1.23 (2018-06-04) ------------------ +================= * Allow providing a list of headers to strip from requests when redirecting to a different host. Defaults to the ``Authorization`` header. Different @@ -246,7 +536,7 @@ Changes 1.22 (2017-07-20) ------------------ +================= * Fixed missing brackets in ``HTTP CONNECT`` when connecting to IPv6 address via IPv6 proxy. (Issue #1222) @@ -263,7 +553,7 @@ Changes 1.21.1 (2017-05-02) -------------------- +=================== * Fixed SecureTransport issue that would cause long delays in response body delivery. (Pull #1154) @@ -277,7 +567,7 @@ Changes 1.21 (2017-04-25) ------------------ +================= * Improved performance of certain selector system calls on Python 3.5 and later. (Pull #1095) @@ -310,7 +600,7 @@ Changes 1.20 (2017-01-19) ------------------ +================= * Added support for waiting for I/O using selectors other than select, improving urllib3's behaviour with large numbers of concurrent connections. @@ -347,13 +637,13 @@ Changes 1.19.1 (2016-11-16) -------------------- +=================== * Fixed AppEngine import that didn't function on Python 3.5. (Pull #1025) 1.19 (2016-11-03) ------------------ +================= * urllib3 now respects Retry-After headers on 413, 429, and 503 responses when using the default retry logic. (Pull #955) @@ -373,7 +663,7 @@ Changes 1.18.1 (2016-10-27) -------------------- +=================== * CVE-2016-9015. Users who are using urllib3 version 1.17 or 1.18 along with PyOpenSSL injection and OpenSSL 1.1.0 *must* upgrade to this version. This @@ -384,13 +674,13 @@ Changes interprets the presence of any flag as requesting certificate validation. There is no PR for this patch, as it was prepared for simultaneous disclosure - and release. The master branch received the same fix in PR #1010. + and release. The master branch received the same fix in Pull #1010. 1.18 (2016-09-26) ------------------ +================= -* Fixed incorrect message for IncompleteRead exception. (PR #973) +* Fixed incorrect message for IncompleteRead exception. (Pull #973) * Accept ``iPAddress`` subject alternative name fields in TLS certificates. (Issue #258) @@ -402,7 +692,7 @@ Changes 1.17 (2016-09-06) ------------------ +================= * Accept ``SSLContext`` objects for use in SSL/TLS negotiation. (Issue #835) @@ -419,36 +709,36 @@ Changes contains retries history. (Issue #848) * Timeout can no longer be set as boolean, and must be greater than zero. - (PR #924) + (Pull #924) * Removed pyasn1 and ndg-httpsclient from dependencies used for PyOpenSSL. We now use cryptography and idna, both of which are already dependencies of - PyOpenSSL. (PR #930) + PyOpenSSL. (Pull #930) * Fixed infinite loop in ``stream`` when amt=None. (Issue #928) * Try to use the operating system's certificates when we are using an - ``SSLContext``. (PR #941) + ``SSLContext``. (Pull #941) * Updated cipher suite list to allow ChaCha20+Poly1305. AES-GCM is preferred to - ChaCha20, but ChaCha20 is then preferred to everything else. (PR #947) + ChaCha20, but ChaCha20 is then preferred to everything else. (Pull #947) -* Updated cipher suite list to remove 3DES-based cipher suites. (PR #958) +* Updated cipher suite list to remove 3DES-based cipher suites. (Pull #958) -* Removed the cipher suite fallback to allow HIGH ciphers. (PR #958) +* Removed the cipher suite fallback to allow HIGH ciphers. (Pull #958) * Implemented ``length_remaining`` to determine remaining content - to be read. (PR #949) + to be read. (Pull #949) * Implemented ``enforce_content_length`` to enable exceptions when - incomplete data chunks are received. (PR #949) + incomplete data chunks are received. (Pull #949) * Dropped connection start, dropped connection reset, redirect, forced retry, - and new HTTPS connection log levels to DEBUG, from INFO. (PR #967) + and new HTTPS connection log levels to DEBUG, from INFO. (Pull #967) 1.16 (2016-06-11) ------------------ +================= * Disable IPv6 DNS when IPv6 connections are not possible. (Issue #840) @@ -473,13 +763,13 @@ Changes 1.15.1 (2016-04-11) -------------------- +=================== * Fix packaging to include backports module. (Issue #841) 1.15 (2016-04-06) ------------------ +================= * Added Retry(raise_on_status=False). (Issue #720) @@ -503,7 +793,7 @@ Changes 1.14 (2015-12-29) ------------------ +================= * contrib: SOCKS proxy support! (Issue #762) @@ -512,13 +802,13 @@ Changes 1.13.1 (2015-12-18) -------------------- +=================== * Fixed regression in IPv6 + SSL for match_hostname. (Issue #761) 1.13 (2015-12-14) ------------------ +================= * Fixed ``pip install urllib3[secure]`` on modern pip. (Issue #706) @@ -535,7 +825,7 @@ Changes 1.12 (2015-09-03) ------------------ +================= * Rely on ``six`` for importing ``httplib`` to work around conflicts with other Python 3 shims. (Issue #688) @@ -548,7 +838,7 @@ Changes 1.11 (2015-07-21) ------------------ +================= * When ``ca_certs`` is given, ``cert_reqs`` defaults to ``'CERT_REQUIRED'``. (Issue #650) @@ -593,7 +883,7 @@ Changes (Issue #674) 1.10.4 (2015-05-03) -------------------- +=================== * Migrate tests to Tornado 4. (Issue #594) @@ -609,7 +899,7 @@ Changes 1.10.3 (2015-04-21) -------------------- +=================== * Emit ``InsecurePlatformWarning`` when SSLContext object is missing. (Issue #558) @@ -630,7 +920,7 @@ Changes 1.10.2 (2015-02-25) -------------------- +=================== * Fix file descriptor leakage on retries. (Issue #548) @@ -642,7 +932,7 @@ Changes 1.10.1 (2015-02-10) -------------------- +=================== * Pools can be used as context managers. (Issue #545) @@ -656,7 +946,7 @@ Changes 1.10 (2014-12-14) ------------------ +================= * Disabled SSLv3. (Issue #473) @@ -688,7 +978,7 @@ Changes 1.9.1 (2014-09-13) ------------------- +================== * Apply socket arguments before binding. (Issue #427) @@ -709,7 +999,7 @@ Changes 1.9 (2014-07-04) ----------------- +================ * Shuffled around development-related files. If you're maintaining a distro package of urllib3, you may need to tweak things. (Issue #415) @@ -746,7 +1036,7 @@ Changes 1.8.3 (2014-06-23) ------------------- +================== * Fix TLS verification when using a proxy in Python 3.4.1. (Issue #385) @@ -768,13 +1058,13 @@ Changes 1.8.2 (2014-04-17) ------------------- +================== * Fix ``urllib3.util`` not being included in the package. 1.8.1 (2014-04-17) ------------------- +================== * Fix AppEngine bug of HTTPS requests going out as HTTP. (Issue #356) @@ -785,7 +1075,7 @@ Changes 1.8 (2014-03-04) ----------------- +================ * Improved url parsing in ``urllib3.util.parse_url`` (properly parse '@' in username, and blank ports like 'hostname:'). @@ -837,7 +1127,7 @@ Changes 1.7.1 (2013-09-25) ------------------- +================== * Added granular timeout support with new ``urllib3.util.Timeout`` class. (Issue #231) @@ -846,7 +1136,7 @@ Changes 1.7 (2013-08-14) ----------------- +================ * More exceptions are now pickle-able, with tests. (Issue #174) @@ -885,7 +1175,7 @@ Changes 1.6 (2013-04-25) ----------------- +================ * Contrib: Optional SNI support for Py2 using PyOpenSSL. (Issue #156) @@ -945,7 +1235,7 @@ Changes 1.5 (2012-08-02) ----------------- +================ * Added ``urllib3.add_stderr_logger()`` for quickly enabling STDERR debug logging in urllib3. @@ -970,7 +1260,7 @@ Changes 1.4 (2012-06-16) ----------------- +================ * Minor AppEngine-related fixes. @@ -982,7 +1272,7 @@ Changes 1.3 (2012-03-25) ----------------- +================ * Removed pre-1.0 deprecated API. @@ -1001,13 +1291,13 @@ Changes 1.2.2 (2012-02-06) ------------------- +================== * Fixed packaging bug of not shipping ``test-requirements.txt``. (Issue #47) 1.2.1 (2012-02-05) ------------------- +================== * Fixed another bug related to when ``ssl`` module is not available. (Issue #41) @@ -1016,7 +1306,7 @@ Changes 1.2 (2012-01-29) ----------------- +================ * Added Python 3 support (tested on 3.2.2) @@ -1042,7 +1332,7 @@ Changes 1.1 (2012-01-07) ----------------- +================ * Refactored ``dummyserver`` to its own root namespace module (used for testing). @@ -1059,7 +1349,7 @@ Changes 1.0.2 (2011-11-04) ------------------- +================== * Fixed typo in ``VerifiedHTTPSConnection`` which would only present as a bug if you're using the object manually. (Thanks pyos) @@ -1072,14 +1362,14 @@ Changes 1.0.1 (2011-10-10) ------------------- +================== * Fixed a bug where the same connection would get returned into the pool twice, causing extraneous "HttpConnectionPool is full" log warnings. 1.0 (2011-10-08) ----------------- +================ * Added ``PoolManager`` with LRU expiration of connections (tested and documented). @@ -1102,13 +1392,13 @@ Changes 0.4.1 (2011-07-17) ------------------- +================== * Minor bug fixes, code cleanup. 0.4 (2011-03-01) ----------------- +================ * Better unicode support. * Added ``VerifiedHTTPSConnection``. @@ -1117,13 +1407,13 @@ Changes 0.3.1 (2010-07-13) ------------------- +================== * Added ``assert_host_name`` optional parameter. Now compatible with proxies. 0.3 (2009-12-10) ----------------- +================ * Added HTTPS support. * Minor bug fixes. @@ -1132,13 +1422,13 @@ Changes 0.2 (2008-11-17) ----------------- +================ * Added unit tests. * Bug fixes. 0.1 (2008-11-16) ----------------- +================ * First release. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt deleted file mode 100644 index 51bbc2b227..0000000000 --- a/CONTRIBUTORS.txt +++ /dev/null @@ -1,310 +0,0 @@ -# Contributions to the urllib3 project - -## Creator & Maintainer - -* Andrey Petrov - - -## Contributors - -In chronological order: - -* victor.vde - * HTTPS patch (which inspired HTTPSConnectionPool) - -* erikcederstrand - * NTLM-authenticated HTTPSConnectionPool - * Basic-authenticated HTTPSConnectionPool (merged into make_headers) - -* niphlod - * Client-verified SSL certificates for HTTPSConnectionPool - * Response gzip and deflate encoding support - * Better unicode support for filepost using StringIO buffers - -* btoconnor - * Non-multipart encoding for POST requests - -* p.dobrogost - * Code review, PEP8 compliance, benchmark fix - -* kennethreitz - * Bugfixes, suggestions, Requests integration - -* georgemarshall - * Bugfixes, Improvements and Test coverage - -* Thomas Kluyver - * Python 3 support - -* brandon-rhodes - * Design review, bugfixes, test coverage. - -* studer - * IPv6 url support and test coverage - -* Shivaram Lingamneni - * Support for explicitly closing pooled connections - -* hartator - * Corrected multipart behavior for params - -* Thomas Weißschuh - * Support for TLS SNI - * API unification of ssl_version/cert_reqs - * SSL fingerprint and alternative hostname verification - * Bugfixes in testsuite - -* Sune Kirkeby - * Optional SNI-support for Python 2 via PyOpenSSL. - -* Marc Schlaich - * Various bugfixes and test improvements. - -* Bryce Boe - * Correct six.moves conflict - * Fixed pickle support of some exceptions - -* Boris Figovsky - * Allowed to skip SSL hostname verification - -* Cory Benfield - * Stream method for Response objects. - * Return native strings in header values. - * Generate 'Host' header when using proxies. - -* Jason Robinson - * Add missing WrappedSocket.fileno method in PyOpenSSL - -* Audrius Butkevicius - * Fixed a race condition - -* Stanislav Vitkovskiy - * Added HTTPS (CONNECT) proxy support - -* Stephen Holsapple - * Added abstraction for granular control of request fields - -* Martin von Gagern - * Support for non-ASCII header parameters - -* Kevin Burke and Pavel Kirichenko - * Support for separate connect and request timeouts - -* Peter Waller - * HTTPResponse.tell() for determining amount received over the wire - -* Nipunn Koorapati - * Ignore default ports when comparing hosts for equality - -* Danilo @dbrgn - * Disabled TLS compression by default on Python 3.2+ - * Disabled TLS compression in pyopenssl contrib module - * Configurable cipher suites in pyopenssl contrib module - -* Roman Bogorodskiy - * Account retries on proxy errors - -* Nicolas Delaby - * Use the platform-specific CA certificate locations - -* Josh Schneier - * HTTPHeaderDict and associated tests and docs - * Bugfixes, docs, test coverage - -* Tahia Khan - * Added Timeout examples in docs - -* Arthur Grunseid - * source_address support and tests (with https://github.com/bui) - -* Ian Cordasco - * PEP8 Compliance and Linting - * Add ability to pass socket options to an HTTP Connection - -* Erik Tollerud - * Support for standard library io module. - -* Krishna Prasad - * Google App Engine documentation - -* Aaron Meurer - * Added Url.url, which unparses a Url - -* Evgeny Kapun - * Bugfixes - -* Benjamen Meyer - * Security Warning Documentation update for proper capture - -* Shivan Sornarajah - * Support for using ConnectionPool and PoolManager as context managers. - -* Alex Gaynor - * Updates to the default SSL configuration - -* Tomas Tomecek - * Implemented generator for getting chunks from chunked responses. - -* tlynn - * Respect the warning preferences at import. - -* David D. Riddle - * IPv6 bugfixes in testsuite - -* Thea Flowers - * App Engine environment tests. - * Documentation re-write. - -* John Krauss - * Clues to debugging problems with `cryptography` dependency in docs - -* Disassem - * Fix pool-default headers not applying for url-encoded requests like GET. - -* James Atherfold - * Bugfixes relating to cleanup of connections during errors. - -* Christian Pedersen - * IPv6 HTTPS proxy bugfix - -* Jordan Moldow - * Fix low-level exceptions leaking from ``HTTPResponse.stream()``. - * Bugfix for ``ConnectionPool.urlopen(release_conn=False)``. - * Creation of ``HTTPConnectionPool.ResponseCls``. - -* Predrag Gruevski - * Made cert digest comparison use a constant-time algorithm. - -* Adam Talsma - * Bugfix to ca_cert file paths. - -* Evan Meagher - * Bugfix related to `memoryview` usage in PyOpenSSL adapter - -* John Vandenberg - * Python 2.6 fixes; pyflakes and pep8 compliance - -* Andy Caldwell - * Bugfix related to reusing connections in indeterminate states. - -* Ville Skyttä - * Logging efficiency improvements, spelling fixes, Travis config. - -* Shige Takeda - * Started Recipes documentation and added a recipe about handling concatenated gzip data in HTTP response - -* Jess Shapiro - * Various character-encoding fixes/tweaks - * Disabling IPv6 DNS when IPv6 connections not supported - -* David Foster - * Ensure order of request and response headers are preserved. - -* Jeremy Cline - * Added connection pool keys by scheme - -* Aviv Palivoda - * History list to Retry object. - * HTTPResponse contains the last Retry object. - -* Nate Prewitt - * Ensure timeouts are not booleans and greater than zero. - * Fixed infinite loop in ``stream`` when amt=None. - * Added length_remaining to determine remaining data to be read. - * Added enforce_content_length to raise exception when incorrect content-length received. - -* Seth Michael Larson - * Created selectors backport that supports PEP 475. - -* Alexandre Dias - * Don't retry on timeout if method not in whitelist - -* Moinuddin Quadri - * Lazily load idna package - -* Tom White - * Made SOCKS handler differentiate socks5h from socks5 and socks4a from socks4. - -* Tim Burke - * Stop buffering entire deflate-encoded responses. - -* Tuukka Mustonen - * Add counter for status_forcelist retries. - -* Erik Rose - * Bugfix to pyopenssl vendoring - -* Wolfgang Richter - * Bugfix related to loading full certificate chains with PyOpenSSL backend. - -* Mike Miller - * Logging improvements to include the HTTP(S) port when opening a new connection - -* Ioannis Tziakos - * Fix ``util.selectors._fileobj_to_fd`` to accept ``long``. - * Update appveyor tox setup to use the 64bit python. - -* Akamai (through Jess Shapiro) - * Ongoing maintenance; 2017-2018 - -* Dominique Leuenberger - * Minor fixes in the test suite - -* Will Bond - * Add Python 2.6 support to ``contrib.securetransport`` - -* Aleksei Alekseev - * using auth info for socks proxy - -* Chris Wilcox - * Improve contribution guide - * Add ``HTTPResponse.geturl`` method to provide ``urllib2.urlopen().geturl()`` behavior - -* Bruce Merry - * Fix leaking exceptions when system calls are interrupted with zero timeout - -* Hugo van Kemenade - * Drop support for EOL Python 2.6 - -* Tim Bell - * Bugfix for responses with Content-Type: message/* logging warnings - -* Justin Bramley - * Add ability to handle multiple Content-Encodings - -* Katsuhiko YOSHIDA - * Remove Authorization header regardless of case when redirecting to cross-site - -* James Meickle - * Improve handling of Retry-After header - -* Chris Jerdonek - * Remove a spurious TypeError from the exception chain inside - HTTPConnectionPool._make_request(), also for BaseExceptions. - -* Jorge Lopez Silva - * Added support for forwarding requests through HTTPS proxies. - -* Benno Rice - * Allow cadata parameter to be passed to underlying ``SSLContext.load_verify_locations()``. - -* Keiichi Kobayashi - * Rename VerifiedHTTPSConnection to HTTPSConnection - -* Himanshu Garg - * DOC & LICENSE Update - -* Hod Bin Noon - * Test improvements - -* Chris Olufson - * Fix for connection not being released on HTTP redirect and response not preloaded - -* [Bastiaan Bakker] - * Support for logging session keys via environment variable ``SSLKEYLOGFILE`` (Python 3.8+) - -* [Ezzeri Esa] - * Ports and extends on types from typeshed - -* [Your name or handle] <[email or website]> - * [Brief summary of your changes] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 4edfedde27..0000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,5 +0,0 @@ -include README.rst CHANGES.rst LICENSE.txt CONTRIBUTORS.txt dev-requirements.txt Makefile -recursive-include dummyserver * -recursive-include test * -recursive-include docs * -recursive-exclude docs/_build * diff --git a/README.md b/README.md new file mode 100644 index 0000000000..695ba4163a --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +

+ +![urllib3](https://github.com/urllib3/urllib3/raw/main/docs/_static/banner_github.svg) + +

+ +

+ PyPI Version + Python Versions + Join our Discord + Coverage Status + Build Status on GitHub + Documentation Status
+ OpenSSF Scorecard + SLSA 3 + CII Best Practices +

+ +urllib3 is a powerful, *user-friendly* HTTP client for Python. Much of the +Python ecosystem already uses urllib3 and you should too. +urllib3 brings many critical features that are missing from the Python +standard libraries: + +- Thread safety. +- Connection pooling. +- Client-side SSL/TLS verification. +- File uploads with multipart encoding. +- Helpers for retrying requests and dealing with HTTP redirects. +- Support for gzip, deflate, brotli, and zstd encoding. +- Proxy support for HTTP and SOCKS. +- 100% test coverage. + +urllib3 is powerful and easy to use: + +```python3 +>>> import urllib3 +>>> http = urllib3.PoolManager() +>>> resp = http.request("GET", "http://httpbin.org/robots.txt") +>>> resp.status +200 +>>> resp.data +b"User-agent: *\nDisallow: /deny\n" +``` + +## Installing + +urllib3 can be installed with [pip](https://pip.pypa.io): + +```bash +$ python -m pip install urllib3 +``` + +Alternatively, you can grab the latest source code from [GitHub](https://github.com/urllib3/urllib3): + +```bash +$ git clone https://github.com/urllib3/urllib3.git +$ cd urllib3 +$ pip install . +``` + + +## Documentation + +urllib3 has usage and reference documentation at [urllib3.readthedocs.io](https://urllib3.readthedocs.io). + + +## Community + +urllib3 has a [community Discord channel](https://discord.gg/urllib3) for asking questions and +collaborating with other contributors. Drop by and say hello 👋 + + +## Contributing + +urllib3 happily accepts contributions. Please see our +[contributing documentation](https://urllib3.readthedocs.io/en/latest/contributing.html) +for some tips on getting started. + + +## Security Disclosures + +To report a security vulnerability, please use the +[Tidelift security contact](https://tidelift.com/security). +Tidelift will coordinate the fix and disclosure with maintainers. + + +## Maintainers + +- [@sethmlarson](https://github.com/sethmlarson) (Seth M. Larson) +- [@pquentin](https://github.com/pquentin) (Quentin Pradet) +- [@theacodes](https://github.com/theacodes) (Thea Flowers) +- [@haikuginger](https://github.com/haikuginger) (Jess Shapiro) +- [@lukasa](https://github.com/lukasa) (Cory Benfield) +- [@sigmavirus24](https://github.com/sigmavirus24) (Ian Stapleton Cordasco) +- [@shazow](https://github.com/shazow) (Andrey Petrov) + +👋 + + +## Sponsorship + +If your company benefits from this library, please consider [sponsoring its +development](https://urllib3.readthedocs.io/en/latest/sponsors.html). + + +## For Enterprise + +Professional support for urllib3 is available as part of the [Tidelift +Subscription][1]. Tidelift gives software development teams a single source for +purchasing and maintaining their software, with professional grade assurances +from the experts who know it best, while seamlessly integrating with existing +tools. + +[1]: https://tidelift.com/subscription/pkg/pypi-urllib3?utm_source=pypi-urllib3&utm_medium=referral&utm_campaign=readme diff --git a/README.rst b/README.rst deleted file mode 100644 index 75b33750b6..0000000000 --- a/README.rst +++ /dev/null @@ -1,118 +0,0 @@ -.. raw:: html - -

- - urllib3 - -

-

- PyPI Version - Python Versions - Join our Discord - Coverage Status - Build Status on GitHub - Build Status on Travis - Documentation Status -

- -urllib3 is a powerful, *user-friendly* HTTP client for Python. Much of the -Python ecosystem already uses urllib3 and you should too. -urllib3 brings many critical features that are missing from the Python -standard libraries: - -- Thread safety. -- Connection pooling. -- Client-side SSL/TLS verification. -- File uploads with multipart encoding. -- Helpers for retrying requests and dealing with HTTP redirects. -- Support for gzip, deflate, and brotli encoding. -- Proxy support for HTTP and SOCKS. -- 100% test coverage. - -urllib3 is powerful and easy to use: - -.. code-block:: python - - >>> import urllib3 - >>> http = urllib3.PoolManager() - >>> r = http.request('GET', 'http://httpbin.org/robots.txt') - >>> r.status - 200 - >>> r.data - 'User-agent: *\nDisallow: /deny\n' - - -Installing ----------- - -urllib3 can be installed with `pip `_:: - - $ python -m pip install urllib3 - -Alternatively, you can grab the latest source code from `GitHub `_:: - - $ git clone git://github.com/urllib3/urllib3.git - $ python setup.py install - - -Documentation -------------- - -urllib3 has usage and reference documentation at `urllib3.readthedocs.io `_. - - -Contributing ------------- - -urllib3 happily accepts contributions. Please see our -`contributing documentation `_ -for some tips on getting started. - - -Security Disclosures --------------------- - -To report a security vulnerability, please use the -`Tidelift security contact `_. -Tidelift will coordinate the fix and disclosure with maintainers. - - -Maintainers ------------ - -- `@sethmlarson `__ (Seth M. Larson) -- `@pquentin `__ (Quentin Pradet) -- `@theacodes `__ (Thea Flowers) -- `@haikuginger `__ (Jess Shapiro) -- `@lukasa `__ (Cory Benfield) -- `@sigmavirus24 `__ (Ian Stapleton Cordasco) -- `@shazow `__ (Andrey Petrov) - -👋 - - -Sponsorship ------------ - -If your company benefits from this library, please consider `sponsoring its -development `_. - - -For Enterprise --------------- - -.. |tideliftlogo| image:: https://nedbatchelder.com/pix/Tidelift_Logos_RGB_Tidelift_Shorthand_On-White_small.png - :width: 75 - :alt: Tidelift - -.. list-table:: - :widths: 10 100 - - * - |tideliftlogo| - - Professional support for urllib3 is available as part of the `Tidelift - Subscription`_. Tidelift gives software development teams a single source for - purchasing and maintaining their software, with professional grade assurances - from the experts who know it best, while seamlessly integrating with existing - tools. - -.. _Tidelift Subscription: https://tidelift.com/subscription/pkg/pypi-urllib3?utm_source=pypi-urllib3&utm_medium=referral&utm_campaign=readme diff --git a/changelog/.gitignore b/changelog/.gitignore new file mode 100644 index 0000000000..f935021a8f --- /dev/null +++ b/changelog/.gitignore @@ -0,0 +1 @@ +!.gitignore diff --git a/changelog/README.rst b/changelog/README.rst new file mode 100644 index 0000000000..93c19bf2fc --- /dev/null +++ b/changelog/README.rst @@ -0,0 +1,31 @@ +This directory contains changelog entries: short files that contain a small +**ReST**-formatted text that will be added to ``CHANGES.rst`` by `towncrier +`__. + +The ``CHANGES.rst`` will be read by **users**, so this description should be aimed to +urllib3 users instead of describing internal changes which are only relevant to the +developers. + +Make sure to use full sentences in the **past tense** and use punctuation, examples:: + + Added support for HTTPS proxies contacting HTTPS servers. + + Upgraded ``urllib3.utils.parse_url()`` to be RFC 3986 compliant. + +Each file should be named like ``..rst``, where ```` is an issue +number, and ```` is one of the `five towncrier default types +`_ + +So for example: ``123.feature.rst``, ``456.bugfix.rst``. + +If your pull request fixes an issue, use that number here. If there is no issue, then +after you submit the pull request and get the pull request number you can add a +changelog using that instead. + +If your change does not deserve a changelog entry, apply the `Skip Changelog` GitHub +label to your pull request. + +You can also run ``nox -s docs`` to build the documentation with the draft changelog +(``docs/_build/html/changelog.html``) if you want to get a preview of how your change +will look in the final release notes. You can also see a preview from the Read the Docs +check in pull requests. diff --git a/ci/0001-Mark-100-Continue-tests-as-failing.patch b/ci/0001-Mark-100-Continue-tests-as-failing.patch new file mode 100644 index 0000000000..184b53bdf0 --- /dev/null +++ b/ci/0001-Mark-100-Continue-tests-as-failing.patch @@ -0,0 +1,60 @@ +diff --git a/tests/unit/test_awsrequest.py b/tests/unit/test_awsrequest.py +index 22bd9a7..862a244 100644 +--- a/tests/unit/test_awsrequest.py ++++ b/tests/unit/test_awsrequest.py +@@ -34,6 +34,7 @@ from botocore.compat import file_type, six + from botocore.exceptions import UnseekableStreamError + from tests import mock, unittest + ++import pytest + + class IgnoreCloseBytesIO(io.BytesIO): + def close(self): +@@ -370,6 +371,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + conn.response_class.return_value = self.mock_response + return conn + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_expect_100_continue_returned(self): + with mock.patch('urllib3.util.wait_for_read') as wait_mock: + # Shows the server first sending a 100 continue response +@@ -387,6 +389,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + # Now we should verify that our final response is the 200 OK + self.assertEqual(response.status, 200) + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_handles_expect_100_with_different_reason_phrase(self): + with mock.patch('urllib3.util.wait_for_read') as wait_mock: + # Shows the server first sending a 100 continue response +@@ -412,6 +415,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + # continue. + self.assertIn(b'body', s.sent_data) + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_expect_100_sends_connection_header(self): + # When using squid as an HTTP proxy, it will also send + # a Connection: keep-alive header back with the 100 continue +@@ -439,6 +443,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + response = conn.getresponse() + self.assertEqual(response.status, 500) + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_expect_100_continue_sends_307(self): + # This is the case where we send a 100 continue and the server + # immediately sends a 307 +@@ -461,6 +466,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + # Now we should verify that our final response is the 307. + self.assertEqual(response.status, 307) + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_expect_100_continue_no_response_from_server(self): + with mock.patch('urllib3.util.wait_for_read') as wait_mock: + # Shows the server first sending a 100 continue response +@@ -566,6 +572,7 @@ class TestAWSHTTPConnection(unittest.TestCase): + response = conn.getresponse() + self.assertEqual(response.status, 200) + ++ @pytest.mark.xfail(reason="https://github.com/urllib3/urllib3/pull/2565") + def test_state_reset_on_connection_close(self): + # This simulates what urllib3 does with connections + # in its connection pool logic. diff --git a/ci/0002-Stop-relying-on-removed-DEFAULT_CIPHERS.patch b/ci/0002-Stop-relying-on-removed-DEFAULT_CIPHERS.patch new file mode 100644 index 0000000000..e533a8c1d4 --- /dev/null +++ b/ci/0002-Stop-relying-on-removed-DEFAULT_CIPHERS.patch @@ -0,0 +1,34 @@ +From dcc55a54fe2ba3b403923e95ab329009a9f430e2 Mon Sep 17 00:00:00 2001 +From: Quentin Pradet +Date: Fri, 19 Aug 2022 11:02:11 +0400 +Subject: [PATCH] Stop relying on removed DEFAULT_CIPHERS + +--- + botocore/httpsession.py | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/botocore/httpsession.py b/botocore/httpsession.py +index 29b210377..aaecb454b 100644 +--- a/botocore/httpsession.py ++++ b/botocore/httpsession.py +@@ -19,7 +19,6 @@ from urllib3.exceptions import ReadTimeoutError as URLLib3ReadTimeoutError + from urllib3.exceptions import SSLError as URLLib3SSLError + from urllib3.util.retry import Retry + from urllib3.util.ssl_ import ( +- DEFAULT_CIPHERS, + OP_NO_COMPRESSION, + PROTOCOL_TLS, + OP_NO_SSLv2, +@@ -99,7 +98,8 @@ def create_urllib3_context( + + context = SSLContext(ssl_version) + +- context.set_ciphers(ciphers or DEFAULT_CIPHERS) ++ if ciphers: ++ context.set_ciphers(ciphers) + + # Setting the default here, as we may have no ssl module on import + cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs +-- +2.37.2 + diff --git a/ci/deploy.sh b/ci/deploy.sh index 503f072381..322ca26fb9 100755 --- a/ci/deploy.sh +++ b/ci/deploy.sh @@ -2,6 +2,6 @@ set -exo pipefail -python3 -m pip install --upgrade twine wheel -python3 setup.py sdist bdist_wheel +python3 -m pip install --upgrade twine wheel build +python3 -m build python3 -m twine upload dist/* -u $PYPI_USERNAME -p $PYPI_PASSWORD --skip-existing diff --git a/ci/downstream/botocore.sh b/ci/downstream/botocore.sh deleted file mode 100755 index 49ec3fe8ba..0000000000 --- a/ci/downstream/botocore.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -set -exo pipefail - -case "${1}" in - install) - git clone --depth 1 https://github.com/boto/botocore - cd botocore - git rev-parse HEAD - python scripts/ci/install - ;; - run) - cd botocore - python scripts/ci/run-tests - ;; - *) - exit 1 - ;; -esac diff --git a/ci/downstream/requests-requirements.txt b/ci/downstream/requests-requirements.txt deleted file mode 100644 index 82e33e9ae2..0000000000 --- a/ci/downstream/requests-requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -pytest-mock -pysocks -httpbin - -# psf/requests#5049 -pytest<4.1 - -# psf/requests#5004 -pytest-httpbin==0.3.0 diff --git a/ci/downstream/requests.sh b/ci/downstream/requests.sh deleted file mode 100755 index f7afdb8bdb..0000000000 --- a/ci/downstream/requests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -set -exo pipefail - -case "${1}" in - install) - git clone --depth 1 https://github.com/psf/requests - cd requests - git rev-parse HEAD - python -m pip install -r ${TRAVIS_BUILD_DIR}/ci/downstream/requests-requirements.txt - python -m pip install . - ;; - run) - cd requests - pytest tests/ - ;; - *) - exit 1 - ;; -esac diff --git a/ci/install.sh b/ci/install.sh deleted file mode 100755 index adda3d13cb..0000000000 --- a/ci/install.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -set -exo pipefail - - -# Linux Setup -# Even when testing on Python 2, we need Python 3 for Nox. This detects if -# we're in one of the Travis Python 2 sessions and sets up the Python 3 install -# for Nox. -if ! python3 -m pip --version; then - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - sudo python3 get-pip.py - # https://github.com/theacodes/nox/issues/328 - sudo python3 -m pip install nox==2019.11.9 -else - # We're not in "dual Python" mode, so we can just install Nox normally. - python3 -m pip install nox -fi - -if [[ "${NOX_SESSION}" == "app_engine" ]]; then - python -m pip install gcp-devrel-py-tools - gcp-devrel-py-tools download-appengine-sdk "$(dirname ${GAE_SDK_PATH})" -fi diff --git a/ci/run.sh b/ci/run.sh deleted file mode 100755 index bfed9f9c3e..0000000000 --- a/ci/run.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -set -exo pipefail - -if [ -n "${NOX_SESSION}" ]; then - nox -s "${NOX_SESSION}" --error-on-missing-interpreters -else - downstream_script="${TRAVIS_BUILD_DIR}/ci/downstream/${DOWNSTREAM}.sh" - if [ ! -x "$downstream_script" ]; then - exit 1 - fi - $downstream_script install - python -m pip install . - $downstream_script run -fi diff --git a/ci/run_tests.sh b/ci/run_tests.sh index df8726a758..fa152d6d52 100755 --- a/ci/run_tests.sh +++ b/ci/run_tests.sh @@ -1,4 +1,6 @@ #!/bin/bash -NOX_SESSION=test-${PYTHON_VERSION%-dev} +if [[ -z "$NOX_SESSION" ]]; then + NOX_SESSION=test-${PYTHON_VERSION%-dev} +fi nox -s $NOX_SESSION --error-on-missing-interpreters diff --git a/ci/upload_coverage.sh b/ci/upload_coverage.sh deleted file mode 100755 index 7d0fee0114..0000000000 --- a/ci/upload_coverage.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -set -exo pipefail - -# Cribbed from Trio's ci.sh -function curl-harder() { - for BACKOFF in 0 1 2 4 8 15 15 15 15; do - sleep $BACKOFF - if curl -fL --connect-timeout 5 "$@"; then - return 0 - fi - done - return 1 -} - -if [ "$JOB_NAME" = "" ]; then - JOB_NAME="${TRAVIS_OS_NAME}-${TRAVIS_PYTHON_VERSION:-unknown}" -fi - -curl-harder -o codecov.sh https://codecov.io/bash -bash codecov.sh -f coverage.xml -n $JOB_NAME diff --git a/codecov.yml b/codecov.yml deleted file mode 100644 index 3920f5c152..0000000000 --- a/codecov.yml +++ /dev/null @@ -1,8 +0,0 @@ -coverage: - status: - patch: - default: - target: '100' - project: - default: - target: '100' diff --git a/dev-requirements.txt b/dev-requirements.txt index ec1e705f66..10b2f2baf8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,19 +1,11 @@ -mock==3.0.5 -coverage~=5.0 -tornado==5.1.1;python_version<="2.7" -tornado==6.0.3;python_version>="3.5" +coverage==7.0.4 +freezegun==1.2.2 +tornado==6.2 PySocks==1.7.1 -# https://github.com/Anorov/PySocks/issues/131 -win-inet-pton==1.1.0 -pytest==4.6.9 -pytest-timeout==1.3.4 -pytest-freezegun==0.4.2 -flaky==3.6.1 -trustme==0.5.3 -cryptography==2.8 -gcp-devrel-py-tools==0.0.15 - -# https://github.com/GoogleCloudPlatform/python-repo-tools/issues/23 -pylint<2.0;python_version<="2.7" - -python-dateutil==2.8.1 +pytest==7.2.0 +pytest-timeout==2.1.0 +trustme==0.9.0 +cryptography==39.0.1 +backports.zoneinfo==0.2.1;python_version<"3.9" +towncrier==21.9.0 +pytest-memray==1.4.0;python_version>="3.8" and python_version<"3.12" and sys_platform!="win32" and implementation_name=="cpython" diff --git a/docs/images/banner.svg b/docs/_static/banner.svg similarity index 100% rename from docs/images/banner.svg rename to docs/_static/banner.svg diff --git a/docs/_static/banner_github.svg b/docs/_static/banner_github.svg new file mode 100644 index 0000000000..069aa19870 --- /dev/null +++ b/docs/_static/banner_github.svg @@ -0,0 +1,13 @@ + + + Layer 1 + + + + + + + + + + \ No newline at end of file diff --git a/docs/_static/dark-logo.svg b/docs/_static/dark-logo.svg new file mode 100644 index 0000000000..a11012fb16 --- /dev/null +++ b/docs/_static/dark-logo.svg @@ -0,0 +1 @@ + diff --git a/docs/advanced-usage.rst b/docs/advanced-usage.rst index a0e68e4e46..6d6c176826 100644 --- a/docs/advanced-usage.rst +++ b/docs/advanced-usage.rst @@ -11,10 +11,13 @@ The :class:`~poolmanager.PoolManager` class automatically handles creating :class:`~connectionpool.ConnectionPool` instances for each host as needed. By default, it will keep a maximum of 10 :class:`~connectionpool.ConnectionPool` instances. If you're making requests to many different hosts it might improve -performance to increase this number:: +performance to increase this number. - >>> import urllib3 - >>> http = urllib3.PoolManager(num_pools=50) +.. code-block:: python + + import urllib3 + + http = urllib3.PoolManager(num_pools=50) However, keep in mind that this does increase memory and socket consumption. @@ -23,12 +26,15 @@ of individual :class:`~connection.HTTPConnection` instances. These connections are used during an individual request and returned to the pool when the request is complete. By default only one connection will be saved for re-use. If you are making many requests to the same host simultaneously it might improve -performance to increase this number:: +performance to increase this number. + +.. code-block:: python + + import urllib3 - >>> import urllib3 - >>> http = urllib3.PoolManager(maxsize=10) + http = urllib3.PoolManager(maxsize=10) # Alternatively - >>> http = urllib3.HTTPConnectionPool('google.com', maxsize=10) + pool = urllib3.HTTPConnectionPool("google.com", maxsize=10) The behavior of the pooling for :class:`~connectionpool.ConnectionPool` is different from :class:`~poolmanager.PoolManager`. By default, if a new @@ -37,11 +43,14 @@ connection will be created. However, this connection will not be saved if more than ``maxsize`` connections exist. This means that ``maxsize`` does not determine the maximum number of connections that can be open to a particular host, just the maximum number of connections to keep in the pool. However, if you specify ``block=True`` then there can be at most ``maxsize`` connections -open to a particular host:: +open to a particular host. + +.. code-block:: python + + http = urllib3.PoolManager(maxsize=10, block=True) - >>> http = urllib3.PoolManager(maxsize=10, block=True) # Alternatively - >>> http = urllib3.HTTPConnectionPool('google.com', maxsize=10, block=True) + pool = urllib3.HTTPConnectionPool("google.com", maxsize=10, block=True) Any new requests will block until a connection is available from the pool. This is a great way to prevent flooding a host with too many connections in @@ -53,60 +62,107 @@ multi-threaded applications. Streaming and I/O ----------------- -When dealing with large responses it's often better to stream the response -content:: - - >>> import urllib3 - >>> http = urllib3.PoolManager() - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/bytes/1024', - ... preload_content=False) - >>> for chunk in r.stream(32): - ... print(chunk) - b'...' - b'...' - ... - >>> r.release_conn() - -Setting ``preload_content`` to ``False`` means that urllib3 will stream the -response content. :meth:`~response.HTTPResponse.stream` lets you iterate over -chunks of the response content. - -.. note:: When using ``preload_content=False``, you should call - :meth:`~response.HTTPResponse.release_conn` to release the http connection - back to the connection pool so that it can be re-used. +When using ``preload_content=True`` (the default setting) the +response body will be read immediately into memory and the HTTP connection +will be released back into the pool without manual intervention. + +However, when dealing with large responses it's often better to stream the response +content using ``preload_content=False``. Setting ``preload_content`` to ``False`` means +that urllib3 will only read from the socket when data is requested. + +.. note:: When using ``preload_content=False``, you need to manually release + the HTTP connection back to the connection pool so that it can be re-used. + To ensure the HTTP connection is in a valid state before being re-used + all data should be read off the wire. + + You can call the :meth:`~response.HTTPResponse.drain_conn` to throw away + unread data still on the wire. This call isn't necessary if data has already + been completely read from the response. + + After all data is read you can call :meth:`~response.HTTPResponse.release_conn` + to release the connection into the pool. + + You can call the :meth:`~response.HTTPResponse.close` to close the connection, + but this call doesn’t return the connection to the pool, throws away the unread + data on the wire, and leaves the connection in an undefined protocol state. + This is desirable if you prefer not reading data from the socket to re-using the + HTTP connection. + +:meth:`~response.HTTPResponse.stream` lets you iterate over chunks of the response content. + +.. code-block:: python + + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/bytes/1024", + preload_content=False + ) + + for chunk in resp.stream(32): + print(chunk) + # b"\x9e\xa97'\x8e\x1eT .... + + resp.release_conn() However, you can also treat the :class:`~response.HTTPResponse` instance as -a file-like object. This allows you to do buffering:: +a file-like object. This allows you to do buffering: + +.. code-block:: python - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/bytes/1024', - ... preload_content=False) - >>> r.read(4) - b'\x88\x1f\x8b\xe5' + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/bytes/1024", + preload_content=False + ) + + print(resp.read(4)) + # b"\x88\x1f\x8b\xe5" Calls to :meth:`~response.HTTPResponse.read()` will block until more response data is available. - >>> import io - >>> reader = io.BufferedReader(r, 8) - >>> reader.read(4) - >>> r.release_conn() +.. code-block:: python + + import io + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/bytes/1024", + preload_content=False + ) + + reader = io.BufferedReader(resp, 8) + print(reader.read(4)) + # b"\xbf\x9c\xd6" + + resp.release_conn() You can use this file-like object to do things like decode the content using -:mod:`codecs`:: - - >>> import codecs - >>> reader = codecs.getreader('utf-8') - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/ip', - ... preload_content=False) - >>> json.load(reader(r)) - {'origin': '127.0.0.1'} - >>> r.release_conn() +:mod:`codecs`: + +.. code-block:: python + + import codecs + import json + import urllib3 + + reader = codecs.getreader("utf-8") + + resp = urllib3.request( + "GET", + "https://httpbin.org/ip", + preload_content=False + ) + + print(json.load(reader(resp))) + # {"origin": "127.0.0.1"} + + resp.release_conn() .. _proxies: @@ -114,11 +170,14 @@ Proxies ------- You can use :class:`~poolmanager.ProxyManager` to tunnel requests through an -HTTP proxy:: +HTTP proxy: + +.. code-block:: python + + import urllib3 - >>> import urllib3 - >>> proxy = urllib3.ProxyManager('http://localhost:3128/') - >>> proxy.request('GET', 'http://google.com/') + proxy = urllib3.ProxyManager("https://localhost:3128/") + proxy.request("GET", "https://google.com/") The usage of :class:`~poolmanager.ProxyManager` is the same as :class:`~poolmanager.PoolManager`. @@ -168,6 +227,78 @@ an `absolute URI `_ if the **only use this option with trusted or corporate proxies** as the proxy will have full visibility of your requests. +.. _https_proxy_error_http_proxy: + +Your proxy appears to only use HTTP and not HTTPS +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're receiving the :class:`~urllib3.exceptions.ProxyError` and it mentions +your proxy only speaks HTTP and not HTTPS here's what to do to solve your issue: + +If you're using ``urllib3`` directly, make sure the URL you're passing into :class:`urllib3.ProxyManager` +starts with ``http://`` instead of ``https://``: + +.. code-block:: python + + # Do this: + http = urllib3.ProxyManager("http://...") + + # Not this: + http = urllib3.ProxyManager("https://...") + +If instead you're using ``urllib3`` through another library like Requests +there are multiple ways your proxy could be mis-configured. You need to figure out +where the configuration isn't correct and make the fix there. Some common places +to look are environment variables like ``HTTP_PROXY``, ``HTTPS_PROXY``, and ``ALL_PROXY``. + +Ensure that the values for all of these environment variables starts with ``http://`` +and not ``https://``: + +.. code-block:: bash + + # Check your existing environment variables in bash + $ env | grep "_PROXY" + HTTP_PROXY=http://127.0.0.1:8888 + HTTPS_PROXY=https://127.0.0.1:8888 # <--- This setting is the problem! + + # Make the fix in your current session and test your script + $ export HTTPS_PROXY="http://127.0.0.1:8888" + $ python test-proxy.py # This should now pass. + + # Persist your change in your shell 'profile' (~/.bashrc, ~/.profile, ~/.bash_profile, etc) + # You may need to logout and log back in to ensure this works across all programs. + $ vim ~/.bashrc + +If you're on Windows or macOS your proxy may be getting set at a system level. +To check this first ensure that the above environment variables aren't set +then run the following: + +.. code-block:: bash + + $ python -c 'import urllib.request; print(urllib.request.getproxies())' + +If the output of the above command isn't empty and looks like this: + +.. code-block:: python + + { + "http": "http://127.0.0.1:8888", + "https": "https://127.0.0.1:8888" # <--- This setting is the problem! + } + +Search how to configure proxies on your operating system and change the ``https://...`` URL into ``http://``. +After you make the change the return value of ``urllib.request.getproxies()`` should be: + +.. code-block:: python + + { # Everything is good here! :) + "http": "http://127.0.0.1:8888", + "https": "http://127.0.0.1:8888" + } + +If you still can't figure out how to configure your proxy after all these steps +please `join our community Discord `_ and we'll try to help you with your issue. + SOCKS Proxies ~~~~~~~~~~~~~ @@ -175,16 +306,21 @@ SOCKS Proxies For SOCKS, you can use :class:`~contrib.socks.SOCKSProxyManager` to connect to SOCKS4 or SOCKS5 proxies. In order to use SOCKS proxies you will need to install `PySocks `_ or install urllib3 with -the ``socks`` extra:: +the ``socks`` extra: + +.. code-block:: bash python -m pip install urllib3[socks] Once PySocks is installed, you can use -:class:`~contrib.socks.SOCKSProxyManager`:: +:class:`~contrib.socks.SOCKSProxyManager`: + +.. code-block:: python - >>> from urllib3.contrib.socks import SOCKSProxyManager - >>> proxy = SOCKSProxyManager('socks5h://localhost:8889/') - >>> proxy.request('GET', 'http://google.com/') + from urllib3.contrib.socks import SOCKSProxyManager + + proxy = SOCKSProxyManager("socks5h://localhost:8889/") + proxy.request("GET", "https://google.com/") .. note:: It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in @@ -201,12 +337,17 @@ Instead of using `certifi `_ you can provide your own certificate authority bundle. This is useful for cases where you've generated your own certificates or when you're using a private certificate authority. Just provide the full path to the certificate bundle when creating a -:class:`~poolmanager.PoolManager`:: +:class:`~poolmanager.PoolManager`: + +.. code-block:: python - >>> import urllib3 - >>> http = urllib3.PoolManager( - ... cert_reqs='CERT_REQUIRED', - ... ca_certs='/path/to/your/certificate_bundle') + import urllib3 + + http = urllib3.PoolManager( + cert_reqs="CERT_REQUIRED", + ca_certs="/path/to/your/certificate_bundle" + ) + resp = http.request("GET", "https://example.com") When you specify your own certificate bundle only requests that can be verified with that bundle will succeed. It's recommended to use a separate @@ -228,20 +369,22 @@ Normally, urllib3 takes care of setting and checking these values for you when you connect to a host by name. However, it's sometimes useful to set a connection's expected Host header and certificate hostname (subject), especially when you are connecting without using name resolution. For example, -you could connect to a server by IP using HTTPS like so:: - - >>> import urllib3 - >>> pool = urllib3.HTTPSConnectionPool( - ... "10.0.0.10", - ... assert_hostname="example.org", - ... server_hostname="example.org" - ... ) - >>> pool.urlopen( - ... "GET", - ... "/", - ... headers={"Host": "example.org"}, - ... assert_same_host=False - ... ) +you could connect to a server by IP using HTTPS like so: + +.. code-block:: python + + import urllib3 + + pool = urllib3.HTTPSConnectionPool( + "104.154.89.105", + server_hostname="badssl.com" + ) + pool.request( + "GET", + "/", + headers={"Host": "badssl.com"}, + assert_same_host=False + ) Note that when you use a connection in this way, you must specify @@ -252,6 +395,25 @@ address that you would like to use. The IP may be for a private interface, or you may want to use a specific host under round-robin DNS. +.. _assert_hostname: + +Verifying TLS against a different host +-------------------------------------- + +If the server you're connecting to presents a different certificate than the +hostname or the SNI hostname, you can use ``assert_hostname``: + +.. code-block:: python + + import urllib3 + + pool = urllib3.HTTPSConnectionPool( + "wrong.host.badssl.com", + assert_hostname="badssl.com", + ) + pool.request("GET", "/") + + .. _ssl_client: Client Certificates @@ -260,24 +422,62 @@ Client Certificates You can also specify a client certificate. This is useful when both the server and the client need to verify each other's identity. Typically these certificates are issued from the same authority. To use a client certificate, -provide the full path when creating a :class:`~poolmanager.PoolManager`:: +provide the full path when creating a :class:`~poolmanager.PoolManager`: + +.. code-block:: python - >>> http = urllib3.PoolManager( - ... cert_file='/path/to/your/client_cert.pem', - ... cert_reqs='CERT_REQUIRED', - ... ca_certs='/path/to/your/certificate_bundle') + http = urllib3.PoolManager( + cert_file="/path/to/your/client_cert.pem", + cert_reqs="CERT_REQUIRED", + ca_certs="/path/to/your/certificate_bundle" + ) If you have an encrypted client certificate private key you can use -the ``key_password`` parameter to specify a password to decrypt the key. :: +the ``key_password`` parameter to specify a password to decrypt the key. - >>> http = urllib3.PoolManager( - ... cert_file='/path/to/your/client_cert.pem', - ... cert_reqs='CERT_REQUIRED', - ... key_file='/path/to/your/client.key', - ... key_password='keyfile_password') +.. code-block:: python + + http = urllib3.PoolManager( + cert_file="/path/to/your/client_cert.pem", + cert_reqs="CERT_REQUIRED", + key_file="/path/to/your/client.key", + key_password="keyfile_password" + ) If your key isn't encrypted the ``key_password`` parameter isn't required. +TLS minimum and maximum versions +-------------------------------- + +When the configured TLS versions by urllib3 aren't compatible with the TLS versions that +the server is willing to use you'll likely see an error like this one: + +.. code-block:: + + SSLError(1, '[SSL: UNSUPPORTED_PROTOCOL] unsupported protocol (_ssl.c:1124)') + +Starting in v2.0 by default urllib3 uses TLS 1.2 and later so servers that only support TLS 1.1 +or earlier will not work by default with urllib3. + +To fix the issue you'll need to use the ``ssl_minimum_version`` option along with the `TLSVersion enum`_ +in the standard library ``ssl`` module to configure urllib3 to accept a wider range of TLS versions. + +For the best security it's a good idea to set this value to the version of TLS that's being used by the +server. For example if the server requires TLS 1.0 you'd configure urllib3 like so: + +.. code-block:: python + + import ssl + import urllib3 + + http = urllib3.PoolManager( + ssl_minimum_version=ssl.TLSVersion.TLSv1 + ) + # This request works! + resp = http.request("GET", "https://tls-v1-0.badssl.com:1010") + +.. _TLSVersion enum: https://docs.python.org/3/library/ssl.html#ssl.TLSVersion + .. _ssl_mac: .. _certificate_validation_and_mac_os_x: @@ -308,94 +508,83 @@ be resolved in different ways. This happens when a request is made to an HTTPS URL without certificate verification enabled. Follow the :ref:`certificate verification ` guide to resolve this warning. -* :class:`~exceptions.InsecurePlatformWarning` - This happens on Python 2 platforms that have an outdated :mod:`ssl` module. - These older :mod:`ssl` modules can cause some insecure requests to succeed - where they should fail and secure requests to fail where they should - succeed. Follow the :ref:`pyOpenSSL ` guide to resolve this - warning. - -.. _sni_warning: - -* :class:`~exceptions.SNIMissingWarning` - This happens on Python 2 versions older than 2.7.9. These older versions - lack `SNI `_ support. - This can cause servers to present a certificate that the client thinks is - invalid. Follow the :ref:`pyOpenSSL ` guide to resolve this - warning. .. _disable_ssl_warnings: Making unverified HTTPS requests is **strongly** discouraged, however, if you understand the risks and wish to disable these warnings, you can use :func:`~urllib3.disable_warnings`: -.. code-block:: pycon +.. code-block:: python - >>> import urllib3 - >>> urllib3.disable_warnings() + import urllib3 + + urllib3.disable_warnings() Alternatively you can capture the warnings with the standard :mod:`logging` module: -.. code-block:: pycon +.. code-block:: python - >>> logging.captureWarnings(True) + logging.captureWarnings(True) Finally, you can suppress the warnings at the interpreter level by setting the ``PYTHONWARNINGS`` environment variable or by using the `-W flag `_. -Google App Engine ------------------ +Brotli Encoding +--------------- -urllib3 supports `Google App Engine `_ with -some caveats. +Brotli is a compression algorithm created by Google with better compression +than gzip and deflate and is supported by urllib3 if the +`Brotli `_ package or +`brotlicffi `_ package is installed. +You may also request the package be installed via the ``urllib3[brotli]`` extra: -If you're using the `Flexible environment -`_, you do not have to do -any configuration- urllib3 will just work. However, if you're using the -`Standard environment `_ then -you either have to use :mod:`urllib3.contrib.appengine`'s -:class:`~urllib3.contrib.appengine.AppEngineManager` or use the `Sockets API -`_ +.. code-block:: bash -To use :class:`~urllib3.contrib.appengine.AppEngineManager`: + $ python -m pip install urllib3[brotli] -.. code-block:: pycon +Here's an example using brotli encoding via the ``Accept-Encoding`` header: - >>> from urllib3.contrib.appengine import AppEngineManager - >>> http = AppEngineManager() - >>> http.request('GET', 'https://google.com/') +.. code-block:: python -To use the Sockets API, add the following to your app.yaml and use -:class:`~urllib3.poolmanager.PoolManager` as usual: + import urllib3 -.. code-block:: yaml + urllib3.request( + "GET", + "https://www.google.com/", + headers={"Accept-Encoding": "br"} + ) - env_variables: - GAE_USE_SOCKETS_HTTPLIB : 'true' +Zstandard Encoding +------------------ -For more details on the limitations and gotchas, see -:mod:`urllib3.contrib.appengine`. +`Zstandard `_ +is a compression algorithm created by Facebook with better compression +than brotli, gzip and deflate (see `benchmarks `_) +and is supported by urllib3 if the `zstandard package `_ is installed. +You may also request the package be installed via the ``urllib3[zstd]`` extra: -Brotli Encoding ---------------- +.. code-block:: bash -Brotli is a compression algorithm created by Google with better compression -than gzip and deflate and is supported by urllib3 if the -`brotlipy `_ package is installed. -You may also request the package be installed via the ``urllib3[brotli]`` extra: + $ python -m pip install urllib3[zstd] -.. code-block:: bash +.. note:: - $ python -m pip install urllib3[brotli] + Zstandard support in urllib3 requires using v0.18.0 or later of the ``zstandard`` package. + If the version installed is less than v0.18.0 then Zstandard support won't be enabled. -Here's an example using brotli encoding via the ``Accept-Encoding`` header: +Here's an example using zstd encoding via the ``Accept-Encoding`` header: + +.. code-block:: python + + import urllib3 -.. code-block:: pycon + urllib3.request( + "GET", + "https://www.facebook.com/", + headers={"Accept-Encoding": "zstd"} + ) - >>> from urllib3 import PoolManager - >>> http = PoolManager() - >>> http.request('GET', 'https://www.google.com/', headers={'Accept-Encoding': 'br'}) Decrypting Captured TLS Sessions with Wireshark ----------------------------------------------- @@ -411,3 +600,34 @@ To enable this simply define environment variable `SSLKEYLOGFILE`: Then configure the key logfile in `Wireshark `_, see `Wireshark TLS Decryption `_ for instructions. + +Custom SSL Contexts +------------------- + +You can exercise fine-grained control over the urllib3 SSL configuration by +providing a :class:`ssl.SSLContext ` object. For purposes +of compatibility, we recommend you obtain one from +:func:`~urllib3.util.create_urllib3_context`. + +Once you have a context object, you can mutate it to achieve whatever effect +you'd like. For example, the code below loads the default SSL certificates, sets +the :data:`ssl.OP_ENABLE_MIDDLEBOX_COMPAT` +flag that isn't set by default, and then makes a HTTPS request: + +.. code-block:: python + + import ssl + + from urllib3 import PoolManager + from urllib3.util import create_urllib3_context + + ctx = create_urllib3_context() + ctx.load_default_certs() + ctx.options |= ssl.OP_ENABLE_MIDDLEBOX_COMPAT + + with PoolManager(ssl_context=ctx) as pool: + pool.request("GET", "https://www.google.com/") + +Note that this is different from passing an ``options`` argument to +:func:`~urllib3.util.create_urllib3_context` because we don't overwrite +the default options: we only add a new one. diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 0000000000..26a877a07e --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,5 @@ +========= +Changelog +========= + +.. include:: ../CHANGES.rst diff --git a/docs/conf.py b/docs/conf.py index 1f695dc215..5ab6a68d8b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations import os import sys @@ -11,22 +11,17 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, root_path) -# Mock some expensive/platform-specific modules so build will work. -# (https://read-the-docs.readthedocs.io/en/latest/faq.html#\ -# i-get-import-errors-on-libraries-that-depend-on-c-modules) -import mock +# https://docs.readthedocs.io/en/stable/builds.html#build-environment +if "READTHEDOCS" in os.environ: + import glob + if glob.glob("../changelog/*.*.rst"): + print("-- Found changes; running towncrier --", flush=True) + import subprocess -class MockModule(mock.Mock): - @classmethod - def __getattr__(cls, name): - return MockModule() - - -MOCK_MODULES = ("ntlm",) - -sys.modules.update((mod_name, MockModule()) for mod_name in MOCK_MODULES) - + subprocess.run( + ["towncrier", "--yes", "--date", "not released yet"], cwd="..", check=True + ) import urllib3 @@ -37,6 +32,7 @@ def __getattr__(cls, name): # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "sphinx.ext.autodoc", + "sphinx_copybutton", "sphinx.ext.doctest", "sphinx.ext.intersphinx", ] @@ -55,7 +51,7 @@ def __getattr__(cls, name): # General information about the project. project = "urllib3" -copyright = "{year}, Andrey Petrov".format(year=date.today().year) +copyright = f"{date.today().year}, Andrey Petrov" # The short X.Y version. version = urllib3.__version__ @@ -73,16 +69,49 @@ def __getattr__(cls, name): # a list of builtin themes. html_theme = "furo" html_favicon = "images/favicon.png" -html_logo = "images/banner.svg" +html_static_path = ["_static"] html_theme_options = { "announcement": """ - Sponsor urllib3 v2.0 on Open Collective + href=\"https://github.com/sponsors/urllib3\"> + Support urllib3 on GitHub Sponsors """, "sidebar_hide_name": True, + "light_logo": "banner.svg", + "dark_logo": "dark-logo.svg", } intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} + +# Show typehints as content of the function or method +autodoc_typehints = "description" + +# Warn about all references to unknown targets +nitpicky = True +# Except for these ones, which we expect to point to unknown targets: +nitpick_ignore = [ + ("py:class", "_TYPE_SOCKS_OPTIONS"), + ("py:class", "_TYPE_SOCKET_OPTIONS"), + ("py:class", "_TYPE_TIMEOUT"), + ("py:class", "_TYPE_FIELD_VALUE"), + ("py:class", "_TYPE_BODY"), + ("py:class", "_HttplibHTTPResponse"), + ("py:class", "_HttplibHTTPMessage"), + ("py:class", "TracebackType"), + ("py:class", "Literal"), + ("py:class", "email.errors.MessageDefect"), + ("py:class", "MessageDefect"), + ("py:class", "http.client.HTTPMessage"), + ("py:class", "RequestHistory"), + ("py:class", "SSLTransportType"), + ("py:class", "VerifyMode"), + ("py:class", "_ssl._SSLContext"), + ("py:class", "urllib3._collections.HTTPHeaderDict"), + ("py:class", "urllib3._collections.RecentlyUsedContainer"), + ("py:class", "urllib3._request_methods.RequestMethods"), + ("py:class", "urllib3.contrib.socks._TYPE_SOCKS_OPTIONS"), + ("py:class", "urllib3.util.timeout._TYPE_DEFAULT"), + ("py:class", "BaseHTTPConnection"), +] diff --git a/docs/contributing.rst b/docs/contributing.rst index 05cac9ca0e..a75e5f6901 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -15,8 +15,9 @@ If you wish to add a new feature or fix a bug: as expected. #. Format your changes with black using command `$ nox -rs format` and lint your changes using command `nox -rs lint`. +#. Add a `changelog entry + `__. #. Send a pull request and bug the maintainer until it gets merged and published. - :) Make sure to add yourself to ``CONTRIBUTORS.txt``. Setting up your development environment @@ -25,7 +26,7 @@ Setting up your development environment In order to setup the development environment all that you need is `nox `_ installed in your machine:: - $ pip install --user --upgrade nox + $ python -m pip install --user --upgrade nox Running the tests @@ -35,18 +36,18 @@ We use some external dependencies, multiple interpreters and code coverage analysis while running test suite. Our ``noxfile.py`` handles much of this for you:: - $ nox --reuse-existing-virtualenvs --sessions test-2.7 test-3.7 + $ nox --reuse-existing-virtualenvs --sessions test-3.7 test-3.9 [ Nox will create virtualenv if needed, install the specified dependencies, and run the commands in order.] - nox > Running session test-2.7 + nox > Running session test-3.7 ....... ....... - nox > Session test-2.7 was successful. + nox > Session test-3.7 was successful. ....... ....... - nox > Running session test-3.7 + nox > Running session test-3.9 ....... ....... - nox > Session test-3.7 was successful. + nox > Session test-3.9 was successful. There is also a nox command for running all of our tests and multiple python versions.:: @@ -55,23 +56,21 @@ versions.:: Note that code coverage less than 100% is regarded as a failing run. Some platform-specific tests are skipped unless run in that platform. To make sure -the code works in all of urllib3's supported platforms, you can run our ``tox`` +the code works in all of urllib3's supported platforms, you can run our ``nox`` suite:: $ nox --reuse-existing-virtualenvs --sessions test [ Nox will create virtualenv if needed, install the specified dependencies, and run the commands in order.] ....... ....... - nox > Session test-2.7 was successful. - nox > Session test-3.4 was successful. - nox > Session test-3.5 was successful. - nox > Session test-3.6 was successful. nox > Session test-3.7 was successful. nox > Session test-3.8 was successful. + nox > Session test-3.9 was successful. + nox > Session test-3.10 was successful. nox > Session test-pypy was successful. -Our test suite `runs continuously on Travis CI -`_ with every pull request. +Our test suite `runs continuously on GitHub Actions +`_ with every pull request. To run specific tests or quickly re-run without nox recreating the env, do the following:: @@ -95,25 +94,102 @@ further parameterize pytest for local testing. For all valid arguments, check `the pytest documentation `_. +Getting paid for your contributions +----------------------------------- + +urllib3 has a `pool of money hosted on Open Collective `_ +which the team uses to pay contributors for their work. **That could be you, too!** If you complete all tasks in an issue +that is marked with the `"💰 Bounty $X00" label `_ then you're eligible to be paid for your work. + +- Ensure that you're able to `receive funds from Open Collective for working on OSS `_. + Consider your employment contract and taxes for possible restrictions. +- If an issue is already assigned to someone on GitHub then it's likely they've + made substantial progress on the issue and will be given the bounty. + If you're interested in bounties you should look for issues which + aren't assigned to anyone. +- **Don't "claim" issues or ask whether someone is already working on an issue.** + Instead, focus on researching and working on the tasks in the issue. Once you + have made considerable progress on the tasks in an issue we can assign your + account to the issue to ensure others don't start working on it in parallel. +- If you've been assigned to an issue and haven't made progress or given an update + in over a week you will be unassigned from the issue to allow others a chance + at solving the issue. +- The amount you will be paid for the completing an issue is shown in the label (either $100, $200, $300, etc). +- If you have questions about how to create an invoice on Open Collective + `try reading their FAQ `_. +- If you have a proposal to work on urllib3 that's not listed in the issue tracker please open an issue + with your proposal and our team will discuss whether we'd pay for your work on your proposal. +- If you have other questions get in contact with a maintainer in the `urllib3 Discord channel `_ or via email. +- The list above isn't an exhaustive list of criteria or rules for how/when money is distributed. + **The final say on whether money will be distributed is up to maintainers.** + +This program is an experiment so if you have positive or negative feedback on the process you can contact the maintainers through one of the above channels. + +Note that this program isn't a "bug bounty" program, we don't distribute funds to reporters of bugs or security vulnerabilities at this time. + +Running local proxies +--------------------- + +If the feature you are developing involves a proxy, you can rely on scripts we have developed to run a proxy locally. + +Run an HTTP proxy locally: + +.. code-block:: bash + + $ python -m dummyserver.proxy + +Run an HTTPS proxy locally: + +.. code-block:: bash + + $ python -m dummyserver.https_proxy + +Contributing to documentation +----------------------------- + +You can build the docs locally using ``nox``: + +.. code-block:: bash + + $ nox -rs docs + +While writing documentation you should follow these guidelines: + +- Use the top-level ``urllib3.request()`` function for smaller code examples. For more involved examples use PoolManager, etc. +- Use double quotes for all strings. (Output, Declaration etc.) +- Use keyword arguments everywhere except for method and url. (ie ``http.request("GET", "https://example.com", headers={...})`` ) +- Use HTTPS in URLs everywhere unless HTTP is needed. +- Rules for code examples and naming variables: + + - ``PoolManager`` instances should be named ``http``. (ie ``http = urllib3.PoolManager(...)``) + - ``ProxyManager`` instances should be named ``proxy``. + - ``ConnectionPool`` instances should be named ``pool``. + - ``Connection`` instances should be named ``conn``. + - ``HTTPResponse`` instances should be named ``resp``. + - Only use ``example.com`` or ``httpbin.org`` for example URLs + +- Comments within snippets should be useful, if what's being done is apparent + (such as parsing JSON, making a request) then it can be skipped for that section. +- Comments should always go above a code section rather than below with the exception of print + statements where the comment containing the result goes below. +- Imports should be their own section separated from the rest of the example with a line of whitespace. +- Imports should minimized if possible. Use import urllib3 instead of from urllib3 import X. +- Sort imports similarly to isort, standard library first and third-party (like urllib3) come after. +- No whitespace is required between the sections as normally would be in case of isort. +- Add print statements along with a comment below them showing the output, potentially compressed. +- This helps users using the copy-paste button immediately see the results from a script. + Releases -------- -A release candidate can be created by any contributor by creating a branch -named ``release-x.x`` where ``x.x`` is the version of the proposed release. - -- Update ``CHANGES.rst`` and ``urllib3/__init__.py`` with the proper version number - and commit the changes to ``release-x.x``. -- Open a pull request to merge the ``release-x.x`` branch into the ``master`` branch. -- Integration tests are run against the release candidate on Travis. From here on all - the steps below will be handled by a maintainer so unless you receive review comments - you are done here. -- Once the pull request is squash merged into master the merging maintainer - will tag the merge commit with the version number: - - - ``git tag -a 1.24.1 [commit sha]`` - - ``git push origin master --tags`` - -- After the commit is tagged Travis will build the tagged commit and upload the sdist and wheel - to PyPI and create a draft release on GitHub for the tag. The merging maintainer will - ensure that the PyPI sdist and wheel are properly uploaded. -- The merging maintainer will mark the draft release on GitHub as an approved release. +A release candidate can be created by any contributor. + +- Announce intent to release on Discord, see if anyone wants to include last minute + changes. +- Run ``towncrier build`` to update ``CHANGES.rst`` with the release notes, adjust as + necessary. +- Update ``urllib3/__init__.py`` with the proper version number +- Commit the changes to a ``release-X.Y.Z`` branch. +- Create a pull request and append ``&expand=1&template=release.md`` to the URL before + submitting in order to include our release checklist in the pull request description. +- Follow the checklist! diff --git a/docs/index.rst b/docs/index.rst index bfb13f0c61..1407431a5c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,12 +6,14 @@ urllib3 :maxdepth: 3 For Enterprise - v2-roadmap + Community Discord + v2-migration-guide sponsors user-guide advanced-usage reference/index contributing + changelog urllib3 is a powerful, *user-friendly* HTTP client for Python. :ref:`Much of the Python ecosystem already uses ` urllib3 and you should too. @@ -24,21 +26,20 @@ standard libraries: - Client-side TLS/SSL verification. - File uploads with multipart encoding. - Helpers for retrying requests and dealing with HTTP redirects. -- Support for gzip, deflate, and brotli encoding. +- Support for gzip, deflate, brotli, and zstd encoding. - Proxy support for HTTP and SOCKS. - 100% test coverage. urllib3 is powerful and easy to use: -.. code-block:: python +.. code-block:: pycon - >>> import urllib3 - >>> http = urllib3.PoolManager() - >>> r = http.request('GET', 'http://httpbin.org/robots.txt') - >>> r.status - 200 - >>> r.data - 'User-agent: *\nDisallow: /deny\n' + >>> import urllib3 + >>> resp = urllib3.request("GET", "https://httpbin.org/robots.txt") + >>> resp.status + 200 + >>> resp.data + b"User-agent: *\nDisallow: /deny\n" For Enterprise -------------- @@ -85,8 +86,9 @@ Alternatively, you can grab the latest source code from `GitHub `_. .. automodule:: urllib3.contrib.pyopenssl :members: diff --git a/docs/reference/contrib/securetransport.rst b/docs/reference/contrib/securetransport.rst index 12a6ddcf2b..d4af1b8ba0 100644 --- a/docs/reference/contrib/securetransport.rst +++ b/docs/reference/contrib/securetransport.rst @@ -1,5 +1,8 @@ macOS SecureTransport ===================== +.. warning:: + DEPRECATED: This module is deprecated and will be removed in urllib3 v2.1.0. + Read more in this `issue `_. `SecureTranport `_ support for urllib3 via ctypes. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2d21e3c9f8..582b8f719f 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,12 +3,12 @@ API Reference .. toctree:: + urllib3.request urllib3.poolmanager urllib3.connectionpool urllib3.connection urllib3.exceptions urllib3.response urllib3.fields - urllib3.request urllib3.util contrib/index diff --git a/docs/reference/urllib3.connection.rst b/docs/reference/urllib3.connection.rst index 472ddc54f5..a5dc65bf23 100644 --- a/docs/reference/urllib3.connection.rst +++ b/docs/reference/urllib3.connection.rst @@ -9,3 +9,7 @@ Connections .. autoclass:: urllib3.connection.HTTPSConnection :members: :show-inheritance: + +.. autoclass:: urllib3.connection.ProxyConfig + :members: + :show-inheritance: diff --git a/docs/reference/urllib3.connectionpool.rst b/docs/reference/urllib3.connectionpool.rst index 9b10144607..9e5b38e7ac 100644 --- a/docs/reference/urllib3.connectionpool.rst +++ b/docs/reference/urllib3.connectionpool.rst @@ -15,3 +15,5 @@ Connection Pools :members: :undoc-members: :show-inheritance: + +.. autofunction:: urllib3.connectionpool.connection_from_url diff --git a/docs/reference/urllib3.exceptions.rst b/docs/reference/urllib3.exceptions.rst index f139f5e5c5..84603a7f6e 100644 --- a/docs/reference/urllib3.exceptions.rst +++ b/docs/reference/urllib3.exceptions.rst @@ -1,7 +1,9 @@ -Exceptions -========== +Exceptions and Warnings +======================= .. automodule:: urllib3.exceptions :members: :undoc-members: :show-inheritance: + +.. autofunction:: urllib3.disable_warnings diff --git a/docs/reference/urllib3.poolmanager.rst b/docs/reference/urllib3.poolmanager.rst index d796dafbf5..200f140ce3 100644 --- a/docs/reference/urllib3.poolmanager.rst +++ b/docs/reference/urllib3.poolmanager.rst @@ -5,8 +5,14 @@ Pool Manager :members: :undoc-members: :show-inheritance: + :inherited-members: .. autoclass:: urllib3.ProxyManager :members: :undoc-members: :show-inheritance: + +.. autoclass:: urllib3.poolmanager.PoolKey + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/reference/urllib3.request.rst b/docs/reference/urllib3.request.rst index 39a1236741..ea39de894e 100644 --- a/docs/reference/urllib3.request.rst +++ b/docs/reference/urllib3.request.rst @@ -1,7 +1,4 @@ -Request Methods -=============== +urllib3.request() +================= -.. automodule:: urllib3.request - :members: - :undoc-members: - :show-inheritance: +.. autofunction:: urllib3.request diff --git a/docs/reference/urllib3.response.rst b/docs/reference/urllib3.response.rst index aa87d3460f..d00b8af65c 100644 --- a/docs/reference/urllib3.response.rst +++ b/docs/reference/urllib3.response.rst @@ -4,10 +4,20 @@ Response and Decoders Response -------- +.. autoclass:: urllib3.response.BaseHTTPResponse + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: urllib3.response.HTTPResponse :members: :undoc-members: :show-inheritance: + :inherited-members: json + + .. autoattribute:: auto_close + .. autoattribute:: status + .. autoattribute:: headers Decoders -------- @@ -19,4 +29,5 @@ representation. .. autoclass:: urllib3.response.BrotliDecoder .. autoclass:: urllib3.response.DeflateDecoder .. autoclass:: urllib3.response.GzipDecoder +.. autoclass:: urllib3.response.ZstdDecoder .. autoclass:: urllib3.response.MultiDecoder diff --git a/docs/reference/urllib3.util.rst b/docs/reference/urllib3.util.rst index cb51215e80..d837b85fd4 100644 --- a/docs/reference/urllib3.util.rst +++ b/docs/reference/urllib3.util.rst @@ -4,11 +4,11 @@ Utilities Useful methods for working with :mod:`http.client`, completely decoupled from code specific to **urllib3**. -At the very core, just like its predecessors, :mod:`urllib3` is built on top of +At the very core, just like its predecessors, urllib3 is built on top of :mod:`http.client` -- the lowest level HTTP library included in the Python standard library. -To aid the limited functionality of the :mod:`http.client` module, :mod:`urllib3` +To aid the limited functionality of the :mod:`http.client` module, urllib3 provides various helper methods which are used with the higher level components but can also be used independently. diff --git a/docs/requirements.txt b/docs/requirements.txt index cdb71a7a58..882137502d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ -r ../dev-requirements.txt sphinx>3.0.0 -requests>=2,<2.16 +requests furo +sphinx-copybutton diff --git a/docs/sponsors.rst b/docs/sponsors.rst index 50a67b8c80..33e9ce7f33 100644 --- a/docs/sponsors.rst +++ b/docs/sponsors.rst @@ -7,45 +7,6 @@ benefits from this library. Your contribution will go towards adding new features to urllib3 and making sure all functionality continues to meet our high quality standards. - -v2.0 Sponsor Perks ------------------- - -.. important:: - - `Get in contact `_ for additional - details on sponsorship and perks before making a contribution - through `Open Collective `_ if you have questions. - - -Silver v2.0 Sponsor Perks -~~~~~~~~~~~~~~~~~~~~~~~~~ - -- Your organization name and URL permanently added - to the **Sponsors and Grants** section below -- Thank you within the v2.0 release announcement - and on Twitter from urllib3 maintainers - -➤ `Contribute to the "Silver v2.0 Sponsor" tier `_ -on Open Collective. - - -Gold v2.0 Sponsor Perks -~~~~~~~~~~~~~~~~~~~~~~~~ - -- Organization logo and URL listed on top of the v2.0 Roadmap -- Call with one or more urllib3 maintainer(s) to discuss - the v2.0 release and how it impacts your organization -- Your organization will be thanked within the v2.0 release - announcement, within all blog posts and public updates related to v2.0 - development, and multiple thank-you's on Twitter from - urllib3 maintainers throughout v2.0 development -- All perks from the **Silver v2.0 Sponsors Perks** above - -➤ `Contribute to the "Gold v2.0 Sponsor" tier `_ -on Open Collective. - - Sponsors and Grants ------------------- @@ -61,6 +22,8 @@ We also welcome sponsorship in the form of time. We greatly appreciate companies who encourage employees to contribute on an ongoing basis during their work hours. Let us know and we'll be glad to add you to our sponsors list. +* `Spotify `_ (June 2nd, 2022) + * `GitCoin Grants `_ (2019-2020), sponsored `@sethmlarson `_ and `@pquentin `_ @@ -76,12 +39,3 @@ Let us know and we'll be glad to add you to our sponsors list. `@Lukasa `_ * `Stripe `_ (June 23, 2014) - - -Open Collective Supporters --------------------------- - -All donations are currently going towards the development of new features for urllib3 v2.0. -Donate $5 or more as an individual or $50 or more as an organization to be added to the list of supporters below (coming soon). - -`Thanks to all our supporters on Open Collective `_! diff --git a/docs/user-guide.rst b/docs/user-guide.rst index abd53233cb..133579fc11 100644 --- a/docs/user-guide.rst +++ b/docs/user-guide.rst @@ -18,43 +18,64 @@ Making Requests First things first, import the urllib3 module: -.. code-block:: pycon +.. code-block:: python - >>> import urllib3 + import urllib3 You'll need a :class:`~poolmanager.PoolManager` instance to make requests. This object handles all of the details of connection pooling and thread safety so that you don't have to: -.. code-block:: pycon +.. code-block:: python - >>> http = urllib3.PoolManager() + http = urllib3.PoolManager() -To make a request use :meth:`~poolmanager.PoolManager.request`: +To make a request use :meth:`~urllib3.PoolManager.request`: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request('GET', 'http://httpbin.org/robots.txt') - >>> r.data - b'User-agent: *\nDisallow: /deny\n' + import urllib3 + + # Creating a PoolManager instance for sending requests. + http = urllib3.PoolManager() + + # Sending a GET request and getting back response as HTTPResponse object. + resp = http.request("GET", "https://httpbin.org/robots.txt") + + # Print the returned data. + print(resp.data) + # b"User-agent: *\nDisallow: /deny\n" ``request()`` returns a :class:`~response.HTTPResponse` object, the :ref:`response_content` section explains how to handle various responses. -You can use :meth:`~poolmanager.PoolManager.request` to make requests using any +You can use :meth:`~urllib3.PoolManager.request` to make requests using any HTTP verb: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... fields={'hello': 'world'} - ... ) + import urllib3 + + http = urllib3.PoolManager() + resp = http.request( + "POST", + "https://httpbin.org/post", + fields={"hello": "world"} # Add custom form fields + ) + + print(resp.data) + # b"{\n "form": {\n "hello": "world"\n }, ... } The :ref:`request_data` section covers sending other kinds of requests data, including JSON, files, and binary data. +.. note:: For quick scripts and experiments you can also use a top-level ``urllib3.request()``. + It uses a module-global ``PoolManager`` instance. + Because of that, its side effects could be shared across dependencies relying on it. + To avoid side effects, create a new ``PoolManager`` instance and use it instead. + In addition, the method does not accept the low-level ``**urlopen_kw`` keyword arguments. + System CA certificates are loaded on default. + .. _response_content: Response Content @@ -64,28 +85,48 @@ The :class:`~response.HTTPResponse` object provides :attr:`~response.HTTPResponse.status`, :attr:`~response.HTTPResponse.data`, and :attr:`~response.HTTPResponse.headers` attributes: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request('GET', 'http://httpbin.org/ip') - >>> r.status - 200 - >>> r.data - b'{\n "origin": "104.232.115.37"\n}\n' - >>> r.headers - HTTPHeaderDict({'Content-Length': '33', ...}) + import urllib3 + + # Making the request (The request function returns HTTPResponse object) + resp = urllib3.request("GET", "https://httpbin.org/ip") + + print(resp.status) + # 200 + print(resp.data) + # b"{\n "origin": "104.232.115.37"\n}\n" + print(resp.headers) + # HTTPHeaderDict({"Content-Length": "32", ...}) JSON Content ~~~~~~~~~~~~ +JSON content can be loaded by :meth:`~response.HTTPResponse.json` +method of the response: + +.. code-block:: python + + import urllib3 + + resp = urllib3.request("GET", "https://httpbin.org/ip") + + print(resp.json()) + # {"origin": "127.0.0.1"} -JSON content can be loaded by decoding and deserializing the -:attr:`~response.HTTPResponse.data` attribute of the request: +Alternatively, Custom JSON libraries such as `orjson` can be used to encode data, +retrieve data by decoding and deserializing the :attr:`~response.HTTPResponse.data` +attribute of the request: -.. code-block:: pycon +.. code-block:: python - >>> import json - >>> r = http.request('GET', 'http://httpbin.org/ip') - >>> json.loads(r.data.decode('utf-8')) - {'origin': '127.0.0.1'} + import orjson + import urllib3 + + encoded_data = orjson.dumps({"attribute": "value"}) + resp = urllib3.request(method="POST", url="http://httpbin.org/post", body=encoded_data) + + print(orjson.loads(resp.data)["json"]) + # {'attribute': 'value'} Binary Content ~~~~~~~~~~~~~~ @@ -93,11 +134,14 @@ Binary Content The :attr:`~response.HTTPResponse.data` attribute of the response is always set to a byte string representing the response content: -.. code-block:: pycon +.. code-block:: python + + import urllib3 - >>> r = http.request('GET', 'http://httpbin.org/bytes/8') - >>> r.data - b'\xaa\xa5H?\x95\xe9\x9b\x11' + resp = urllib3.request("GET", "https://httpbin.org/bytes/8") + + print(resp.data) + # b"\xaa\xa5H?\x95\xe9\x9b\x11" .. note:: For larger responses, it's sometimes better to :ref:`stream ` the response. @@ -110,13 +154,22 @@ directly with :class:`~response.HTTPResponse` data. Making these two interfaces together requires using the :attr:`~response.HTTPResponse.auto_close` attribute by setting it to ``False``. By default HTTP responses are closed after reading all bytes, this disables that behavior: -.. code-block:: pycon +.. code-block:: python + + import io + import urllib3 + + resp = urllib3.request("GET", "https://example.com", preload_content=False) + resp.auto_close = False - >>> import io - >>> r = http.request('GET', 'https://example.com', preload_content=False) - >>> r.auto_close = False - >>> for line in io.TextIOWrapper(r): - >>> print(line) + for line in io.TextIOWrapper(resp): + print(line) + # + # + # + # .... + # + # .. _request_data: @@ -126,48 +179,117 @@ Request Data Headers ~~~~~~~ -You can specify headers as a dictionary in the ``headers`` argument in :meth:`~poolmanager.PoolManager.request`: +You can specify headers as a dictionary in the ``headers`` argument in :meth:`~urllib3.PoolManager.request`: + +.. code-block:: python + + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/headers", + headers={ + "X-Something": "value" + } + ) + + print(resp.json()["headers"]) + # {"X-Something": "value", ...} + +Or you can use the ``HTTPHeaderDict`` class to create multi-valued HTTP headers: + +.. code-block:: python + + import urllib3 + + # Create an HTTPHeaderDict and add headers + headers = urllib3.HTTPHeaderDict() + headers.add("Accept", "application/json") + headers.add("Accept", "text/plain") + + # Make the request using the headers + resp = urllib3.request( + "GET", + "https://httpbin.org/headers", + headers=headers + ) + + print(resp.json()["headers"]) + # {"Accept": "application/json, text/plain", ...} + +Cookies +~~~~~~~ + +Cookies are specified using the ``Cookie`` header with a string containing +the ``;`` delimited key-value pairs: + +.. code-block:: python + + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/cookies", + headers={ + "Cookie": "session=f3efe9db; id=30" + } + ) + + print(resp.json()) + # {"cookies": {"id": "30", "session": "f3efe9db"}} + +Cookies provided by the server are stored in the ``Set-Cookie`` header: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/headers', - ... headers={ - ... 'X-Something': 'value' - ... } - ... ) - >>> json.loads(r.data.decode('utf-8'))['headers'] - {'X-Something': 'value', ...} + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/cookies/set/session/f3efe9db", + redirect=False + ) + + print(resp.headers["Set-Cookie"]) + # session=f3efe9db; Path=/ Query Parameters ~~~~~~~~~~~~~~~~ For ``GET``, ``HEAD``, and ``DELETE`` requests, you can simply pass the arguments as a dictionary in the ``fields`` argument to -:meth:`~poolmanager.PoolManager.request`: +:meth:`~urllib3.PoolManager.request`: + +.. code-block:: python -.. code-block:: pycon + import urllib3 - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/get', - ... fields={'arg': 'value'} - ... ) - >>> json.loads(r.data.decode('utf-8'))['args'] - {'arg': 'value'} + resp = urllib3.request( + "GET", + "https://httpbin.org/get", + fields={"arg": "value"} + ) + + print(resp.json()["args"]) + # {"arg": "value"} For ``POST`` and ``PUT`` requests, you need to manually encode query parameters in the URL: -.. code-block:: pycon +.. code-block:: python + + from urllib.parse import urlencode + import urllib3 + + # Encode the args into url grammar. + encoded_args = urlencode({"arg": "value"}) - >>> from urllib.parse import urlencode - >>> encoded_args = urlencode({'arg': 'value'}) - >>> url = 'http://httpbin.org/post?' + encoded_args - >>> r = http.request('POST', url) - >>> json.loads(r.data.decode('utf-8'))['args'] - {'arg': 'value'} + # Create a URL with args encoded. + url = "https://httpbin.org/post?" + encoded_args + resp = urllib3.request("POST", url) + + print(resp.json()["args"]) + # {"arg": "value"} .. _form_data: @@ -177,38 +299,47 @@ Form Data For ``PUT`` and ``POST`` requests, urllib3 will automatically form-encode the dictionary in the ``fields`` argument provided to -:meth:`~poolmanager.PoolManager.request`: +:meth:`~urllib3.PoolManager.request`: + +.. code-block:: python + + import urllib3 -.. code-block:: pycon + resp = urllib3.request( + "POST", + "https://httpbin.org/post", + fields={"field": "value"} + ) + + print(resp.json()["form"]) + # {"field": "value"} - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... fields={'field': 'value'} - ... ) - >>> json.loads(r.data.decode('utf-8'))['form'] - {'field': 'value'} +.. _json: JSON ~~~~ -You can send a JSON request by specifying the encoded data as the ``body`` -argument and setting the ``Content-Type`` header when calling -:meth:`~poolmanager.PoolManager.request`: +You can send a JSON request by specifying the data as ``json`` argument, +urllib3 automatically encodes data using ``json`` module with ``UTF-8`` +encoding. Also by default ``"Content-Type"`` in headers is set to +``"application/json"`` if not specified when calling +:meth:`~urllib3.PoolManager.request`: + +.. code-block:: python + + import urllib3 + + data = {"attribute": "value"} -.. code-block:: pycon + resp = urllib3.request( + "POST", + "https://httpbin.org/post", + body=data, + headers={"Content-Type": "application/json"} + ) - >>> import json - >>> data = {'attribute': 'value'} - >>> encoded_data = json.dumps(data).encode('utf-8') - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... body=encoded_data, - ... headers={'Content-Type': 'application/json'} - ... ) - >>> json.loads(r.data.decode('utf-8'))['json'] - {'attribute': 'value'} + print(resp.json()) + # {"attribute": "value"} Files & Binary Data ~~~~~~~~~~~~~~~~~~~ @@ -217,49 +348,59 @@ For uploading files using ``multipart/form-data`` encoding you can use the same approach as :ref:`form_data` and specify the file field as a tuple of ``(file_name, file_data)``: -.. code-block:: pycon - - >>> with open('example.txt') as fp: - ... file_data = fp.read() - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... fields={ - ... 'filefield': ('example.txt', file_data), - ... } - ... ) - >>> json.loads(r.data.decode('utf-8'))['files'] - {'filefield': '...'} +.. code-block:: python + + import urllib3 + + # Reading the text file from local storage. + with open("example.txt") as fp: + file_data = fp.read() + + # Sending the request. + resp = urllib3.request( + "POST", + "https://httpbin.org/post", + fields={ + "filefield": ("example.txt", file_data), + } + ) + + print(resp.json()["files"]) + # {"filefield": "..."} While specifying the filename is not strictly required, it's recommended in order to match browser behavior. You can also pass a third item in the tuple to specify the file's MIME type explicitly: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... fields={ - ... 'filefield': ('example.txt', file_data, 'text/plain'), - ... } - ... ) + resp = urllib3.request( + "POST", + "https://httpbin.org/post", + fields={ + "filefield": ("example.txt", file_data, "text/plain"), + } + ) For sending raw binary data simply specify the ``body`` argument. It's also recommended to set the ``Content-Type`` header: -.. code-block:: pycon +.. code-block:: python + + import urllib3 + + with open("/home/samad/example.jpg", "rb") as fp: + binary_data = fp.read() - >>> with open('example.jpg', 'rb') as fp: - ... binary_data = fp.read() - >>> r = http.request( - ... 'POST', - ... 'http://httpbin.org/post', - ... body=binary_data, - ... headers={'Content-Type': 'image/jpeg'} - ... ) - >>> json.loads(r.data.decode('utf-8'))['data'] - b'...' + resp = urllib3.request( + "POST", + "https://httpbin.org/post", + body=binary_data, + headers={"Content-Type": "image/jpeg"} + ) + + print(resp.json()["data"]) + # data:application/octet-stream;base64,... .. _ssl: @@ -268,9 +409,9 @@ Certificate Verification .. note:: *New in version 1.25:* - HTTPS connections are now verified by default (``cert_reqs = 'CERT_REQUIRED'``). + HTTPS connections are now verified by default (``cert_reqs = "CERT_REQUIRED"``). -While you can disable certification verification by setting ``cert_reqs = 'CERT_NONE'``, it is highly recommend to leave it on. +While you can disable certification verification by setting ``cert_reqs = "CERT_NONE"``, it is highly recommend to leave it on. Unless otherwise specified urllib3 will try to load the default system certificate stores. The most reliable cross-platform method is to use the `certifi `_ @@ -280,36 +421,37 @@ package which provides Mozilla's root certificate bundle: $ python -m pip install certifi -You can also install certifi along with urllib3 by using the ``secure`` -extra: - -.. code-block:: bash - - $ python -m pip install urllib3[secure] - -.. warning:: If you're using Python 2 you may need additional packages. See the :ref:`section below ` for more details. - Once you have certificates, you can create a :class:`~poolmanager.PoolManager` that verifies certificates when making requests: -.. code-block:: pycon +.. code-block:: python - >>> import certifi - >>> import urllib3 - >>> http = urllib3.PoolManager( - ... cert_reqs='CERT_REQUIRED', - ... ca_certs=certifi.where() - ... ) + import certifi + import urllib3 + + http = urllib3.PoolManager( + cert_reqs="CERT_REQUIRED", + ca_certs=certifi.where() + ) The :class:`~poolmanager.PoolManager` will automatically handle certificate verification and will raise :class:`~exceptions.SSLError` if verification fails: -.. code-block:: pycon +.. code-block:: python + + import certifi + import urllib3 + + http = urllib3.PoolManager( + cert_reqs="CERT_REQUIRED", + ca_certs=certifi.where() + ) - >>> http.request('GET', 'https://google.com') - (No exception) - >>> http.request('GET', 'https://expired.badssl.com') - urllib3.exceptions.SSLError ... + http.request("GET", "https://httpbin.org/") + # (No exception) + + http.request("GET", "https://expired.badssl.com") + # urllib3.exceptions.SSLError ... .. note:: You can use OS-provided certificates if desired. Just specify the full path to the certificate bundle as the ``ca_certs`` argument instead of @@ -317,210 +459,187 @@ verification and will raise :class:`~exceptions.SSLError` if verification fails: at ``/etc/ssl/certs/ca-certificates.crt``. Other operating systems can be `difficult `_. -.. _ssl_py2: - -Certificate Verification in Python 2 -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Older versions of Python 2 are built with an :mod:`ssl` module that lacks -:ref:`SNI support ` and can lag behind security updates. For these reasons it's recommended to use -`pyOpenSSL `_. - -If you install urllib3 with the ``secure`` extra, all required packages for -certificate verification on Python 2 will be installed: - -.. code-block:: bash - - $ python -m pip install urllib3[secure] - -If you want to install the packages manually, you will need ``pyOpenSSL``, -``cryptography``, ``idna``, and ``certifi``. - -.. note:: If you are not using macOS or Windows, note that `cryptography - `_ requires additional system packages - to compile. See `building cryptography on Linux - `_ - for the list of packages required. - -Once installed, you can tell urllib3 to use pyOpenSSL by using :mod:`urllib3.contrib.pyopenssl`: - -.. code-block:: pycon - - >>> import urllib3.contrib.pyopenssl - >>> urllib3.contrib.pyopenssl.inject_into_urllib3() - -Finally, you can create a :class:`~poolmanager.PoolManager` that verifies -certificates when performing requests: - -.. code-block:: pycon - - >>> import certifi - >>> import urllib3 - >>> http = urllib3.PoolManager( - ... cert_reqs='CERT_REQUIRED', - ... ca_certs=certifi.where() - ... ) - -If you do not wish to use pyOpenSSL, you can simply omit the call to -:func:`urllib3.contrib.pyopenssl.inject_into_urllib3`. urllib3 will fall back -to the standard-library :mod:`ssl` module. You may experience -:ref:`several warnings ` when doing this. - -.. warning:: If you do not use pyOpenSSL, Python must be compiled with ssl - support for certificate verification to work. It is uncommon, but it is - possible to compile Python without SSL support. See this - `StackOverflow thread `_ - for more details. - - If you are on Google App Engine, you must explicitly enable SSL - support in your ``app.yaml``: - - .. code-block:: yaml - - libraries: - - name: ssl - version: latest - Using Timeouts -------------- Timeouts allow you to control how long (in seconds) requests are allowed to run before being aborted. In simple cases, you can specify a timeout as a ``float`` -to :meth:`~poolmanager.PoolManager.request`: +to :meth:`~urllib3.PoolManager.request`: + +.. code-block:: python + + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/delay/3", + timeout=4.0 + ) -.. code-block:: pycon + print(type(resp)) + # - >>> http.request( - ... 'GET', 'http://httpbin.org/delay/3', timeout=4.0 - ... ) - - >>> http.request( - ... 'GET', 'http://httpbin.org/delay/3', timeout=2.5 - ... ) - MaxRetryError caused by ReadTimeoutError + # This request will take more time to process than timeout. + urllib3.request( + "GET", + "https://httpbin.org/delay/3", + timeout=2.5 + ) + # MaxRetryError caused by ReadTimeoutError For more granular control you can use a :class:`~util.timeout.Timeout` instance which lets you specify separate connect and read timeouts: -.. code-block:: pycon +.. code-block:: python - >>> http.request( - ... 'GET', - ... 'http://httpbin.org/delay/3', - ... timeout=urllib3.Timeout(connect=1.0) - ... ) - - >>> http.request( - ... 'GET', - ... 'http://httpbin.org/delay/3', - ... timeout=urllib3.Timeout(connect=1.0, read=2.0) - ... ) - MaxRetryError caused by ReadTimeoutError + import urllib3 + + resp = urllib3.request( + "GET", + "https://httpbin.org/delay/3", + timeout=urllib3.Timeout(connect=1.0) + ) + + print(type(resp)) + # + + urllib3.request( + "GET", + "https://httpbin.org/delay/3", + timeout=urllib3.Timeout(connect=1.0, read=2.0) + ) + # MaxRetryError caused by ReadTimeoutError If you want all requests to be subject to the same timeout, you can specify the timeout at the :class:`~urllib3.poolmanager.PoolManager` level: -.. code-block:: pycon +.. code-block:: python - >>> http = urllib3.PoolManager(timeout=3.0) - >>> http = urllib3.PoolManager( - ... timeout=urllib3.Timeout(connect=1.0, read=2.0) - ... ) + import urllib3 + + http = urllib3.PoolManager(timeout=3.0) + + http = urllib3.PoolManager( + timeout=urllib3.Timeout(connect=1.0, read=2.0) + ) You still override this pool-level timeout by specifying ``timeout`` to -:meth:`~poolmanager.PoolManager.request`. +:meth:`~urllib3.PoolManager.request`. Retrying Requests ----------------- urllib3 can automatically retry idempotent requests. This same mechanism also handles redirects. You can control the retries using the ``retries`` parameter -to :meth:`~poolmanager.PoolManager.request`. By default, urllib3 will retry +to :meth:`~urllib3.PoolManager.request`. By default, urllib3 will retry requests 3 times and follow up to 3 redirects. To change the number of retries just specify an integer: -.. code-block:: pycon +.. code-block:: python + + import urllib3 - >>> http.requests('GET', 'http://httpbin.org/ip', retries=10) + urllib3.request("GET", "https://httpbin.org/ip", retries=10) To disable all retry and redirect logic specify ``retries=False``: -.. code-block:: pycon +.. code-block:: python + + import urllib3 - >>> http.request( - ... 'GET', 'http://nxdomain.example.com', retries=False - ... ) - NewConnectionError - >>> r = http.request( - ... 'GET', 'http://httpbin.org/redirect/1', retries=False - ... ) - >>> r.status - 302 + urllib3.request( + "GET", + "https://nxdomain.example.com", + retries=False + ) + # NewConnectionError + + resp = urllib3.request( + "GET", + "https://httpbin.org/redirect/1", + retries=False + ) + + print(resp.status) + # 302 To disable redirects but keep the retrying logic, specify ``redirect=False``: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request( - ... 'GET', 'http://httpbin.org/redirect/1', redirect=False - ... ) - >>> r.status - 302 + resp = urllib3.request( + "GET", + "https://httpbin.org/redirect/1", + redirect=False + ) + + print(resp.status) + # 302 For more granular control you can use a :class:`~util.retry.Retry` instance. This class allows you far greater control of how requests are retried. For example, to do a total of 3 retries, but limit to only 2 redirects: -.. code-block:: pycon +.. code-block:: python - >>> http.request( - ... 'GET', - ... 'http://httpbin.org/redirect/3', - ... retries=urllib3.Retry(3, redirect=2) - ... ) - MaxRetryError + urllib3.request( + "GET", + "https://httpbin.org/redirect/3", + retries=urllib3.Retry(3, redirect=2) + ) + # MaxRetryError You can also disable exceptions for too many redirects and just return the ``302`` response: -.. code-block:: pycon +.. code-block:: python - >>> r = http.request( - ... 'GET', - ... 'http://httpbin.org/redirect/3', - ... retries=urllib3.Retry( - ... redirect=2, raise_on_redirect=False) - ... ) - >>> r.status - 302 + resp = urllib3.request( + "GET", + "https://httpbin.org/redirect/3", + retries=urllib3.Retry( + redirect=2, + raise_on_redirect=False + ) + ) + + print(resp.status) + # 302 If you want all requests to be subject to the same retry policy, you can specify the retry at the :class:`~urllib3.poolmanager.PoolManager` level: -.. code-block:: pycon +.. code-block:: python + + import urllib3 - >>> http = urllib3.PoolManager(retries=False) - >>> http = urllib3.PoolManager( - ... retries=urllib3.Retry(5, redirect=2) - ... ) + http = urllib3.PoolManager(retries=False) + + http = urllib3.PoolManager( + retries=urllib3.Retry(5, redirect=2) + ) You still override this pool-level retry policy by specifying ``retries`` to -:meth:`~poolmanager.PoolManager.request`. +:meth:`~urllib3.PoolManager.request`. Errors & Exceptions ------------------- urllib3 wraps lower-level exceptions, for example: -.. code-block:: pycon +.. code-block:: python + + import urllib3 + + try: + urllib3.request("GET","https://nx.example.com", retries=False) - >>> try: - ... http.request('GET', 'nx.example.com', retries=False) - ... except urllib3.exceptions.NewConnectionError: - ... print('Connection failed.') + except urllib3.exceptions.NewConnectionError: + print("Connection failed.") + # Connection failed. See :mod:`~urllib3.exceptions` for the full list of all exceptions. @@ -531,6 +650,6 @@ If you are using the standard library :mod:`logging` module urllib3 will emit several logs. In some cases this can be undesirable. You can use the standard logger interface to change the log level for urllib3's logger: -.. code-block:: pycon +.. code-block:: python - >>> logging.getLogger("urllib3").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/docs/v2-migration-guide.rst b/docs/v2-migration-guide.rst new file mode 100644 index 0000000000..24263dc52f --- /dev/null +++ b/docs/v2-migration-guide.rst @@ -0,0 +1,315 @@ +v2.0 Migration Guide +==================== + +**urllib3 v2.0 is now available!** Read below for how to get started and what is contained in the new major release. + +**🚀 Migrating from 1.x to 2.0** +-------------------------------- + +We're maintaining **functional API compatibility for most users** to make the +migration an easy choice for almost everyone. Most changes are either to default +configurations, supported Python versions, or internal implementation details. +So unless you're in a specific situation you should notice no changes! 🎉 + +.. note:: + + If you have difficulty migrating to v2.0 or following this guide + you can `open an issue on GitHub `_ + or reach out in `our community Discord channel `_. + + +Timeline for deprecations and breaking changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The 2.x initial release schedule will look like this: + +* urllib3 ``v2.0.0-alpha1`` will be released in November 2022. This release + contains **minor breaking changes and deprecation warnings for other breaking changes**. + There may be other pre-releases to address fixes before v2.0.0 is released. +* urllib3 ``v2.0.0`` will be released in early 2023 after some initial integration testing + against dependent packages and fixing of bug reports. +* urllib3 ``v2.1.0`` will be released in the summer of 2023 with **all breaking changes + being warned about in v2.0.0**. + +.. warning:: + + Please take the ``DeprecationWarnings`` you receive when migrating from v1.x to v2.0 seriously + as they will become errors after 2.1.0 is released. + + +What are the important changes? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here's a short summary of which changes in urllib3 v2.0 are most important: + +- Python version must be **3.7 or later** (previously supported Python 2.7, 3.5, and 3.6). +- Removed support for non-OpenSSL TLS libraries (like LibreSSL and wolfSSL). +- Removed support for OpenSSL versions older than 1.1.1. +- Removed support for Python implementations that aren't CPython or PyPy3 (previously supported Google App Engine, Jython). +- Removed the ``urllib3.contrib.ntlmpool`` module. +- Deprecated the ``urllib3.contrib.pyopenssl``, ``urllib3.contrib.securetransport`` modules, will be removed in v2.1.0. +- Deprecated the ``urllib3[secure]`` extra, will be removed in v2.1.0. +- Deprecated the ``HTTPResponse.getheaders()`` method in favor of ``HTTPResponse.headers``, will be removed in v2.1.0. +- Deprecated the ``HTTPResponse.getheader(name, default)`` method in favor of ``HTTPResponse.headers.get(name, default)``, will be removed in v2.1.0. +- Changed the default minimum TLS version to TLS 1.2 (previously was TLS 1.0). +- Removed support for verifying certificate hostnames via ``commonName``, now only ``subjectAltName`` is used. +- Removed the default set of TLS ciphers, instead now urllib3 uses the list of ciphers configured by the system. + +For a full list of changes you can look at `the changelog `_. + + +Migrating as a package maintainer? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're a maintainer of a package that uses urllib3 under the hood then this section is for you. +You may have already seen an issue opened from someone on our team about the upcoming release. + +The primary goal for migrating to urllib3 v2.x should be to ensure your package supports **both urllib3 v1.26.x and v2.0 for some time**. +This is to reduce the chance that diamond dependencies are introduced into your users' dependencies which will then cause issues +with them upgrading to the latest version of **your package**. + +The first step to supporting urllib3 v2.0 is to make sure the version v2.x not being excluded by ``install_requires``. You should +ensure your package allows for both urllib3 1.26.x and 2.0 to be used: + +.. code-block:: python + + # setup.py (setuptools) + setup( + ... + install_requires=["urllib3>=1.26,<3"] + ) + + # pyproject.toml (hatch) + [project] + dependencies = [ + "urllib3>=1.26,<3" + ] + +Next you should try installing urllib3 v2.0 locally and run your test suite. + +.. code-block:: bash + + $ python -m pip install -U --pre 'urllib3>=2.0.0a1' + + +Because there are many ``DeprecationWarnings`` you should ensure that you're +able to see those warnings when running your test suite. To do so you can add +the following to your test setup to ensure even ``DeprecationWarnings`` are +output to the terminal: + +.. code-block:: bash + + # Set PYTHONWARNING=default to show all warnings. + $ export PYTHONWARNINGS="default" + + # Run your test suite and look for failures. + # Pytest automatically prints all warnings. + $ pytest tests/ + +or you can opt-in within your Python code: + +.. code-block:: python + + # You can change warning filters according to the filter rules: + # https://docs.python.org/3/library/warnings.html#warning-filter + import warnings + warnings.filterwarnings("default", category=DeprecationWarning) + +Any failures or deprecation warnings you receive should be fixed as urllib3 v2.1.0 will remove all +deprecated features. Many deprecation warnings will make suggestions about what to do to avoid the deprecated feature. + +Warnings will look something like this: + +.. code-block:: bash + + DeprecationWarning: 'ssl_version' option is deprecated and will be removed + in urllib3 v2.1.0. Instead use 'ssl_minimum_version' + +Continue removing deprecation warnings until there are no more. After this you can publish a new release of your package +that supports both urllib3 v1.26.x and v2.x. + +.. note:: + + If you're not able to support both 1.26.x and v2.0 of urllib3 at the same time with your package please + `open an issue on GitHub `_ or reach out in + `our community Discord channel `_. + + +Migrating as an application developer? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're someone who writes Python but doesn't ship as a package (things like web services, data science, tools, and more) this section is for you. + +Python environments only allow for one version of a dependency to be installed per environment which means +that **all of your dependencies using urllib3 need to support v2.0 for you to upgrade**. + +The best way to visualize relationships between your dependencies is using `pipdeptree `_ and ``$ pipdeptree --reverse``: + +.. code-block:: bash + + # From inside your Python environment: + $ python -m pip install pipdeptree + # We only care about packages requiring urllib3 + $ pipdeptree --reverse | grep "requires: urllib3" + + - botocore==1.29.8 [requires: urllib3>=1.25.4,<1.27] + - requests==2.28.1 [requires: urllib3>=1.21.1,<1.27] + +Reading the output from above, there are two packages which depend on urllib3: ``botocore`` and ``requests``. +The versions of these two packages both require urllib3 that is less than v2.0 (ie ``<1.27``). + +Because both of these packages require urllib3 before v2.0 the new version of urllib3 can't be installed +by default. There are ways to force installing the newer version of urllib3 v2.0 (ie pinning to ``urllib3==2.0.0``) +which you can do to test your application. + +It's important to know that even if you don't upgrade all of your services to 2.x +immediately you will `receive security fixes on the 1.26.x release stream <#security-fixes-for-urllib3-v1-26-x>` for some time. + + +Security fixes for urllib3 v1.26.x +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Thanks to support from `Tidelift `_ +we're able to continue supporting the v1.26.x release stream with +security fixes for the foreseeable future 💖 + +However, upgrading is still recommended as **no new feature developments or non-critical +bug fixes will be shipped to the 1.26.x release stream**. + +If your organization relies on urllib3 and is interested in continuing support you can learn +more about the `Tidelift Subscription for Enterprise `_. + + +**💪 User-friendly features** +----------------------------- + +urllib3 has always billed itself as a **user-friendly HTTP client library**. +In the spirit of being even more user-friendly we've added two features +which should make using urllib3 for tinkering sessions, throw-away scripts, +and smaller projects a breeze! + +urllib3.request() +~~~~~~~~~~~~~~~~~ + +Previously the highest-level API available for urllib3 was a ``PoolManager``, +but for many cases configuring a poolmanager is extra steps for no benefit. +To make using urllib3 as simple as possible we've added a top-level function +for sending requests from a global poolmanager instance: + +.. code-block:: python + + >>> import urllib3 + >>> resp = urllib3.request("GET", "https://example.com") + >>> resp.status + 200 + +JSON support for requests and responses +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +JSON is everywhere – and now it's in urllib3, too! + +If you'd like to send JSON in a request body or deserialize a response body +from JSON into Python objects you can now use the new ``json=`` parameter +for requests and ``HTTPResponse.json()`` method on responses: + +.. code-block:: python + + import urllib3 + + # Send a request with a JSON body. + # This adds 'Content-Type: application/json' by default. + resp = urllib3.request( + "POST", "https://example.api.com", + json={"key": "value"} + ) + + # Receive a JSON body in the response. + resp = urllib3.request("GET", "https://xkcd.com/2347/info.0.json") + + # There's always an XKCD... + resp.json() + { + "num": 2347, + "img": "https://imgs.xkcd.com/comics/dependency.png", + "title": "Dependency", + ... + } + + +**✨ Optimized for Python 3.7+** +-------------------------------- + +In v2.0 we'll be specifically targeting +CPython 3.7+ and PyPy 7.0+ (compatible with CPython 3.7) +and dropping support for Python versions 2.7, 3.5, and 3.6. + +By dropping end-of-life Python versions we're able to optimize +the codebase for Python 3.7+ by using new features to improve +performance and reduce the amount of code that needs to be executed +in order to support legacy versions. + + +**📜 Type-hinted APIs** +----------------------- + +You're finally able to run Mypy or other type-checkers +on code using urllib3. This also means that for IDEs +that support type hints you'll receive better suggestions +from auto-complete. No more confusion with ``**kwargs``! + +We've also added API interfaces like ``BaseHTTPResponse`` +and ``BaseHTTPConnection`` to ensure that when you're sub-classing +an interface you're only using supported public APIs to ensure +compatibility and minimize breakages down the road. + +.. note:: + + If you're one of the rare few who is subclassing connections + or responses you should take a closer look at detailed changes + in `the changelog `_. + + +**🔐 Modern security by default** +--------------------------------- + +HTTPS requires TLS 1.2+ +~~~~~~~~~~~~~~~~~~~~~~~ + +Greater than 95% of websites support TLS 1.2 or above. +At this point we're comfortable switching the default +minimum TLS version to be 1.2 to ensure high security +for users without breaking services. + +Dropping TLS 1.0 and 1.1 by default means you +won't be vulnerable to TLS downgrade attacks +if a vulnerability in TLS 1.0 or 1.1 were discovered in +the future. Extra security for free! By dropping TLS 1.0 +and TLS 1.1 we also tighten the list of ciphers we need +to support to ensure high security for data traveling +over the wire. + +If you still need to use TLS 1.0 or 1.1 in your application +you can still upgrade to v2.0, you'll only need to set +``ssl_minimum_version`` to the proper value to continue using +legacy TLS versions. + + +Stop verifying commonName in certificates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Dropping support the long deprecated ``commonName`` +field on certificates in favor of only verifying +``subjectAltName`` to put us in line with browsers and +other HTTP client libraries and to improve security for our users. + + +Certificate verification via SSLContext +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default certificate verification is handled by urllib3 +to support legacy Python versions, but now we can +rely on Python's certificate verification instead! This +should result in a speedup for verifying certificates +and means that any improvements made to certificate +verification in Python or OpenSSL will be immediately +available. diff --git a/docs/v2-roadmap.rst b/docs/v2-roadmap.rst deleted file mode 100644 index 2a5183b255..0000000000 --- a/docs/v2-roadmap.rst +++ /dev/null @@ -1,180 +0,0 @@ -v2.0 Roadmap -============ - -.. important:: - - We're seeking `sponsors and supporters for urllib3 v2.0 on Open Collective `_. - There's a lot of work to be done for our small team and we want to make sure - development can get completed on-time while also fairly compensating contributors - for the additional effort required for a large release like ``v2.0``. - - Additional information available within the :doc:`sponsors` section of our documentation. - - -**🚀 Functional API Compatibility** ------------------------------------ - -We're maintaining **99% functional API compatibility** to make the -migration an easy choice for most users. Migration from v1.x to v2.x -should be the simplest major version upgrade you've ever completed. - -Most changes are either to default configurations, supported Python versions, -and internal implementation details. So unless you're in a specific situation -you should notice no changes! 🎉 - - -v1.26.x Security and Bug Fixes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Thanks to support from `Tidelift `_ -we're able to continue supporting v1.26.x releases with -both security and bug fixes for the forseeable future 💖 - -If your organization relies on urllib3 and is interested in continuing support you can learn -more about the `Tidelift Subscription for Enterprise `_. - - -**🔐 Modern Security by Default** ---------------------------------- - -HTTPS requires TLS 1.2+ -~~~~~~~~~~~~~~~~~~~~~~~ - -Greater than 95% of websites support TLS 1.2 or above. -At this point we're comfortable switching the default -minimum TLS version to be 1.2 to ensure high security -for users without breaking services. - -Dropping TLS 1.0 and 1.1 by default means you -won't be vulnerable to TLS downgrade attacks -if a vulnerability in TLS 1.0 or 1.1 were discovered in -the future. Extra security for free! By dropping TLS 1.0 -and TLS 1.1 we also tighten the list of ciphers we need -to support to ensure high security for data traveling -over the wire. - -If you still need to use TLS 1.0 or 1.1 in your application -you can still upgrade to v2.0, you'll only need to set -``ssl_version`` to the proper values to continue using -legacy TLS versions. - - -Stop Verifying CommonName in Certificates -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Dropping support the long deprecated ``commonName`` -field on certificates in favor of only verifying -``subjectAltName`` to put us in line with browsers and -other HTTP client libraries and to improve security for our users. - - -Certificate Verification via SSLContext -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -By default certificate verification is handled by urllib3 -to support legacy Python versions, but now we can -rely on Python's certificate verification instead! This -should result in a speedup for verifying certificates -and means that any improvements made to certificate -verification in Python or OpenSSL will be immediately -available. - - -**✨ Optimized for Python 3.6+** --------------------------------- - -In v2.0 we'll be specifically be targeting -CPython 3.6+ and PyPy 7.0+ (compatible with CPython 3.6) -and dropping support Python versions 2.7 and 3.5. - -By dropping end-of-life Python versions we're able to optimize -the codebase for Python 3.6+ by using new features to improve -performance and reduce the amount of code that needs to be executed -in order to support legacy versions. - - -**🔮 Tracing** --------------- - -Currently with urllib3 it's tough to get low-level insights into what -how your HTTP client is performing and what your connection information -looks like. In v2.0 we'll be adding tracing and telemetry information -to HTTP response objects including: - -- Connection ID -- IP Address resolved by DNS -- Request Method, Target, and Headers -- TLS Version and Cipher -- Certificate Fingerprint, subjectAltName, and Validity Information -- Timings for DNS, Request Data, First Byte in Response - - -**📜 Type-Hinted APIs** ------------------------ - -You'll finally be able to run Mypy or other type-checkers -on code using urllib3. This also means that for IDEs -that support type hints you'll receive better suggestions -from auto-complete. No more confusing with ``**kwargs``! - -We'll also add API interfaces to ensure that when -you're sub-classing an interface you're only using -supported public APIs to ensure compatibility and -minimize breakages down the road. - - -**🎁 ...and many more features!** ---------------------------------- - -- Top-level ``urllib3.request()`` API -- Open Possibility to Alternate HTTP Implementations -- Translated Guides -- Support Zstandard Compression -- Streaming ``multipart/form-encoded`` Request Data -- More Powerful and Configurable Retry Logic - -If there's a feature you don't see here but would like to see -in urllib3 v2.0, there's an open GitHub issue for making -feature suggestions. - - -**📅 Release and Migration Schedule** -------------------------------------- - -We're aiming for all ``v2.x`` features to be released in **mid-to-late 2021**. - -Here's what the release and migration schedule will look like leading up -to v2.0 being released: - -- Development of ``v2.x`` breaking changes starts. -- Release ``v1.26.0`` with deprecation warnings for ``v2.0.0`` breaking changes. - This will be the last non-patch release within the ``v1.x`` stream. -- Release ``v2.0.0-alpha1`` once all breaking changes have been completed. - We'll wait for users to report issues, bugs, and unexpected - breakages at this stage to ensure the release ``v2.0.0`` goes smoothly. -- Development of remaining ``v2.x`` features starts. -- Release ``v2.0.0`` which will be identical to ``v2.0.0-alpha1``. -- Release ``v2.1.0`` with remaining ``v2.x`` features. - -Deprecation warnings within ``v1.26.x`` will be opt-in by default. - -**More detailed Application Migration Guide coming soon.** - -For Package Maintainers -~~~~~~~~~~~~~~~~~~~~~~~ - -Since this is the first major release in almost 9 years some users may -be caught off-guard by a new major release of urllib3. We're mitigating this by -trying to make ``v2.x`` API-compatible with ``v1.x``. - -If your application or library uses urllib3 and you'd like to be extra -cautious about not breaking your users, you can pin urllib3 like so -until you ensure compatibility with ``v2.x``: - -.. code-block:: python - - # 'install_requires' or 'requirements.txt' - "urllib3>=1.25,<2" - -We'd really appreciate testing compatibility -and providing feedback on ``v2.0.0-alpha1`` once released. diff --git a/dummyserver/certs/README.rst b/dummyserver/certs/README.rst index 7c712b6e15..3ee127a02b 100644 --- a/dummyserver/certs/README.rst +++ b/dummyserver/certs/README.rst @@ -6,7 +6,7 @@ Here's how you can regenerate the certificates:: import trustme ca = trustme.CA() - server_cert = ca.issue_cert(u"localhost") + server_cert = ca.issue_cert("localhost") ca.cert_pem.write_to_path("cacert.pem") ca.private_key_pem.write_to_path("cacert.key") diff --git a/dummyserver/handlers.py b/dummyserver/handlers.py index c047094c4f..f49aeb4c9c 100644 --- a/dummyserver/handlers.py +++ b/dummyserver/handlers.py @@ -1,4 +1,4 @@ -from __future__ import print_function +from __future__ import annotations import collections import contextlib @@ -6,49 +6,64 @@ import json import logging import sys -import time +import typing import zlib from datetime import datetime, timedelta +from http.client import responses from io import BytesIO +from urllib.parse import urlsplit from tornado import httputil from tornado.web import RequestHandler -from urllib3.packages.six import binary_type, ensure_str -from urllib3.packages.six.moves.http_client import responses -from urllib3.packages.six.moves.urllib.parse import urlsplit +from urllib3.util.util import to_str log = logging.getLogger(__name__) -class Response(object): - def __init__(self, body="", status="200 OK", headers=None): +class Response: + def __init__( + self, + body: str | bytes | typing.Sequence[str | bytes] = "", + status: str = "200 OK", + headers: typing.Sequence[tuple[str, str | bytes]] | None = None, + json: typing.Any | None = None, + ) -> None: self.body = body self.status = status - self.headers = headers or [("Content-type", "text/plain")] + if json is not None: + self.headers = headers or [("Content-type", "application/json")] + self.body = json + else: + self.headers = headers or [("Content-type", "text/plain")] - def __call__(self, request_handler): + def __call__(self, request_handler: RequestHandler) -> None: status, reason = self.status.split(" ", 1) request_handler.set_status(int(status), reason) for header, value in self.headers: request_handler.add_header(header, value) + if isinstance(self.body, str): + request_handler.write(self.body.encode()) + elif isinstance(self.body, bytes): + request_handler.write(self.body) # chunked - if isinstance(self.body, list): + else: for item in self.body: if not isinstance(item, bytes): item = item.encode("utf8") request_handler.write(item) request_handler.flush() - else: - body = self.body - if not isinstance(body, bytes): - body = body.encode("utf8") - request_handler.write(body) + +RETRY_TEST_NAMES: dict[str, int] = collections.defaultdict(int) -RETRY_TEST_NAMES = collections.defaultdict(int) +def request_params(request: httputil.HTTPServerRequest) -> dict[str, bytes]: + params = {} + for k, v in request.arguments.items(): + params[k] = next(iter(v)) + return params class TestingApp(RequestHandler): @@ -61,32 +76,29 @@ class TestingApp(RequestHandler): method has its own conditions for success/failure. """ - def get(self): - """ Handle GET requests """ + def get(self) -> None: + """Handle GET requests""" self._call_method() - def post(self): - """ Handle POST requests """ + def post(self) -> None: + """Handle POST requests""" self._call_method() - def put(self): - """ Handle PUT requests """ + def put(self) -> None: + """Handle PUT requests""" self._call_method() - def options(self): - """ Handle OPTIONS requests """ + def options(self) -> None: + """Handle OPTIONS requests""" self._call_method() - def head(self): - """ Handle HEAD requests """ + def head(self) -> None: + """Handle HEAD requests""" self._call_method() - def _call_method(self): - """ Call the correct method in this class based on the incoming URI """ + def _call_method(self) -> None: + """Call the correct method in this class based on the incoming URI""" req = self.request - req.params = {} - for k, v in req.arguments.items(): - req.params[k] = next(iter(v)) path = req.path[:] if not path.startswith("/"): @@ -103,60 +115,65 @@ def _call_method(self): resp(self) - def index(self, _request): + def index(self, _request: httputil.HTTPServerRequest) -> Response: "Render simple message" return Response("Dummy server!") - def certificate(self, request): + def certificate(self, request: httputil.HTTPServerRequest) -> Response: """Return the requester's certificate.""" cert = request.get_ssl_certificate() - subject = dict() + assert isinstance(cert, dict) + subject = {} if cert is not None: - subject = dict((k, v) for (k, v) in [y for z in cert["subject"] for y in z]) + subject = {k: v for (k, v) in [y for z in cert["subject"] for y in z]} return Response(json.dumps(subject)) - def alpn_protocol(self, request): + def alpn_protocol(self, request: httputil.HTTPServerRequest) -> Response: """Return the selected ALPN protocol.""" - proto = request.connection.stream.socket.selected_alpn_protocol() - return Response(proto.encode("utf8") if proto is not None else u"") + assert request.connection is not None + proto = request.connection.stream.socket.selected_alpn_protocol() # type: ignore[attr-defined] + return Response(proto.encode("utf8") if proto is not None else "") - def source_address(self, request): + def source_address(self, request: httputil.HTTPServerRequest) -> Response: """Return the requester's IP address.""" - return Response(request.remote_ip) + return Response(request.remote_ip) # type: ignore[arg-type] - def set_up(self, request): - test_type = request.params.get("test_type") - test_id = request.params.get("test_id") + def set_up(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + test_type = params.get("test_type") + test_id = params.get("test_id") if test_id: - print("\nNew test %s: %s" % (test_type, test_id)) + print(f"\nNew test {test_type!r}: {test_id!r}") else: - print("\nNew test %s" % test_type) + print(f"\nNew test {test_type!r}") return Response("Dummy server is ready!") - def specific_method(self, request): + def specific_method(self, request: httputil.HTTPServerRequest) -> Response: "Confirm that the request matches the desired method type" - method = request.params.get("method") - if method and not isinstance(method, str): - method = method.decode("utf8") + params = request_params(request) + method = params.get("method") + method_str = method.decode() if method else None - if request.method != method: + if request.method != method_str: return Response( - "Wrong method: %s != %s" % (method, request.method), + f"Wrong method: {method_str} != {request.method}", status="400 Bad Request", ) return Response() - def upload(self, request): + def upload(self, request: httputil.HTTPServerRequest) -> Response: "Confirm that the uploaded file conforms to specification" + params = request_params(request) # FIXME: This is a huge broken mess - param = request.params.get("upload_param", b"myfile").decode("ascii") - filename = request.params.get("upload_filename", b"").decode("utf-8") - size = int(request.params.get("upload_size", "0")) + param = params.get("upload_param", b"myfile").decode("ascii") + filename = params.get("upload_filename", b"").decode("utf-8") + size = int(params.get("upload_size", "0")) files_ = request.files.get(param) + assert files_ is not None if len(files_) != 1: return Response( - "Expected 1 file for '%s', not %d" % (param, len(files_)), + f"Expected 1 file for '{param}', not {len(files_)}", status="400 Bad Request", ) file_ = files_[0] @@ -164,80 +181,80 @@ def upload(self, request): data = file_["body"] if int(size) != len(data): return Response( - "Wrong size: %d != %d" % (size, len(data)), status="400 Bad Request" + f"Wrong size: {int(size)} != {len(data)}", status="400 Bad Request" ) got_filename = file_["filename"] - if isinstance(got_filename, binary_type): + if isinstance(got_filename, bytes): got_filename = got_filename.decode("utf-8") # Tornado can leave the trailing \n in place on the filename. if filename != got_filename: return Response( - u"Wrong filename: %s != %s" % (filename, file_.filename), + f"Wrong filename: {filename} != {file_.filename}", status="400 Bad Request", ) return Response() - def redirect(self, request): + def redirect(self, request: httputil.HTTPServerRequest) -> Response: # type: ignore[override] "Perform a redirect to ``target``" - target = request.params.get("target", "/") - status = request.params.get("status", "303 See Other") + params = request_params(request) + target = params.get("target", "/") + status = params.get("status", b"303 See Other").decode("latin-1") if len(status) == 3: - status = "%s Redirect" % status.decode("latin-1") + status = f"{status} Redirect" headers = [("Location", target)] return Response(status=status, headers=headers) - def not_found(self, request): + def not_found(self, request: httputil.HTTPServerRequest) -> Response: return Response("Not found", status="404 Not Found") - def multi_redirect(self, request): + def multi_redirect(self, request: httputil.HTTPServerRequest) -> Response: "Performs a redirect chain based on ``redirect_codes``" - codes = request.params.get("redirect_codes", b"200").decode("utf-8") + params = request_params(request) + codes = params.get("redirect_codes", b"200").decode("utf-8") head, tail = codes.split(",", 1) if "," in codes else (codes, None) - status = "{0} {1}".format(head, responses[int(head)]) + assert head is not None + status = f"{head} {responses[int(head)]}" if not tail: return Response("Done redirecting", status=status) - headers = [("Location", "/multi_redirect?redirect_codes=%s" % tail)] + headers = [("Location", f"/multi_redirect?redirect_codes={tail}")] return Response(status=status, headers=headers) - def keepalive(self, request): - if request.params.get("close", b"0") == b"1": + def keepalive(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + if params.get("close", b"0") == b"1": headers = [("Connection", "close")] return Response("Closing", headers=headers) headers = [("Connection", "keep-alive")] return Response("Keeping alive", headers=headers) - def echo_params(self, request): - params = sorted( - [(ensure_str(k), ensure_str(v)) for k, v in request.params.items()] - ) - return Response(repr(params)) - - def sleep(self, request): - "Sleep for a specified amount of ``seconds``" - # DO NOT USE THIS, IT'S DEPRECATED. - # FIXME: Delete this once appengine tests are fixed to not use this handler. - seconds = float(request.params.get("seconds", "1")) - time.sleep(seconds) - return Response() + def echo_params(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + echod = sorted((to_str(k), to_str(v)) for k, v in params.items()) + return Response(repr(echod)) - def echo(self, request): + def echo(self, request: httputil.HTTPServerRequest) -> Response: "Echo back the params" if request.method == "GET": return Response(request.query) return Response(request.body) - def echo_uri(self, request): + def echo_json(self, request: httputil.HTTPServerRequest) -> Response: + "Echo back the JSON" + return Response(json=request.body, headers=list(request.headers.items())) + + def echo_uri(self, request: httputil.HTTPServerRequest) -> Response: "Echo back the requested URI" + assert request.uri is not None return Response(request.uri) - def encodingrequest(self, request): + def encodingrequest(self, request: httputil.HTTPServerRequest) -> Response: "Check for UA accepting gzip/deflate encoding" data = b"hello, world!" encoding = request.headers.get("Accept-Encoding", "") @@ -255,16 +272,19 @@ def encodingrequest(self, request): data = zlib.compress(data) elif encoding == "garbage-gzip": headers = [("Content-Encoding", "gzip")] - data = "garbage" + data = b"garbage" elif encoding == "garbage-deflate": headers = [("Content-Encoding", "deflate")] - data = "garbage" + data = b"garbage" return Response(data, headers=headers) - def headers(self, request): + def headers(self, request: httputil.HTTPServerRequest) -> Response: return Response(json.dumps(dict(request.headers))) - def successful_retry(self, request): + def multi_headers(self, request: httputil.HTTPServerRequest) -> Response: + return Response(json.dumps({"headers": list(request.headers.get_all())})) + + def successful_retry(self, request: httputil.HTTPServerRequest) -> Response: """Handler which will return an error and then success It's not currently very flexible as the number of retries is hard-coded. @@ -280,10 +300,10 @@ def successful_retry(self, request): else: return Response("need to keep retrying!", status="418 I'm A Teapot") - def chunked(self, request): + def chunked(self, request: httputil.HTTPServerRequest) -> Response: return Response(["123"] * 4) - def chunked_gzip(self, request): + def chunked_gzip(self, request: httputil.HTTPServerRequest) -> Response: chunks = [] compressor = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) @@ -294,39 +314,43 @@ def chunked_gzip(self, request): return Response(chunks, headers=[("Content-Encoding", "gzip")]) - def nbytes(self, request): - length = int(request.params.get("length")) + def nbytes(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + length = int(params["length"]) data = b"1" * length return Response(data, headers=[("Content-Type", "application/octet-stream")]) - def status(self, request): - status = request.params.get("status", "200 OK") + def status(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + status = params.get("status", b"200 OK").decode("latin-1") return Response(status=status) - def retry_after(self, request): - if datetime.now() - self.application.last_req < timedelta(seconds=1): - status = request.params.get("status", b"429 Too Many Requests") + def retry_after(self, request: httputil.HTTPServerRequest) -> Response: + params = request_params(request) + if datetime.now() - self.application.last_req < timedelta(seconds=1): # type: ignore[attr-defined] + status = params.get("status", b"429 Too Many Requests") return Response( status=status.decode("utf-8"), headers=[("Retry-After", "1")] ) - self.application.last_req = datetime.now() + self.application.last_req = datetime.now() # type: ignore[attr-defined] return Response(status="200 OK") - def redirect_after(self, request): + def redirect_after(self, request: httputil.HTTPServerRequest) -> Response: "Perform a redirect to ``target``" - date = request.params.get("date") + params = request_params(request) + date = params.get("date") if date: retry_after = str( httputil.format_timestamp(datetime.utcfromtimestamp(float(date))) ) else: retry_after = "1" - target = request.params.get("target", "/") + target = params.get("target", "/") headers = [("Location", target), ("Retry-After", retry_after)] return Response(status="303 See Other", headers=headers) - def shutdown(self, request): + def shutdown(self, request: httputil.HTTPServerRequest) -> typing.NoReturn: sys.exit() diff --git a/dummyserver/https_proxy.py b/dummyserver/https_proxy.py new file mode 100755 index 0000000000..79dae1cd03 --- /dev/null +++ b/dummyserver/https_proxy.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import sys +import typing + +import tornado.httpserver +import tornado.ioloop +import tornado.web + +from dummyserver.proxy import ProxyHandler +from dummyserver.server import DEFAULT_CERTS, ssl_options_to_context + + +def run_proxy(port: int, certs: dict[str, typing.Any] = DEFAULT_CERTS) -> None: + """ + Run proxy on the specified port using the provided certs. + + Example usage: + + python -m dummyserver.https_proxy + + You'll need to ensure you have access to certain packages such as trustme, + tornado, urllib3. + """ + upstream_ca_certs = certs.get("ca_certs") + app = tornado.web.Application( + [(r".*", ProxyHandler)], upstream_ca_certs=upstream_ca_certs + ) + ssl_opts = ssl_options_to_context(**certs) + http_server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_opts) + http_server.listen(port) + + ioloop = tornado.ioloop.IOLoop.instance() + try: + ioloop.start() + except KeyboardInterrupt: + ioloop.stop() + + +if __name__ == "__main__": + port = 8443 + if len(sys.argv) > 1: + port = int(sys.argv[1]) + + print(f"Starting HTTPS proxy on port {port}") + run_proxy(port) diff --git a/dummyserver/proxy.py b/dummyserver/proxy.py index 0cd8dedd26..686506da6b 100755 --- a/dummyserver/proxy.py +++ b/dummyserver/proxy.py @@ -25,6 +25,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from __future__ import annotations + import socket import ssl import sys @@ -40,17 +42,16 @@ class ProxyHandler(tornado.web.RequestHandler): - SUPPORTED_METHODS = ["GET", "POST", "CONNECT"] + SUPPORTED_METHODS = ["GET", "POST", "CONNECT"] # type: ignore[assignment] - @tornado.gen.coroutine - def get(self): - def handle_response(response): + async def get(self) -> None: + async def handle_response(response: tornado.httpclient.HTTPResponse) -> None: if response.error and not isinstance( response.error, tornado.httpclient.HTTPError ): self.set_status(500) self.write("Internal server error:\n" + str(response.error)) - self.finish() + await self.finish() else: self.set_status(response.code) for header in ( @@ -65,7 +66,7 @@ def handle_response(response): self.set_header(header, v) if response.body: self.write(response.body) - self.finish() + await self.finish() upstream_ca_certs = self.application.settings.get("upstream_ca_certs", None) ssl_options = None @@ -73,6 +74,8 @@ def handle_response(response): if upstream_ca_certs: ssl_options = ssl.create_default_context(cafile=upstream_ca_certs) + assert self.request.uri is not None + assert self.request.method is not None req = tornado.httpclient.HTTPRequest( url=self.request.uri, method=self.request.method, @@ -85,30 +88,31 @@ def handle_response(response): client = tornado.httpclient.AsyncHTTPClient() try: - response = yield client.fetch(req) - yield handle_response(response) + response = await client.fetch(req) + await handle_response(response) except tornado.httpclient.HTTPError as e: if hasattr(e, "response") and e.response: - yield handle_response(e.response) + await handle_response(e.response) else: self.set_status(500) self.write("Internal server error:\n" + str(e)) self.finish() - @tornado.gen.coroutine - def post(self): - yield self.get() + async def post(self) -> None: + await self.get() - @tornado.gen.coroutine - def connect(self): + async def connect(self) -> None: + assert self.request.uri is not None host, port = self.request.uri.split(":") - client = self.request.connection.stream + assert self.request.connection is not None + client: tornado.iostream.IOStream = self.request.connection.stream # type: ignore[attr-defined] - @tornado.gen.coroutine - def start_forward(reader, writer): + async def start_forward( + reader: tornado.iostream.IOStream, writer: tornado.iostream.IOStream + ) -> None: while True: try: - data = yield reader.read_bytes(4096, partial=True) + data = await reader.read_bytes(4096, partial=True) except tornado.iostream.StreamClosedError: break if not data: @@ -118,15 +122,15 @@ def start_forward(reader, writer): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) upstream = tornado.iostream.IOStream(s) - yield upstream.connect((host, int(port))) + await upstream.connect((host, int(port))) client.write(b"HTTP/1.0 200 Connection established\r\n\r\n") fu1 = start_forward(client, upstream) fu2 = start_forward(upstream, client) - yield [fu1, fu2] + await tornado.gen.multi([fu1, fu2]) -def run_proxy(port, start_ioloop=True): +def run_proxy(port: int, start_ioloop: bool = True) -> None: """ Run proxy on the specified port. If start_ioloop is True (default), the tornado IOLoop will be started immediately. @@ -143,5 +147,5 @@ def run_proxy(port, start_ioloop=True): if len(sys.argv) > 1: port = int(sys.argv[1]) - print("Starting HTTP proxy on port %d" % port) + print(f"Starting HTTP proxy on port {port}") run_proxy(port) diff --git a/dummyserver/server.py b/dummyserver/server.py index 9ecde97f35..2683c9ec62 100755 --- a/dummyserver/server.py +++ b/dummyserver/server.py @@ -3,15 +3,22 @@ """ Dummy server used for unit testing. """ -from __future__ import print_function +from __future__ import annotations + +import asyncio +import concurrent.futures +import contextlib +import errno import logging import os import socket import ssl import sys import threading +import typing import warnings +from collections.abc import Coroutine, Generator from datetime import datetime import tornado.httpserver @@ -25,10 +32,15 @@ from urllib3.exceptions import HTTPWarning from urllib3.util import ALPN_PROTOCOLS, resolve_cert_reqs, resolve_ssl_version +if typing.TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + log = logging.getLogger(__name__) CERTS_PATH = os.path.join(os.path.dirname(__file__), "certs") -DEFAULT_CERTS = { +DEFAULT_CERTS: dict[str, typing.Any] = { "certfile": os.path.join(CERTS_PATH, "server.crt"), "keyfile": os.path.join(CERTS_PATH, "server.key"), "cert_reqs": ssl.CERT_OPTIONAL, @@ -39,8 +51,8 @@ DEFAULT_CA_KEY = os.path.join(CERTS_PATH, "cacert.key") -def _resolves_to_ipv6(host): - """ Returns True if the system resolves host to an IPv6 address by default. """ +def _resolves_to_ipv6(host: str) -> bool: + """Returns True if the system resolves host to an IPv6 address by default.""" resolves_to_ipv6 = False try: for res in socket.getaddrinfo(host, None, socket.AF_UNSPEC): @@ -53,8 +65,8 @@ def _resolves_to_ipv6(host): return resolves_to_ipv6 -def _has_ipv6(host): - """ Returns True if the system can bind an IPv6 address. """ +def _has_ipv6(host: str) -> bool: + """Returns True if the system can bind an IPv6 address.""" sock = None has_ipv6 = False @@ -89,7 +101,6 @@ def _has_ipv6(host): class NoIPv6Warning(HTTPWarning): "IPv6 is not available" - pass class SocketServerThread(threading.Thread): @@ -102,15 +113,20 @@ class SocketServerThread(threading.Thread): USE_IPV6 = HAS_IPV6_AND_DNS - def __init__(self, socket_handler, host="localhost", port=8081, ready_event=None): - threading.Thread.__init__(self) + def __init__( + self, + socket_handler: typing.Callable[[socket.socket], None], + host: str = "localhost", + ready_event: threading.Event | None = None, + ) -> None: + super().__init__() self.daemon = True self.socket_handler = socket_handler self.host = host self.ready_event = ready_event - def _start_server(self): + def _start_server(self) -> None: if self.USE_IPV6: sock = socket.socket(socket.AF_INET6) else: @@ -130,22 +146,22 @@ def _start_server(self): self.socket_handler(sock) sock.close() - def run(self): - self.server = self._start_server() + def run(self) -> None: + self._start_server() -def ssl_options_to_context( +def ssl_options_to_context( # type: ignore[no-untyped-def] keyfile=None, certfile=None, server_side=None, cert_reqs=None, - ssl_version=None, + ssl_version: str | int | None = None, ca_certs=None, do_handshake_on_connect=None, suppress_ragged_eofs=None, ciphers=None, alpn_protocols=None, -): +) -> ssl.SSLContext: """Return an equivalent SSLContext based on ssl.wrap_socket args.""" ssl_version = resolve_ssl_version(ssl_version) cert_none = resolve_cert_reqs("CERT_NONE") @@ -167,54 +183,52 @@ def ssl_options_to_context( return ctx -def run_tornado_app(app, io_loop, certs, scheme, host): - assert io_loop == tornado.ioloop.IOLoop.current() - +def run_tornado_app( + app: tornado.web.Application, + certs: dict[str, typing.Any] | None, + scheme: str, + host: str, +) -> tuple[tornado.httpserver.HTTPServer, int]: # We can't use fromtimestamp(0) because of CPython issue 29097, so we'll # just construct the datetime object directly. - app.last_req = datetime(1970, 1, 1) + app.last_req = datetime(1970, 1, 1) # type: ignore[attr-defined] if scheme == "https": - if sys.version_info < (2, 7, 9): - ssl_opts = certs - else: - ssl_opts = ssl_options_to_context(**certs) + assert certs is not None + ssl_opts = ssl_options_to_context(**certs) http_server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_opts) else: http_server = tornado.httpserver.HTTPServer(app) - sockets = tornado.netutil.bind_sockets(None, address=host) + # When we request a socket with host localhost and port zero (None in Python), then + # Tornado gets a free IPv4 port and requests that same port in IPv6. But that port + # could easily be taken with IPv6, especially in crowded CI environments. For this + # reason we put bind_sockets in a retry loop. Full details: + # * https://github.com/urllib3/urllib3/issues/2171 + # * https://github.com/tornadoweb/tornado/issues/1860 + for i in range(10): + try: + sockets = tornado.netutil.bind_sockets(None, address=host) # type: ignore[arg-type] + except OSError as e: + if e.errno == errno.EADDRINUSE: + # TODO this should be a warning if there's a way for pytest to print it + print( + f"Retrying bind_sockets({host}) after EADDRINUSE", file=sys.stderr + ) + continue + break + port = sockets[0].getsockname()[1] http_server.add_sockets(sockets) return http_server, port -def run_loop_in_thread(io_loop): - t = threading.Thread(target=io_loop.start) - t.start() - return t - - -def get_unreachable_address(): +def get_unreachable_address() -> tuple[str, int]: # reserved as per rfc2606 return ("something.invalid", 54321) -if __name__ == "__main__": - # For debugging dummyserver itself - python -m dummyserver.server - from .testcase import TestingApp - - host = "127.0.0.1" - - io_loop = tornado.ioloop.IOLoop.current() - app = tornado.web.Application([(r".*", TestingApp)]) - server, port = run_tornado_app(app, io_loop, None, "http", host) - server_thread = run_loop_in_thread(io_loop) - - print("Listening on http://{host}:{port}".format(host=host, port=port)) - - -def encrypt_key_pem(private_key_pem, password): +def encrypt_key_pem(private_key_pem: trustme.Blob, password: bytes) -> trustme.Blob: private_key = serialization.load_pem_private_key( private_key_pem.bytes(), password=None, backend=default_backend() ) @@ -224,3 +238,79 @@ def encrypt_key_pem(private_key_pem, password): serialization.BestAvailableEncryption(password), ) return trustme.Blob(encrypted_key) + + +R = typing.TypeVar("R") + + +def _run_and_close_tornado( + async_fn: typing.Callable[P, Coroutine[typing.Any, typing.Any, R]], + *args: P.args, + **kwargs: P.kwargs, +) -> R: + tornado_loop = None + + async def inner_fn() -> R: + nonlocal tornado_loop + tornado_loop = tornado.ioloop.IOLoop.current() + return await async_fn(*args, **kwargs) + + try: + return asyncio.run(inner_fn()) + finally: + tornado_loop.close(all_fds=True) # type: ignore[union-attr] + + +@contextlib.contextmanager +def run_loop_in_thread() -> Generator[tornado.ioloop.IOLoop, None, None]: + loop_started: concurrent.futures.Future[ + tuple[tornado.ioloop.IOLoop, asyncio.Event] + ] = concurrent.futures.Future() + with concurrent.futures.ThreadPoolExecutor( + 1, thread_name_prefix="test IOLoop" + ) as tpe: + + async def run() -> None: + io_loop = tornado.ioloop.IOLoop.current() + stop_event = asyncio.Event() + loop_started.set_result((io_loop, stop_event)) + await stop_event.wait() + + # run asyncio.run in a thread and collect exceptions from *either* + # the loop failing to start, or failing to close + ran = tpe.submit(_run_and_close_tornado, run) # type: ignore[arg-type] + for f in concurrent.futures.as_completed((loop_started, ran)): # type: ignore[misc] + if f is loop_started: + io_loop, stop_event = loop_started.result() + try: + yield io_loop + finally: + io_loop.add_callback(stop_event.set) + + elif f is ran: + # if this is the first iteration the loop failed to start + # if it's the second iteration the loop has finished or + # the loop failed to close and we need to raise the exception + ran.result() + return + + +def main() -> int: + # For debugging dummyserver itself - python -m dummyserver.server + from .handlers import TestingApp + + host = "127.0.0.1" + + async def amain() -> int: + app = tornado.web.Application([(r".*", TestingApp)]) + server, port = run_tornado_app(app, None, "http", host) + + print(f"Listening on http://{host}:{port}") + await asyncio.Event().wait() + return 0 + + return asyncio.run(amain()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dummyserver/testcase.py b/dummyserver/testcase.py index 6a49e36cd2..a681106732 100644 --- a/dummyserver/testcase.py +++ b/dummyserver/testcase.py @@ -1,8 +1,14 @@ +from __future__ import annotations + +import asyncio +import contextlib +import socket +import ssl import threading -from contextlib import contextmanager +import typing import pytest -from tornado import ioloop, web +from tornado import httpserver, ioloop, web from dummyserver.handlers import TestingApp from dummyserver.proxy import ProxyHandler @@ -14,19 +20,23 @@ run_tornado_app, ) from urllib3.connection import HTTPConnection +from urllib3.util.ssltransport import SSLTransport -def consume_socket(sock, chunks=65536): +def consume_socket( + sock: SSLTransport | socket.socket, chunks: int = 65536 +) -> bytearray: consumed = bytearray() while True: b = sock.recv(chunks) + assert isinstance(b, bytes) consumed += b if b.endswith(b"\r\n\r\n"): break return consumed -class SocketDummyServerTestCase(object): +class SocketDummyServerTestCase: """ A simple socket-based server is created for this class that is good for exactly one request. @@ -35,8 +45,25 @@ class SocketDummyServerTestCase(object): scheme = "http" host = "localhost" + server_thread: typing.ClassVar[SocketServerThread] + port: typing.ClassVar[int] + + tmpdir: typing.ClassVar[str] + ca_path: typing.ClassVar[str] + cert_combined_path: typing.ClassVar[str] + cert_path: typing.ClassVar[str] + key_path: typing.ClassVar[str] + password_key_path: typing.ClassVar[str] + + server_context: typing.ClassVar[ssl.SSLContext] + client_context: typing.ClassVar[ssl.SSLContext] + + proxy_server: typing.ClassVar[SocketDummyServerTestCase] + @classmethod - def _start_server(cls, socket_handler): + def _start_server( + cls, socket_handler: typing.Callable[[socket.socket], None] + ) -> None: ready_event = threading.Event() cls.server_thread = SocketServerThread( socket_handler=socket_handler, ready_event=ready_event, host=cls.host @@ -48,10 +75,12 @@ def _start_server(cls, socket_handler): cls.port = cls.server_thread.port @classmethod - def start_response_handler(cls, response, num=1, block_send=None): + def start_response_handler( + cls, response: bytes, num: int = 1, block_send: threading.Event | None = None + ) -> threading.Event: ready_event = threading.Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for _ in range(num): ready_event.set() @@ -67,34 +96,45 @@ def socket_handler(listener): return ready_event @classmethod - def start_basic_handler(cls, **kw): + def start_basic_handler( + cls, num: int = 1, block_send: threading.Event | None = None + ) -> threading.Event: return cls.start_response_handler( - b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n", **kw + b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + num, + block_send, ) @classmethod - def teardown_class(cls): + def teardown_class(cls) -> None: if hasattr(cls, "server_thread"): cls.server_thread.join(0.1) def assert_header_received( - self, received_headers, header_name, expected_value=None - ): - header_name = header_name.encode("ascii") - if expected_value is not None: - expected_value = expected_value.encode("ascii") + self, + received_headers: typing.Iterable[bytes], + header_name: str, + expected_value: str | None = None, + ) -> None: + header_name_bytes = header_name.encode("ascii") + if expected_value is None: + expected_value_bytes = None + else: + expected_value_bytes = expected_value.encode("ascii") header_titles = [] for header in received_headers: key, value = header.split(b": ") header_titles.append(key) - if key == header_name and expected_value is not None: - assert value == expected_value - assert header_name in header_titles + if key == header_name_bytes and expected_value_bytes is not None: + assert value == expected_value_bytes + assert header_name_bytes in header_titles class IPV4SocketDummyServerTestCase(SocketDummyServerTestCase): @classmethod - def _start_server(cls, socket_handler): + def _start_server( + cls, socket_handler: typing.Callable[[socket.socket], None] + ) -> None: ready_event = threading.Event() cls.server_thread = SocketServerThread( socket_handler=socket_handler, ready_event=ready_event, host=cls.host @@ -107,7 +147,7 @@ def _start_server(cls, socket_handler): cls.port = cls.server_thread.port -class HTTPDummyServerTestCase(object): +class HTTPDummyServerTestCase: """A simple HTTP server that runs when your test class runs Have your test class inherit from this one, and then a simple server @@ -120,28 +160,39 @@ class HTTPDummyServerTestCase(object): host = "localhost" host_alt = "127.0.0.1" # Some tests need two hosts certs = DEFAULT_CERTS + base_url: typing.ClassVar[str] + base_url_alt: typing.ClassVar[str] + + io_loop: typing.ClassVar[ioloop.IOLoop] + server: typing.ClassVar[httpserver.HTTPServer] + port: typing.ClassVar[int] + server_thread: typing.ClassVar[threading.Thread] + _stack: typing.ClassVar[contextlib.ExitStack] @classmethod - def _start_server(cls): - cls.io_loop = ioloop.IOLoop.current() - app = web.Application([(r".*", TestingApp)]) - cls.server, cls.port = run_tornado_app( - app, cls.io_loop, cls.certs, cls.scheme, cls.host - ) - cls.server_thread = run_loop_in_thread(cls.io_loop) + def _start_server(cls) -> None: + with contextlib.ExitStack() as stack: + io_loop = stack.enter_context(run_loop_in_thread()) + + async def run_app() -> None: + app = web.Application([(r".*", TestingApp)]) + cls.server, cls.port = run_tornado_app( + app, cls.certs, cls.scheme, cls.host + ) + + asyncio.run_coroutine_threadsafe(run_app(), io_loop.asyncio_loop).result() # type: ignore[attr-defined] + cls._stack = stack.pop_all() @classmethod - def _stop_server(cls): - cls.io_loop.add_callback(cls.server.stop) - cls.io_loop.add_callback(cls.io_loop.stop) - cls.server_thread.join() + def _stop_server(cls) -> None: + cls._stack.close() @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls._start_server() @classmethod - def teardown_class(cls): + def teardown_class(cls) -> None: cls._stop_server() @@ -149,57 +200,78 @@ class HTTPSDummyServerTestCase(HTTPDummyServerTestCase): scheme = "https" host = "localhost" certs = DEFAULT_CERTS - - -class HTTPDummyProxyTestCase(object): - - http_host = "localhost" - http_host_alt = "127.0.0.1" - - https_host = "localhost" - https_host_alt = "127.0.0.1" - https_certs = DEFAULT_CERTS - - proxy_host = "localhost" - proxy_host_alt = "127.0.0.1" + certs_dir = "" + bad_ca_path = "" + + +class HTTPDummyProxyTestCase: + io_loop: typing.ClassVar[ioloop.IOLoop] + + http_host: typing.ClassVar[str] = "localhost" + http_host_alt: typing.ClassVar[str] = "127.0.0.1" + http_server: typing.ClassVar[httpserver.HTTPServer] + http_port: typing.ClassVar[int] + http_url: typing.ClassVar[str] + http_url_alt: typing.ClassVar[str] + + https_host: typing.ClassVar[str] = "localhost" + https_host_alt: typing.ClassVar[str] = "127.0.0.1" + https_certs: typing.ClassVar[dict[str, typing.Any]] = DEFAULT_CERTS + https_server: typing.ClassVar[httpserver.HTTPServer] + https_port: typing.ClassVar[int] + https_url: typing.ClassVar[str] + https_url_alt: typing.ClassVar[str] + + proxy_host: typing.ClassVar[str] = "localhost" + proxy_host_alt: typing.ClassVar[str] = "127.0.0.1" + proxy_server: typing.ClassVar[httpserver.HTTPServer] + proxy_port: typing.ClassVar[int] + proxy_url: typing.ClassVar[str] + https_proxy_server: typing.ClassVar[httpserver.HTTPServer] + https_proxy_port: typing.ClassVar[int] + https_proxy_url: typing.ClassVar[str] + + certs_dir: typing.ClassVar[str] = "" + bad_ca_path: typing.ClassVar[str] = "" + + server_thread: typing.ClassVar[threading.Thread] + _stack: typing.ClassVar[contextlib.ExitStack] @classmethod - def setup_class(cls): - cls.io_loop = ioloop.IOLoop.current() - - app = web.Application([(r".*", TestingApp)]) - cls.http_server, cls.http_port = run_tornado_app( - app, cls.io_loop, None, "http", cls.http_host - ) - - app = web.Application([(r".*", TestingApp)]) - cls.https_server, cls.https_port = run_tornado_app( - app, cls.io_loop, cls.https_certs, "https", cls.http_host - ) - - app = web.Application([(r".*", ProxyHandler)]) - cls.proxy_server, cls.proxy_port = run_tornado_app( - app, cls.io_loop, None, "http", cls.proxy_host - ) - - upstream_ca_certs = cls.https_certs.get("ca_certs", None) - app = web.Application( - [(r".*", ProxyHandler)], upstream_ca_certs=upstream_ca_certs - ) - cls.https_proxy_server, cls.https_proxy_port = run_tornado_app( - app, cls.io_loop, cls.https_certs, "https", cls.proxy_host - ) - - cls.server_thread = run_loop_in_thread(cls.io_loop) + def setup_class(cls) -> None: + with contextlib.ExitStack() as stack: + io_loop = stack.enter_context(run_loop_in_thread()) + + async def run_app() -> None: + app = web.Application([(r".*", TestingApp)]) + cls.http_server, cls.http_port = run_tornado_app( + app, None, "http", cls.http_host + ) + + app = web.Application([(r".*", TestingApp)]) + cls.https_server, cls.https_port = run_tornado_app( + app, cls.https_certs, "https", cls.http_host + ) + + app = web.Application([(r".*", ProxyHandler)]) + cls.proxy_server, cls.proxy_port = run_tornado_app( + app, None, "http", cls.proxy_host + ) + + upstream_ca_certs = cls.https_certs.get("ca_certs") + app = web.Application( + [(r".*", ProxyHandler)], upstream_ca_certs=upstream_ca_certs + ) + cls.https_proxy_server, cls.https_proxy_port = run_tornado_app( + app, cls.https_certs, "https", cls.proxy_host + ) + + asyncio.run_coroutine_threadsafe(run_app(), io_loop.asyncio_loop).result() # type: ignore[attr-defined] + cls._stack = stack.pop_all() @classmethod - def teardown_class(cls): - cls.io_loop.add_callback(cls.http_server.stop) - cls.io_loop.add_callback(cls.https_server.stop) - cls.io_loop.add_callback(cls.proxy_server.stop) - cls.io_loop.add_callback(cls.https_proxy_server.stop) - cls.io_loop.add_callback(cls.io_loop.stop) - cls.server_thread.join() + def teardown_class(cls) -> None: + cls._stack.close() @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 not available") @@ -209,7 +281,6 @@ class IPv6HTTPDummyServerTestCase(HTTPDummyServerTestCase): @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 not available") class IPv6HTTPDummyProxyTestCase(HTTPDummyProxyTestCase): - http_host = "localhost" http_host_alt = "127.0.0.1" @@ -221,7 +292,7 @@ class IPv6HTTPDummyProxyTestCase(HTTPDummyProxyTestCase): proxy_host_alt = "127.0.0.1" -class ConnectionMarker(object): +class ConnectionMarker: """ Marks an HTTP(S)Connection's socket after a request was made. @@ -232,32 +303,33 @@ class ConnectionMarker(object): MARK_FORMAT = b"$#MARK%04x*!" @classmethod - @contextmanager - def mark(cls, monkeypatch): + @contextlib.contextmanager + def mark( + cls, monkeypatch: pytest.MonkeyPatch + ) -> typing.Generator[None, None, None]: """ Mark connections under in that context. """ orig_request = HTTPConnection.request - orig_request_chunked = HTTPConnection.request_chunked - def call_and_mark(target): - def part(self, *args, **kwargs): - result = target(self, *args, **kwargs) + def call_and_mark( + target: typing.Callable[..., None] + ) -> typing.Callable[..., None]: + def part( + self: HTTPConnection, *args: typing.Any, **kwargs: typing.Any + ) -> None: + target(self, *args, **kwargs) self.sock.sendall(cls._get_socket_mark(self.sock, False)) - return result return part with monkeypatch.context() as m: m.setattr(HTTPConnection, "request", call_and_mark(orig_request)) - m.setattr( - HTTPConnection, "request_chunked", call_and_mark(orig_request_chunked) - ) yield @classmethod - def consume_request(cls, sock, chunks=65536): + def consume_request(cls, sock: socket.socket, chunks: int = 65536) -> bytearray: """ Consume a socket until after the HTTP request is sent. """ @@ -273,7 +345,7 @@ def consume_request(cls, sock, chunks=65536): return consumed @classmethod - def _get_socket_mark(cls, sock, server): + def _get_socket_mark(cls, sock: socket.socket, server: bool) -> bytes: if server: port = sock.getpeername()[1] else: diff --git a/mypy-requirements.txt b/mypy-requirements.txt new file mode 100644 index 0000000000..0d41bf1fa6 --- /dev/null +++ b/mypy-requirements.txt @@ -0,0 +1,9 @@ +mypy==1.2.0 +idna>=2.0.0 +cryptography>=1.3.4 +tornado>=6.1 +pytest>=6.2 +trustme==0.9.0 +types-backports +types-requests +nox diff --git a/notes/connection-lifecycle.md b/notes/connection-lifecycle.md new file mode 100644 index 0000000000..e7c007fce6 --- /dev/null +++ b/notes/connection-lifecycle.md @@ -0,0 +1,130 @@ +# Connection lifecycle + +## Current implementation + +`HTTPConnection` should be instantiated with `host` and `port` of the +**first origin being connected to** to reach the target origin. This either means +the target origin itself or the proxy origin if one is desired. + +```python +import urllib3.connection + +# Initialize the HTTPSConnection ('https://...') +conn = urllib3.connection.HTTPSConnection( + host="example.com", + # Here you can configure other options like + # 'ssl_minimum_version', 'ca_certs', etc. +) + +# Set the connect timeout either in the +# constructor above or via the property. +conn.timeout = 3.0 # (connect timeout) +``` + +If using CONNECT tunneling with the proxy, call `HTTPConnection.set_tunnel()` +with the tunneled host, port, and headers. This should be called before calling +`HTTPConnection.connect()` or sending a request. + +```python +conn = urllib3.connection.HTTPConnection( + # Remember that the *first* origin we want to connect to should + # be configured as 'host' and 'port', *not* the target origin. + host="myproxy.net", + port=8080, + proxy="http://myproxy.net:8080" +) + +conn.set_tunnel("example.com", scheme="http", headers={"Proxy-Header": "value"}) +``` + +Connect to the first origin by calling the `HTTPConnection.connect()` method. +If an error occurs here you can check whether the error occurred during the +connection to the proxy if `HTTPConnection.has_connected_to_proxy` is false. +If the value is true then the error didn't occur while connecting to a proxy. + +```python +# Explicitly connect to the origin. This isn't +# required as sending the first request will +# automatically connect if not done explicitly. +conn.connect() +``` + +After connecting to the origin, the connection can be checked to see if `is_verified` is set to true. If not the `HTTPConnectionPool` would emit a warning. The warning only matters for when verification is disabled, because otherwise an error is raised on unverified TLS handshake. + +```python +if not conn.is_verified: + # There isn't a verified TLS connection to target origin. +if not conn.is_proxy_verified: + # There isn't a verified TLS connection to proxy origin. +``` + +If the read timeout is different from the connect timeout then the +`HTTPConnection.timeout` property can be changed at this point. + +```python +conn.timeout = 5.0 # (read timeout) +``` + +Then the HTTP request can be sent with `HTTPConnection.request()`. If a `BrokenPipeError` is raised while sending the request body it can be swallowed as a response can still be received from the origin even when the request isn't completely sent. + +```python +try: + conn.request("GET", "/") +except BrokenPipeError: + # We can still try to get a response! + +resp = conn.getresponse() +``` + +Then response headers (and other info) are read from the connection via `HTTPConnection.getresponse()` and returned as a `urllib3.HTTPResponse`. The `HTTPResponse` instance carries a reference to the `HTTPConnection` instance so the connection can be closed if the connection gets into an undefined protocol state. + +```python +assert resp.connection is conn +``` + +If pooling is in use the `HTTPConnectionPool` will set `_pool` on the `HTTPResponse` instance. This will return the connection to the pool once the response is exhausted. If retries are in use set `retries` on the `HTTPResponse` instance. + +```python +# Set by the HTTPConnectionPool before returning to the caller. +resp = conn.getresponse() +resp._pool = pool + +# This will call resp._pool._put_conn(resp.connection) +# Connection can get auto-released by exhausting. +resp.release_conn() +``` + +If any error is received from connecting to the origin, sending the request, or receiving the response, the caller will call `HTTPConnection.close()` and discard the connection. Connections can be re-used after being closed, a new TCP connection to proxies and origins will be established. + +If instead of a tunneling proxy we were using a forwarding proxy then we configure the `HTTPConnection` similarly, except instead of `set_tunnel()` we send absolute URLs to `HTTPConnection.request()`: + +```python +import urllib3.connection + +# Initialize the HTTPConnection. +conn = urllib3.connection.HTTPConnection( + host="myproxy.net", + port=8080, + proxy="http://myproxy.net:8080" +) + +# You can request HTTP or HTTPS resources over the proxy +# using the absolute URL. +conn.request("GET", "http://example.com") +resp = conn.getresponse() + +conn.request("GET", "https://example.com") +resp = conn.getresponse() +``` + +### HTTP/HTTPS/proxies + +This is how `HTTPConnection` instances will be configured and used when a `PoolManager` or `ProxyManager` receives a given config: + +- No proxy, HTTP origin -> `HTTPConnection` +- No proxy, HTTPS origin -> `HTTPSConnection` +- HTTP proxy, HTTP origin -> `HTTPConnection` in forwarding mode +- HTTP proxy, HTTPS origin -> `HTTPSConnection` in tunnel mode +- HTTPS proxy, HTTP origin -> `HTTPSConnection` in forwarding mode +- HTTPS proxy, HTTPS origin -> `HTTPSConnection` in tunnel mode +- HTTPS proxy, HTTPS origin, `ProxyConfig.use_forwarding_for_https=True` -> `HTTPSConnection` in forwarding mode diff --git a/notes/public-and-private-apis.md b/notes/public-and-private-apis.md new file mode 100644 index 0000000000..6321696dbb --- /dev/null +++ b/notes/public-and-private-apis.md @@ -0,0 +1,28 @@ +# Public and private APIs + +## Public APIs + +- `urllib3.request()` +- `urllib3.PoolManager` +- `urllib3.ProxyManager` +- `urllib3.HTTPConnectionPool` +- `urllib3.HTTPSConnectionPool` +- `urllib3.BaseHTTPResponse` +- `urllib3.HTTPResponse` +- `urllib3.HTTPHeaderDict` +- `urllib3.filepost` +- `urllib3.fields` +- `urllib3.exceptions` +- `urllib3.contrib.*` +- `urllib3.util` + +Only public way to configure proxies is through `ProxyManager`? + +## Private APIs + +- `urllib3.connection` +- `urllib3.connection.BaseHTTPConnection` +- `urllib3.connection.BaseHTTPSConnection` +- `urllib3.connection.HTTPConnection` +- `urllib3.connection.HTTPSConnection` +- `urllib3.util.*` (submodules) diff --git a/noxfile.py b/noxfile.py index 169f7384ad..4b08f7edd2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,38 +1,20 @@ +from __future__ import annotations + import os import shutil -import subprocess +import sys import nox -# Whenever type-hints are completed on a file it should be added here so that -# this file will continue to be checked by mypy. Errors from other files are -# ignored. -TYPED_FILES = { - "src/urllib3/contrib/__init__.py", - "src/urllib3/exceptions.py", - "src/urllib3/fields.py", - "src/urllib3/filepost.py", - "src/urllib3/packages/__init__.py", - "src/urllib3/packages/six.py", - "src/urllib3/packages/ssl_match_hostname/__init__.py", - "src/urllib3/packages/ssl_match_hostname/_implementation.py", - "src/urllib3/util/queue.py", - "src/urllib3/util/url.py", -} -SOURCE_FILES = [ - "docs/", - "dummyserver/", - "src/", - "test/", - "noxfile.py", - "setup.py", -] - - -def tests_impl(session, extras="socks,secure,brotli"): + +def tests_impl( + session: nox.Session, + extras: str = "socks,secure,brotli,zstd", + byte_string_comparisons: bool = True, +) -> None: # Install deps and the package itself. session.install("-r", "dev-requirements.txt") - session.install(".[{extras}]".format(extras=extras)) + session.install(f".[{extras}]") # Show the pip version. session.run("pip", "--version") @@ -42,110 +24,148 @@ def tests_impl(session, extras="socks,secure,brotli"): # Print OpenSSL information. session.run("python", "-m", "OpenSSL.debug") - # Inspired from https://github.com/pyca/cryptography - # We use parallel mode and then combine here so that coverage.py will take - # the paths like .tox/pyXY/lib/pythonX.Y/site-packages/urllib3/__init__.py - # and collapse them into src/urllib3/__init__.py. - + memray_supported = True + if ( + sys.implementation.name != "cpython" + or sys.version_info < (3, 8) + or sys.version_info.releaselevel != "final" + ): + memray_supported = False # pytest-memray requires CPython 3.8+ + elif sys.platform == "win32": + memray_supported = False + + # Inspired from https://hynek.me/articles/ditch-codecov-python/ + # We use parallel mode and then combine in a later CI step session.run( + "python", + *(("-bb",) if byte_string_comparisons else ()), + "-m", "coverage", "run", "--parallel-mode", "-m", "pytest", - "-r", - "a", + *("--memray", "--hide-memray-summary") if memray_supported else (), + "-v", + "-ra", + f"--color={'yes' if 'GITHUB_ACTIONS' in os.environ else 'auto'}", "--tb=native", - "--no-success-flaky-report", + "--durations=10", + "--strict-config", + "--strict-markers", *(session.posargs or ("test/",)), env={"PYTHONWARNINGS": "always::DeprecationWarning"}, ) - session.run("coverage", "combine") - session.run("coverage", "report", "-m") - session.run("coverage", "xml") -@nox.session(python=["2.7", "3.5", "3.6", "3.7", "3.8", "3.9", "3.10", "pypy"]) -def test(session): +@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy"]) +def test(session: nox.Session) -> None: tests_impl(session) -@nox.session(python=["2", "3"]) -def google_brotli(session): - # https://pypi.org/project/Brotli/ is the Google version of brotli, so - # install it separately and don't install our brotli extra (which installs - # brotlipy). - session.install("brotli") - tests_impl(session, extras="socks,secure") +@nox.session(python=["3"]) +def test_brotlipy(session: nox.Session) -> None: + """Check that if 'brotlipy' is installed instead of 'brotli' or + 'brotlicffi' that we still don't blow up. + """ + session.install("brotlipy") + tests_impl(session, extras="socks,secure", byte_string_comparisons=False) -@nox.session(python="2.7") -def app_engine(session): - session.install("-r", "dev-requirements.txt") - session.install(".") - session.run( - "coverage", - "run", - "--parallel-mode", - "-m", - "pytest", - "-r", - "sx", - "test/appengine", - *session.posargs, - ) - session.run("coverage", "combine") - session.run("coverage", "report", "-m") - session.run("coverage", "xml") +def git_clone(session: nox.Session, git_url: str) -> None: + """We either clone the target repository or if already exist + simply reset the state and pull. + """ + expected_directory = git_url.split("/")[-1] + + if expected_directory.endswith(".git"): + expected_directory = expected_directory[:-4] + + if not os.path.isdir(expected_directory): + session.run("git", "clone", "--depth", "1", git_url, external=True) + else: + session.run( + "git", "-C", expected_directory, "reset", "--hard", "HEAD", external=True + ) + session.run("git", "-C", expected_directory, "pull", external=True) @nox.session() -def format(session): - """Run code formatters.""" - session.install("black", "isort") - session.run("black", *SOURCE_FILES) - session.run("isort", *SOURCE_FILES) +def downstream_botocore(session: nox.Session) -> None: + root = os.getcwd() + tmp_dir = session.create_tmp() + + session.cd(tmp_dir) + git_clone(session, "https://github.com/boto/botocore") + session.chdir("botocore") + for patch in [ + "0001-Mark-100-Continue-tests-as-failing.patch", + "0002-Stop-relying-on-removed-DEFAULT_CIPHERS.patch", + ]: + session.run("git", "apply", f"{root}/ci/{patch}", external=True) + session.run("git", "rev-parse", "HEAD", external=True) + session.run("python", "scripts/ci/install") + session.cd(root) + session.install(".", silent=False) + session.cd(f"{tmp_dir}/botocore") + + session.run("python", "-c", "import urllib3; print(urllib3.__version__)") + session.run("python", "scripts/ci/run-tests") + + +@nox.session() +def downstream_requests(session: nox.Session) -> None: + root = os.getcwd() + tmp_dir = session.create_tmp() + + session.cd(tmp_dir) + git_clone(session, "https://github.com/psf/requests") + session.chdir("requests") + session.run("git", "rev-parse", "HEAD", external=True) + session.install(".[socks]", silent=False) + session.install("-r", "requirements-dev.txt", silent=False) + + session.cd(root) + session.install(".", silent=False) + session.cd(f"{tmp_dir}/requests") + + session.run("python", "-c", "import urllib3; print(urllib3.__version__)") + session.run("pytest", "tests") + + +@nox.session() +def format(session: nox.Session) -> None: + """Run code formatters.""" lint(session) @nox.session -def lint(session): - session.install("flake8", "flake8-2020", "black", "isort", "mypy") - session.run("flake8", "--version") - session.run("black", "--version") - session.run("isort", "--version") +def lint(session: nox.Session) -> None: + session.install("pre-commit") + session.run("pre-commit", "run", "--all-files") + + mypy(session) + + +@nox.session(python="3.8") +def mypy(session: nox.Session) -> None: + """Run mypy.""" + session.install("-r", "mypy-requirements.txt") session.run("mypy", "--version") - session.run("black", "--check", *SOURCE_FILES) - session.run("isort", "--check", *SOURCE_FILES) - session.run("flake8", *SOURCE_FILES) - - session.log("mypy --strict src/urllib3") - all_errors, errors = [], [] - process = subprocess.run( - ["mypy", "--strict", "src/urllib3"], - env=session.env, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, + session.run( + "mypy", + "dummyserver", + "noxfile.py", + "src/urllib3", + "test", ) - # Ensure that mypy itself ran successfully - assert process.returncode in (0, 1) - - for line in process.stdout.split("\n"): - all_errors.append(line) - filepath = line.partition(":")[0] - if filepath.replace(".pyi", ".py") in TYPED_FILES: - errors.append(line) - session.log("all errors count: {}".format(len(all_errors))) - if errors: - session.error("\n" + "\n".join(sorted(set(errors)))) @nox.session -def docs(session): +def docs(session: nox.Session) -> None: session.install("-r", "docs/requirements.txt") - session.install(".[socks,secure,brotli]") + session.install(".[socks,secure,brotli,zstd]") session.chdir("docs") if os.path.exists("_build"): diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..7046f7726b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,123 @@ +# This file is protected via CODEOWNERS + +[build-system] +requires = ["hatchling>=1.6.0,<2"] +build-backend = "hatchling.build" + +[project] +name = "urllib3" +description = "HTTP library with thread-safe connection pooling, file post, and more." +readme = "README.md" +keywords = ["urllib", "httplib", "threadsafe", "filepost", "http", "https", "ssl", "pooling"] +authors = [ + {name = "Andrey Petrov", email = "andrey.petrov@shazow.net"} +] +maintainers = [ + {name = "Seth Michael Larson", email="sethmichaellarson@gmail.com"}, + {name = "Quentin Pradet", email="quentin@pradet.me"}, +] +classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Software Development :: Libraries", +] +requires-python = ">=3.7" +dynamic = ["version"] + +[project.optional-dependencies] +brotli = [ + "brotli>=1.0.9; platform_python_implementation == 'CPython'", + "brotlicffi>=0.8.0; platform_python_implementation != 'CPython'" +] +zstd = [ + "zstandard>=0.18.0", +] +secure = [ + "pyOpenSSL>=17.1.0", + "cryptography>=1.9", + "idna>=2.0.0", + "certifi", + "urllib3-secure-extra", +] +socks = [ + "PySocks>=1.5.6,<2.0,!=1.5.7", +] + +[project.urls] +"Changelog" = "https://github.com/urllib3/urllib3/blob/main/CHANGES.rst" +"Documentation" = "https://urllib3.readthedocs.io" +"Code" = "https://github.com/urllib3/urllib3" +"Issue tracker" = "https://github.com/urllib3/urllib3/issues" + +[tool.hatch.version] +path = "src/urllib3/_version.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/docs", + "/dummyserver", + "/src", + "/test", + "/dev-requirements.txt", + "/CHANGES.rst", + "/README.md", + "/LICENSE.txt", +] + +[tool.pytest.ini_options] +xfail_strict = true +python_classes = ["Test", "*TestCase"] +markers = ["limit_memory"] +log_level = "DEBUG" +filterwarnings = [ + "error", + '''default:'urllib3\[secure\]' extra is deprecated and will be removed in urllib3 v2\.1\.0.*:DeprecationWarning''', + '''default:'urllib3\.contrib\.pyopenssl' module is deprecated and will be removed in urllib3 v2\.1\.0.*:DeprecationWarning''', + '''default:'urllib3\.contrib\.securetransport' module is deprecated and will be removed in urllib3 v2\.1\.0.*:DeprecationWarning''', + '''default:No IPv6 support. Falling back to IPv4:urllib3.exceptions.HTTPWarning''', + '''default:No IPv6 support. skipping:urllib3.exceptions.HTTPWarning''', + '''default:ssl\.TLSVersion\.TLSv1 is deprecated:DeprecationWarning''', + '''default:ssl\.PROTOCOL_TLS is deprecated:DeprecationWarning''', + '''default:ssl\.PROTOCOL_TLSv1 is deprecated:DeprecationWarning''', + '''default:ssl\.TLSVersion\.TLSv1_1 is deprecated:DeprecationWarning''', + '''default:ssl\.PROTOCOL_TLSv1_1 is deprecated:DeprecationWarning''', + '''default:ssl\.PROTOCOL_TLSv1_2 is deprecated:DeprecationWarning''', + '''default:unclosed .*:ResourceWarning''', + '''default:ssl NPN is deprecated, use ALPN instead:DeprecationWarning''', +] + +[tool.isort] +profile = "black" +add_imports = "from __future__ import annotations" + +[tool.mypy] +mypy_path = "src" +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +no_implicit_reexport = true +show_error_codes = true +strict_equality = true +warn_redundant_casts = true +warn_return_any = true +warn_unused_configs = true +warn_unused_ignores = true diff --git a/setup.cfg b/setup.cfg index e8d273b5ab..5dd55b4b71 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,29 +1,4 @@ [flake8] ignore = E501, E203, W503, W504 -exclude=./docs/conf.py,./src/urllib3/packages/* +exclude=./docs/conf.py max-line-length=99 - -[bdist_wheel] -universal = 1 - -[metadata] -license_file = LICENSE.txt -provides-extra = - secure - socks - brotli -requires-dist = - pyOpenSSL>=0.14; extra == 'secure' - cryptography>=1.3.4; extra == 'secure' - idna>=2.0.0; extra == 'secure' - certifi; extra == 'secure' - ipaddress; python_version=="2.7" and extra == 'secure' - PySocks>=1.5.6,<2.0,!=1.5.7; extra == 'socks' - brotlipy>=0.6.0; extra == 'brotli' - -[tool:pytest] -xfail_strict = true -python_classes = Test *TestCase - -[isort] -profile=black diff --git a/setup.py b/setup.py deleted file mode 100755 index d5030fbd79..0000000000 --- a/setup.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python -# This file is protected via CODEOWNERS - -import codecs -import os -import re - -from setuptools import setup - -base_path = os.path.dirname(__file__) - -# Get the version (borrowed from SQLAlchemy) -with open(os.path.join(base_path, "src", "urllib3", "_version.py")) as fp: - VERSION = ( - re.compile(r""".*__version__ = ["'](.*?)['"]""", re.S).match(fp.read()).group(1) - ) - - -with codecs.open("README.rst", encoding="utf-8") as fp: - # Remove reST raw directive from README as they're not allowed on PyPI - # Those blocks start with a newline and continue until the next newline - mode = None - lines = [] - for line in fp: - if line.startswith(".. raw::"): - mode = "ignore_nl" - elif line == "\n": - mode = "wait_nl" if mode == "ignore_nl" else None - - if mode is None: - lines.append(line) - readme = "".join(lines) - -with codecs.open("CHANGES.rst", encoding="utf-8") as fp: - changes = fp.read() - -version = VERSION - -setup( - name="urllib3", - version=version, - description="HTTP library with thread-safe connection pooling, file post, and more.", - long_description=u"\n\n".join([readme, changes]), - long_description_content_type="text/x-rst", - classifiers=[ - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Software Development :: Libraries", - ], - keywords="urllib httplib threadsafe filepost http https ssl pooling", - author="Andrey Petrov", - author_email="andrey.petrov@shazow.net", - url="https://urllib3.readthedocs.io/", - project_urls={ - "Documentation": "https://urllib3.readthedocs.io/", - "Code": "https://github.com/urllib3/urllib3", - "Issue tracker": "https://github.com/urllib3/urllib3/issues", - }, - license="MIT", - packages=[ - "urllib3", - "urllib3.packages", - "urllib3.packages.ssl_match_hostname", - "urllib3.packages.backports", - "urllib3.contrib", - "urllib3.contrib._securetransport", - "urllib3.util", - ], - package_dir={"": "src"}, - requires=[], - python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4", - extras_require={ - "brotli": ["brotlipy>=0.6.0"], - "secure": [ - "pyOpenSSL>=0.14", - "cryptography>=1.3.4", - "idna>=2.0.0", - "certifi", - "ipaddress; python_version=='2.7'", - ], - "socks": ["PySocks>=1.5.6,<2.0,!=1.5.7"], - }, -) diff --git a/src/urllib3/__init__.py b/src/urllib3/__init__.py index fe86b59d78..43e79fa2bb 100644 --- a/src/urllib3/__init__.py +++ b/src/urllib3/__init__.py @@ -1,23 +1,63 @@ """ Python HTTP library with thread-safe connection pooling, file post support, user friendly, and more """ -from __future__ import absolute_import + +from __future__ import annotations # Set default logging handler to avoid "No handler found" warnings. import logging +import typing import warnings from logging import NullHandler from . import exceptions +from ._base_connection import _TYPE_BODY +from ._collections import HTTPHeaderDict from ._version import __version__ from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url -from .filepost import encode_multipart_formdata +from .filepost import _TYPE_FIELDS, encode_multipart_formdata from .poolmanager import PoolManager, ProxyManager, proxy_from_url -from .response import HTTPResponse +from .response import BaseHTTPResponse, HTTPResponse from .util.request import make_headers from .util.retry import Retry from .util.timeout import Timeout -from .util.url import get_host + +# Ensure that Python is compiled with OpenSSL 1.1.1+ +# If the 'ssl' module isn't available at all that's +# fine, we only care if the module is available. +try: + import ssl +except ImportError: + pass +else: + # fmt: off + if ( + not ssl.OPENSSL_VERSION.startswith("OpenSSL ") + or ssl.OPENSSL_VERSION_INFO < (1, 1, 1) + ): # Defensive: + raise ImportError( + "urllib3 v2.0 only supports OpenSSL 1.1.1+, currently " + f"the 'ssl' module is compiled with {ssl.OPENSSL_VERSION}. " + "See: https://github.com/urllib3/urllib3/issues/2168" + ) + # fmt: on + +# === NOTE TO REPACKAGERS AND VENDORS === +# Please delete this block, this logic is only +# for urllib3 being distributed via PyPI. +# See: https://github.com/urllib3/urllib3/issues/2680 +try: + import urllib3_secure_extra # type: ignore # noqa: F401 +except ModuleNotFoundError: + pass +else: + warnings.warn( + "'urllib3[secure]' extra is deprecated and will be removed " + "in urllib3 v2.1.0. Read more in this issue: " + "https://github.com/urllib3/urllib3/issues/2680", + category=DeprecationWarning, + stacklevel=2, + ) __author__ = "Andrey Petrov (andrey.petrov@shazow.net)" __license__ = "MIT" @@ -25,6 +65,7 @@ __all__ = ( "HTTPConnectionPool", + "HTTPHeaderDict", "HTTPSConnectionPool", "PoolManager", "ProxyManager", @@ -35,15 +76,17 @@ "connection_from_url", "disable_warnings", "encode_multipart_formdata", - "get_host", "make_headers", "proxy_from_url", + "request", ) logging.getLogger(__name__).addHandler(NullHandler()) -def add_stderr_logger(level=logging.DEBUG): +def add_stderr_logger( + level: int = logging.DEBUG, +) -> logging.StreamHandler[typing.TextIO]: """ Helper for quickly adding a StreamHandler to the logger. Useful for debugging. @@ -70,16 +113,51 @@ def add_stderr_logger(level=logging.DEBUG): # mechanisms to silence them. # SecurityWarning's always go off by default. warnings.simplefilter("always", exceptions.SecurityWarning, append=True) -# SubjectAltNameWarning's should go off once per host -warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True) # InsecurePlatformWarning's don't vary between requests, so we keep it default. warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True) -# SNIMissingWarnings should go off only once. -warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True) -def disable_warnings(category=exceptions.HTTPWarning): +def disable_warnings(category: type[Warning] = exceptions.HTTPWarning) -> None: """ Helper for quickly disabling all urllib3 warnings. """ warnings.simplefilter("ignore", category) + + +_DEFAULT_POOL = PoolManager() + + +def request( + method: str, + url: str, + *, + body: _TYPE_BODY | None = None, + fields: _TYPE_FIELDS | None = None, + headers: typing.Mapping[str, str] | None = None, + preload_content: bool | None = True, + decode_content: bool | None = True, + redirect: bool | None = True, + retries: Retry | bool | int | None = None, + timeout: Timeout | float | int | None = 3, + json: typing.Any | None = None, +) -> BaseHTTPResponse: + """ + A convenience, top-level request method. It uses a module-global ``PoolManager`` instance. + Therefore, its side effects could be shared across dependencies relying on it. + To avoid side effects create a new ``PoolManager`` instance and use it instead. + The method does not accept low-level ``**urlopen_kw`` keyword arguments. + """ + + return _DEFAULT_POOL.request( + method, + url, + body=body, + fields=fields, + headers=headers, + preload_content=preload_content, + decode_content=decode_content, + redirect=redirect, + retries=retries, + timeout=timeout, + json=json, + ) diff --git a/src/urllib3/_base_connection.py b/src/urllib3/_base_connection.py new file mode 100644 index 0000000000..3afed765bc --- /dev/null +++ b/src/urllib3/_base_connection.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import typing + +from .util.connection import _TYPE_SOCKET_OPTIONS +from .util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT +from .util.url import Url + +_TYPE_BODY = typing.Union[bytes, typing.IO[typing.Any], typing.Iterable[bytes], str] + + +class ProxyConfig(typing.NamedTuple): + ssl_context: ssl.SSLContext | None + use_forwarding_for_https: bool + assert_hostname: None | str | Literal[False] + assert_fingerprint: str | None + + +class _ResponseOptions(typing.NamedTuple): + # TODO: Remove this in favor of a better + # HTTP request/response lifecycle tracking. + request_method: str + request_url: str + preload_content: bool + decode_content: bool + enforce_content_length: bool + + +if typing.TYPE_CHECKING: + import ssl + + from typing_extensions import Literal, Protocol + + from .response import BaseHTTPResponse + + class BaseHTTPConnection(Protocol): + default_port: typing.ClassVar[int] + default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS] + + host: str + port: int + timeout: None | ( + float + ) # Instance doesn't store _DEFAULT_TIMEOUT, must be resolved. + blocksize: int + source_address: tuple[str, int] | None + socket_options: _TYPE_SOCKET_OPTIONS | None + + proxy: Url | None + proxy_config: ProxyConfig | None + + is_verified: bool + proxy_is_verified: bool | None + + def __init__( + self, + host: str, + port: int | None = None, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 8192, + socket_options: _TYPE_SOCKET_OPTIONS | None = ..., + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + ) -> None: + ... + + def set_tunnel( + self, + host: str, + port: int | None = None, + headers: typing.Mapping[str, str] | None = None, + scheme: str = "http", + ) -> None: + ... + + def connect(self) -> None: + ... + + def request( + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + # We know *at least* botocore is depending on the order of the + # first 3 parameters so to be safe we only mark the later ones + # as keyword-only to ensure we have space to extend. + *, + chunked: bool = False, + preload_content: bool = True, + decode_content: bool = True, + enforce_content_length: bool = True, + ) -> None: + ... + + def getresponse(self) -> BaseHTTPResponse: + ... + + def close(self) -> None: + ... + + @property + def is_closed(self) -> bool: + """Whether the connection either is brand new or has been previously closed. + If this property is True then both ``is_connected`` and ``has_connected_to_proxy`` + properties must be False. + """ + + @property + def is_connected(self) -> bool: + """Whether the connection is actively connected to any origin (proxy or target)""" + + @property + def has_connected_to_proxy(self) -> bool: + """Whether the connection has successfully connected to its proxy. + This returns False if no proxy is in use. Used to determine whether + errors are coming from the proxy layer or from tunnelling to the target origin. + """ + + class BaseHTTPSConnection(BaseHTTPConnection, Protocol): + default_port: typing.ClassVar[int] + default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS] + + # Certificate verification methods + cert_reqs: int | str | None + assert_hostname: None | str | Literal[False] + assert_fingerprint: str | None + ssl_context: ssl.SSLContext | None + + # Trusted CAs + ca_certs: str | None + ca_cert_dir: str | None + ca_cert_data: None | str | bytes + + # TLS version + ssl_minimum_version: int | None + ssl_maximum_version: int | None + ssl_version: int | str | None # Deprecated + + # Client certificates + cert_file: str | None + key_file: str | None + key_password: str | None + + def __init__( + self, + host: str, + port: int | None = None, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 8192, + socket_options: _TYPE_SOCKET_OPTIONS | None = ..., + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + cert_reqs: int | str | None = None, + assert_hostname: None | str | Literal[False] = None, + assert_fingerprint: str | None = None, + server_hostname: str | None = None, + ssl_context: ssl.SSLContext | None = None, + ca_certs: str | None = None, + ca_cert_dir: str | None = None, + ca_cert_data: None | str | bytes = None, + ssl_minimum_version: int | None = None, + ssl_maximum_version: int | None = None, + ssl_version: int | str | None = None, # Deprecated + cert_file: str | None = None, + key_file: str | None = None, + key_password: str | None = None, + ) -> None: + ... diff --git a/src/urllib3/_collections.py b/src/urllib3/_collections.py index da9857e986..3e43afbc90 100644 --- a/src/urllib3/_collections.py +++ b/src/urllib3/_collections.py @@ -1,34 +1,66 @@ -from __future__ import absolute_import - -try: - from collections.abc import Mapping, MutableMapping -except ImportError: - from collections import Mapping, MutableMapping -try: - from threading import RLock -except ImportError: # Platform-specific: No threads available - - class RLock: - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_value, traceback): - pass - +from __future__ import annotations +import typing from collections import OrderedDict +from enum import Enum, auto +from threading import RLock -from .exceptions import InvalidHeader -from .packages import six -from .packages.six import iterkeys, itervalues +if typing.TYPE_CHECKING: + # We can only import Protocol if TYPE_CHECKING because it's a development + # dependency, and is not available at runtime. + from typing_extensions import Protocol -__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"] + class HasGettableStringKeys(Protocol): + def keys(self) -> typing.Iterator[str]: + ... + + def __getitem__(self, key: str) -> str: + ... -_Null = object() +__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"] -class RecentlyUsedContainer(MutableMapping): +# Key type +_KT = typing.TypeVar("_KT") +# Value type +_VT = typing.TypeVar("_VT") +# Default type +_DT = typing.TypeVar("_DT") + +ValidHTTPHeaderSource = typing.Union[ + "HTTPHeaderDict", + typing.Mapping[str, str], + typing.Iterable[typing.Tuple[str, str]], + "HasGettableStringKeys", +] + + +class _Sentinel(Enum): + not_passed = auto() + + +def ensure_can_construct_http_header_dict( + potential: object, +) -> ValidHTTPHeaderSource | None: + if isinstance(potential, HTTPHeaderDict): + return potential + elif isinstance(potential, typing.Mapping): + # Full runtime checking of the contents of a Mapping is expensive, so for the + # purposes of typechecking, we assume that any Mapping is the right shape. + return typing.cast(typing.Mapping[str, str], potential) + elif isinstance(potential, typing.Iterable): + # Similarly to Mapping, full runtime checking of the contents of an Iterable is + # expensive, so for the purposes of typechecking, we assume that any Iterable + # is the right shape. + return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential) + elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"): + return typing.cast("HasGettableStringKeys", potential) + else: + return None + + +class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]): """ Provides a thread-safe dict-like container which maintains up to ``maxsize`` keys while throwing away the least-recently-used keys beyond @@ -42,69 +74,134 @@ class RecentlyUsedContainer(MutableMapping): ``dispose_func(value)`` is called. Callback which will get called """ - ContainerCls = OrderedDict - - def __init__(self, maxsize=10, dispose_func=None): + _container: typing.OrderedDict[_KT, _VT] + _maxsize: int + dispose_func: typing.Callable[[_VT], None] | None + lock: RLock + + def __init__( + self, + maxsize: int = 10, + dispose_func: typing.Callable[[_VT], None] | None = None, + ) -> None: + super().__init__() self._maxsize = maxsize self.dispose_func = dispose_func - - self._container = self.ContainerCls() + self._container = OrderedDict() self.lock = RLock() - def __getitem__(self, key): + def __getitem__(self, key: _KT) -> _VT: # Re-insert the item, moving it to the end of the eviction line. with self.lock: item = self._container.pop(key) self._container[key] = item return item - def __setitem__(self, key, value): - evicted_value = _Null + def __setitem__(self, key: _KT, value: _VT) -> None: + evicted_item = None with self.lock: # Possibly evict the existing value of 'key' - evicted_value = self._container.get(key, _Null) - self._container[key] = value - - # If we didn't evict an existing value, we might have to evict the - # least recently used item from the beginning of the container. - if len(self._container) > self._maxsize: - _key, evicted_value = self._container.popitem(last=False) - - if self.dispose_func and evicted_value is not _Null: + try: + # If the key exists, we'll overwrite it, which won't change the + # size of the pool. Because accessing a key should move it to + # the end of the eviction line, we pop it out first. + evicted_item = key, self._container.pop(key) + self._container[key] = value + except KeyError: + # When the key does not exist, we insert the value first so that + # evicting works in all cases, including when self._maxsize is 0 + self._container[key] = value + if len(self._container) > self._maxsize: + # If we didn't evict an existing value, and we've hit our maximum + # size, then we have to evict the least recently used item from + # the beginning of the container. + evicted_item = self._container.popitem(last=False) + + # After releasing the lock on the pool, dispose of any evicted value. + if evicted_item is not None and self.dispose_func: + _, evicted_value = evicted_item self.dispose_func(evicted_value) - def __delitem__(self, key): + def __delitem__(self, key: _KT) -> None: with self.lock: value = self._container.pop(key) if self.dispose_func: self.dispose_func(value) - def __len__(self): + def __len__(self) -> int: with self.lock: return len(self._container) - def __iter__(self): + def __iter__(self) -> typing.NoReturn: raise NotImplementedError( "Iteration over this class is unlikely to be threadsafe." ) - def clear(self): + def clear(self) -> None: with self.lock: # Copy pointers to all values, then wipe the mapping - values = list(itervalues(self._container)) + values = list(self._container.values()) self._container.clear() if self.dispose_func: for value in values: self.dispose_func(value) - def keys(self): + def keys(self) -> set[_KT]: # type: ignore[override] with self.lock: - return list(iterkeys(self._container)) + return set(self._container.keys()) -class HTTPHeaderDict(MutableMapping): +class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]): + """ + HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of + address. + + If we directly try to get an item with a particular name, we will get a string + back that is the concatenated version of all the values: + + >>> d['X-Header-Name'] + 'Value1, Value2, Value3' + + However, if we iterate over an HTTPHeaderDict's items, we will optionally combine + these values based on whether combine=True was called when building up the dictionary + + >>> d = HTTPHeaderDict({"A": "1", "B": "foo"}) + >>> d.add("A", "2", combine=True) + >>> d.add("B", "bar") + >>> list(d.items()) + [ + ('A', '1, 2'), + ('B', 'foo'), + ('B', 'bar'), + ] + + This class conforms to the interface required by the MutableMapping ABC while + also giving us the nonstandard iteration behavior we want; items with duplicate + keys, ordered by time of first insertion. + """ + + _headers: HTTPHeaderDict + + def __init__(self, headers: HTTPHeaderDict) -> None: + self._headers = headers + + def __len__(self) -> int: + return len(list(self._headers.iteritems())) + + def __iter__(self) -> typing.Iterator[tuple[str, str]]: + return self._headers.iteritems() + + def __contains__(self, item: object) -> bool: + if isinstance(item, tuple) and len(item) == 2: + passed_key, passed_val = item + if isinstance(passed_key, str) and isinstance(passed_val, str): + return self._headers._has_value_for_header(passed_key, passed_val) + return False + + +class HTTPHeaderDict(typing.MutableMapping[str, str]): """ :param headers: An iterable of field-value pairs. Must not contain multiple field names @@ -138,9 +235,11 @@ class HTTPHeaderDict(MutableMapping): '7' """ - def __init__(self, headers=None, **kwargs): - super(HTTPHeaderDict, self).__init__() - self._container = OrderedDict() + _container: typing.MutableMapping[str, list[str]] + + def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str): + super().__init__() + self._container = {} # 'dict' is insert-ordered in Python 3.7+ if headers is not None: if isinstance(headers, HTTPHeaderDict): self._copy_from(headers) @@ -149,123 +248,147 @@ def __init__(self, headers=None, **kwargs): if kwargs: self.extend(kwargs) - def __setitem__(self, key, val): + def __setitem__(self, key: str, val: str) -> None: + # avoid a bytes/str comparison by decoding before httplib + if isinstance(key, bytes): + key = key.decode("latin-1") self._container[key.lower()] = [key, val] - return self._container[key.lower()] - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: val = self._container[key.lower()] return ", ".join(val[1:]) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._container[key.lower()] - def __contains__(self, key): - return key.lower() in self._container + def __contains__(self, key: object) -> bool: + if isinstance(key, str): + return key.lower() in self._container + return False - def __eq__(self, other): - if not isinstance(other, Mapping) and not hasattr(other, "keys"): - return False - if not isinstance(other, type(self)): - other = type(self)(other) - return dict((k.lower(), v) for k, v in self.itermerged()) == dict( - (k.lower(), v) for k, v in other.itermerged() - ) + def setdefault(self, key: str, default: str = "") -> str: + return super().setdefault(key, default) - def __ne__(self, other): - return not self.__eq__(other) + def __eq__(self, other: object) -> bool: + maybe_constructable = ensure_can_construct_http_header_dict(other) + if maybe_constructable is None: + return False + else: + other_as_http_header_dict = type(self)(maybe_constructable) - if six.PY2: # Python 2 - iterkeys = MutableMapping.iterkeys - itervalues = MutableMapping.itervalues + return {k.lower(): v for k, v in self.itermerged()} == { + k.lower(): v for k, v in other_as_http_header_dict.itermerged() + } - __marker = object() + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) - def __len__(self): + def __len__(self) -> int: return len(self._container) - def __iter__(self): + def __iter__(self) -> typing.Iterator[str]: # Only provide the originally cased names for vals in self._container.values(): yield vals[0] - def pop(self, key, default=__marker): - """D.pop(k[,d]) -> v, remove specified key and return the corresponding value. - If key is not found, d is returned if given, otherwise KeyError is raised. - """ - # Using the MutableMapping function directly fails due to the private marker. - # Using ordinary dict.pop would expose the internal structures. - # So let's reinvent the wheel. - try: - value = self[key] - except KeyError: - if default is self.__marker: - raise - return default - else: - del self[key] - return value - - def discard(self, key): + def discard(self, key: str) -> None: try: del self[key] except KeyError: pass - def add(self, key, val): + def add(self, key: str, val: str, *, combine: bool = False) -> None: """Adds a (name, value) pair, doesn't overwrite the value if it already exists. + If this is called with combine=True, instead of adding a new header value + as a distinct item during iteration, this will instead append the value to + any existing header value with a comma. If no existing header value exists + for the key, then the value will simply be added, ignoring the combine parameter. + >>> headers = HTTPHeaderDict(foo='bar') >>> headers.add('Foo', 'baz') >>> headers['foo'] 'bar, baz' + >>> list(headers.items()) + [('foo', 'bar'), ('foo', 'baz')] + >>> headers.add('foo', 'quz', combine=True) + >>> list(headers.items()) + [('foo', 'bar, baz, quz')] """ + # avoid a bytes/str comparison by decoding before httplib + if isinstance(key, bytes): + key = key.decode("latin-1") key_lower = key.lower() new_vals = [key, val] # Keep the common case aka no item present as fast as possible vals = self._container.setdefault(key_lower, new_vals) if new_vals is not vals: - vals.append(val) + # if there are values here, then there is at least the initial + # key/value pair + assert len(vals) >= 2 + if combine: + vals[-1] = vals[-1] + ", " + val + else: + vals.append(val) - def extend(self, *args, **kwargs): + def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None: """Generic import function for any type of header-like object. Adapted version of MutableMapping.update in order to insert items with self.add instead of self.__setitem__ """ if len(args) > 1: raise TypeError( - "extend() takes at most 1 positional " - "arguments ({0} given)".format(len(args)) + f"extend() takes at most 1 positional arguments ({len(args)} given)" ) other = args[0] if len(args) >= 1 else () if isinstance(other, HTTPHeaderDict): for key, val in other.iteritems(): self.add(key, val) - elif isinstance(other, Mapping): - for key in other: - self.add(key, other[key]) - elif hasattr(other, "keys"): - for key in other.keys(): - self.add(key, other[key]) - else: + elif isinstance(other, typing.Mapping): + for key, val in other.items(): + self.add(key, val) + elif isinstance(other, typing.Iterable): + other = typing.cast(typing.Iterable[typing.Tuple[str, str]], other) for key, value in other: self.add(key, value) + elif hasattr(other, "keys") and hasattr(other, "__getitem__"): + # THIS IS NOT A TYPESAFE BRANCH + # In this branch, the object has a `keys` attr but is not a Mapping or any of + # the other types indicated in the method signature. We do some stuff with + # it as though it partially implements the Mapping interface, but we're not + # doing that stuff safely AT ALL. + for key in other.keys(): + self.add(key, other[key]) for key, value in kwargs.items(): self.add(key, value) - def getlist(self, key, default=__marker): + @typing.overload + def getlist(self, key: str) -> list[str]: + ... + + @typing.overload + def getlist(self, key: str, default: _DT) -> list[str] | _DT: + ... + + def getlist( + self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed + ) -> list[str] | _DT: """Returns a list of all the values for the named field. Returns an empty list if the key doesn't exist.""" try: vals = self._container[key.lower()] except KeyError: - if default is self.__marker: + if default is _Sentinel.not_passed: + # _DT is unbound; empty list is instance of List[str] return [] + # _DT is bound; default is instance of _DT return default else: + # _DT may or may not be bound; vals[1:] is instance of List[str], which + # meets our external interface requirement of `Union[List[str], _DT]`. return vals[1:] # Backwards compatibility for httplib @@ -276,62 +399,36 @@ def getlist(self, key, default=__marker): # Backwards compatibility for http.cookiejar get_all = getlist - def __repr__(self): - return "%s(%s)" % (type(self).__name__, dict(self.itermerged())) + def __repr__(self) -> str: + return f"{type(self).__name__}({dict(self.itermerged())})" - def _copy_from(self, other): + def _copy_from(self, other: HTTPHeaderDict) -> None: for key in other: val = other.getlist(key) - if isinstance(val, list): - # Don't need to convert tuples - val = list(val) - self._container[key.lower()] = [key] + val + self._container[key.lower()] = [key, *val] - def copy(self): + def copy(self) -> HTTPHeaderDict: clone = type(self)() clone._copy_from(self) return clone - def iteritems(self): + def iteritems(self) -> typing.Iterator[tuple[str, str]]: """Iterate over all header lines, including duplicate ones.""" for key in self: vals = self._container[key.lower()] for val in vals[1:]: yield vals[0], val - def itermerged(self): + def itermerged(self) -> typing.Iterator[tuple[str, str]]: """Iterate over all headers, merging duplicate ones together.""" for key in self: val = self._container[key.lower()] yield val[0], ", ".join(val[1:]) - def items(self): - return list(self.iteritems()) - - @classmethod - def from_httplib(cls, message): # Python 2 - """Read headers from a Python 2 httplib message object.""" - # python2.7 does not expose a proper API for exporting multiheaders - # efficiently. This function re-reads raw lines from the message - # object and extracts the multiheaders properly. - obs_fold_continued_leaders = (" ", "\t") - headers = [] - - for line in message.headers: - if line.startswith(obs_fold_continued_leaders): - if not headers: - # We received a header line that starts with OWS as described - # in RFC-7230 S3.2.4. This indicates a multiline header, but - # there exists no previous header to which we can attach it. - raise InvalidHeader( - "Header continuation with no previous header: %s" % line - ) - else: - key, value = headers[-1] - headers[-1] = (key, value + " " + line.strip()) - continue - - key, value = line.split(":", 1) - headers.append((key, value.strip())) - - return cls(headers) + def items(self) -> HTTPHeaderDictItemView: # type: ignore[override] + return HTTPHeaderDictItemView(self) + + def _has_value_for_header(self, header_name: str, potential_value: str) -> bool: + if header_name in self: + return potential_value in self._container[header_name.lower()][1:] + return False diff --git a/src/urllib3/request.py b/src/urllib3/_request_methods.py similarity index 63% rename from src/urllib3/request.py rename to src/urllib3/_request_methods.py index 398386a5b9..1d0f3465ad 100644 --- a/src/urllib3/request.py +++ b/src/urllib3/_request_methods.py @@ -1,12 +1,23 @@ -from __future__ import absolute_import +from __future__ import annotations -from .filepost import encode_multipart_formdata -from .packages.six.moves.urllib.parse import urlencode +import json as _json +import typing +from urllib.parse import urlencode + +from ._base_connection import _TYPE_BODY +from ._collections import HTTPHeaderDict +from .filepost import _TYPE_FIELDS, encode_multipart_formdata +from .response import BaseHTTPResponse __all__ = ["RequestMethods"] +_TYPE_ENCODE_URL_FIELDS = typing.Union[ + typing.Sequence[typing.Tuple[str, typing.Union[str, bytes]]], + typing.Mapping[str, typing.Union[str, bytes]], +] + -class RequestMethods(object): +class RequestMethods: """ Convenience mixin for classes who implement a :meth:`urlopen` method, such as :class:`urllib3.HTTPConnectionPool` and @@ -37,25 +48,34 @@ class RequestMethods(object): _encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"} - def __init__(self, headers=None): + def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None: self.headers = headers or {} def urlopen( self, - method, - url, - body=None, - headers=None, - encode_multipart=True, - multipart_boundary=None, - **kw - ): # Abstract + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + encode_multipart: bool = True, + multipart_boundary: str | None = None, + **kw: typing.Any, + ) -> BaseHTTPResponse: # Abstract raise NotImplementedError( "Classes extending RequestMethods must implement " "their own ``urlopen`` method." ) - def request(self, method, url, fields=None, headers=None, **urlopen_kw): + def request( + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + fields: _TYPE_FIELDS | None = None, + headers: typing.Mapping[str, str] | None = None, + json: typing.Any | None = None, + **urlopen_kw: typing.Any, + ) -> BaseHTTPResponse: """ Make a request using :meth:`urlopen` with the appropriate encoding of ``fields`` based on the ``method`` used. @@ -68,18 +88,45 @@ def request(self, method, url, fields=None, headers=None, **urlopen_kw): """ method = method.upper() - urlopen_kw["request_url"] = url + if json is not None and body is not None: + raise TypeError( + "request got values for both 'body' and 'json' parameters which are mutually exclusive" + ) + + if json is not None: + if headers is None: + headers = self.headers.copy() # type: ignore + if not ("content-type" in map(str.lower, headers.keys())): + headers["Content-Type"] = "application/json" # type: ignore + + body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode( + "utf-8" + ) + + if body is not None: + urlopen_kw["body"] = body if method in self._encode_url_methods: return self.request_encode_url( - method, url, fields=fields, headers=headers, **urlopen_kw + method, + url, + fields=fields, # type: ignore[arg-type] + headers=headers, + **urlopen_kw, ) else: return self.request_encode_body( method, url, fields=fields, headers=headers, **urlopen_kw ) - def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw): + def request_encode_url( + self, + method: str, + url: str, + fields: _TYPE_ENCODE_URL_FIELDS | None = None, + headers: typing.Mapping[str, str] | None = None, + **urlopen_kw: str, + ) -> BaseHTTPResponse: """ Make a request using :meth:`urlopen` with the ``fields`` encoded in the url. This is useful for request methods like GET, HEAD, DELETE, etc. @@ -87,7 +134,7 @@ def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_k if headers is None: headers = self.headers - extra_kw = {"headers": headers} + extra_kw: dict[str, typing.Any] = {"headers": headers} extra_kw.update(urlopen_kw) if fields: @@ -97,14 +144,14 @@ def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_k def request_encode_body( self, - method, - url, - fields=None, - headers=None, - encode_multipart=True, - multipart_boundary=None, - **urlopen_kw - ): + method: str, + url: str, + fields: _TYPE_FIELDS | None = None, + headers: typing.Mapping[str, str] | None = None, + encode_multipart: bool = True, + multipart_boundary: str | None = None, + **urlopen_kw: str, + ) -> BaseHTTPResponse: """ Make a request using :meth:`urlopen` with the ``fields`` encoded in the body. This is useful for request methods like POST, PUT, PATCH, etc. @@ -143,7 +190,8 @@ def request_encode_body( if headers is None: headers = self.headers - extra_kw = {"headers": {}} + extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)} + body: bytes | str if fields: if "body" in urlopen_kw: @@ -157,14 +205,13 @@ def request_encode_body( ) else: body, content_type = ( - urlencode(fields), + urlencode(fields), # type: ignore[arg-type] "application/x-www-form-urlencoded", ) extra_kw["body"] = body - extra_kw["headers"] = {"Content-Type": content_type} + extra_kw["headers"].setdefault("Content-Type", content_type) - extra_kw["headers"].update(headers) extra_kw.update(urlopen_kw) return self.urlopen(method, url, **extra_kw) diff --git a/src/urllib3/_version.py b/src/urllib3/_version.py index 3110e75103..6c6d26d428 100644 --- a/src/urllib3/_version.py +++ b/src/urllib3/_version.py @@ -1,2 +1,4 @@ # This file is protected via CODEOWNERS -__version__ = "1.26.0.dev0" +from __future__ import annotations + +__version__ = "2.0.0" diff --git a/src/urllib3/connection.py b/src/urllib3/connection.py index 52487417c9..1f13af0b63 100644 --- a/src/urllib3/connection.py +++ b/src/urllib3/connection.py @@ -1,64 +1,70 @@ -from __future__ import absolute_import +from __future__ import annotations import datetime import logging import os import re import socket +import typing import warnings -from socket import error as SocketError +from http.client import HTTPConnection as _HTTPConnection +from http.client import HTTPException as HTTPException # noqa: F401 +from http.client import ResponseNotReady from socket import timeout as SocketTimeout -from .packages import six -from .packages.six.moves.http_client import HTTPConnection as _HTTPConnection -from .packages.six.moves.http_client import HTTPException # noqa: F401 -from .util.proxy import create_proxy_ssl_context +if typing.TYPE_CHECKING: + from typing_extensions import Literal + + from .response import HTTPResponse + from .util.ssl_ import _TYPE_PEER_CERT_RET_DICT + from .util.ssltransport import SSLTransport + +from ._collections import HTTPHeaderDict +from .util.response import assert_header_parsing +from .util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT, Timeout +from .util.util import to_str +from .util.wait import wait_for_read try: # Compiled with SSL? import ssl BaseSSLError = ssl.SSLError -except (ImportError, AttributeError): # Platform-specific: No SSL. - ssl = None - - class BaseSSLError(BaseException): - pass - - -try: - # Python 3: not a no-op, we're adding this to the namespace so it can be imported. - ConnectionError = ConnectionError -except NameError: - # Python 2 - class ConnectionError(Exception): - pass - +except (ImportError, AttributeError): + ssl = None # type: ignore[assignment] -try: # Python 3: - # Not a no-op, we're adding this to the namespace so it can be imported. - BrokenPipeError = BrokenPipeError -except NameError: # Python 2: - - class BrokenPipeError(Exception): + class BaseSSLError(BaseException): # type: ignore[no-redef] pass +from ._base_connection import _TYPE_BODY +from ._base_connection import ProxyConfig as ProxyConfig +from ._base_connection import _ResponseOptions as _ResponseOptions from ._version import __version__ from .exceptions import ( ConnectTimeoutError, + HeaderParsingError, + NameResolutionError, NewConnectionError, - SubjectAltNameWarning, + ProxyError, SystemTimeWarning, ) -from .packages.ssl_match_hostname import CertificateError, match_hostname -from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection +from .util import SKIP_HEADER, SKIPPABLE_HEADERS, connection, ssl_ +from .util.request import body_to_chunks +from .util.ssl_ import assert_fingerprint as _assert_fingerprint from .util.ssl_ import ( - assert_fingerprint, create_urllib3_context, + is_ipaddress, resolve_cert_reqs, resolve_ssl_version, ssl_wrap_socket, ) +from .util.ssl_match_hostname import CertificateError, match_hostname +from .util.url import Url + +# Not a no-op, we're adding this to the namespace so it can be imported. +ConnectionError = ConnectionError +BrokenPipeError = BrokenPipeError + log = logging.getLogger(__name__) @@ -66,12 +72,12 @@ class BrokenPipeError(Exception): # When it comes time to update this value as a part of regular maintenance # (ie test_recent_date is failing) update it to ~6 months before the current date. -RECENT_DATE = datetime.date(2019, 1, 1) +RECENT_DATE = datetime.date(2022, 1, 1) _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") -class HTTPConnection(_HTTPConnection, object): +class HTTPConnection(_HTTPConnection): """ Based on :class:`http.client.HTTPConnection` but provides an extra constructor backwards-compatibility layer between older and newer Pythons. @@ -79,7 +85,6 @@ class HTTPConnection(_HTTPConnection, object): Additional keyword parameters are used to configure attributes of the connection. Accepted parameters include: - - ``strict``: See the documentation on :class:`urllib3.connectionpool.HTTPConnectionPool` - ``source_address``: Set the source address for the current connection. - ``socket_options``: Set specific options on the underlying socket. If not specified, then defaults are loaded from ``HTTPConnection.default_socket_options`` which includes disabling @@ -97,34 +102,68 @@ class HTTPConnection(_HTTPConnection, object): Or you may want to disable the defaults by passing an empty list (e.g., ``[]``). """ - default_port = port_by_scheme["http"] + default_port: typing.ClassVar[int] = port_by_scheme["http"] # type: ignore[misc] #: Disable Nagle's algorithm by default. #: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` - default_socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + default_socket_options: typing.ClassVar[connection._TYPE_SOCKET_OPTIONS] = [ + (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + ] #: Whether this connection verifies the host's certificate. - is_verified = False - - def __init__(self, *args, **kw): - if not six.PY2: - kw.pop("strict", None) + is_verified: bool = False - # Pre-set source_address. - self.source_address = kw.get("source_address") + #: Whether this proxy connection verified the proxy host's certificate. + # If no proxy is currently connected to the value will be ``None``. + proxy_is_verified: bool | None = None - #: The socket options provided by the user. If no options are - #: provided, we use the default options. - self.socket_options = kw.pop("socket_options", self.default_socket_options) + blocksize: int + source_address: tuple[str, int] | None + socket_options: connection._TYPE_SOCKET_OPTIONS | None - # Proxy options provided by the user. - self.proxy = kw.pop("proxy", None) - self.proxy_config = kw.pop("proxy_config", None) - - _HTTPConnection.__init__(self, *args, **kw) + _has_connected_to_proxy: bool + _response_options: _ResponseOptions | None + _tunnel_host: str | None + _tunnel_port: int | None + _tunnel_scheme: str | None + def __init__( + self, + host: str, + port: int | None = None, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 8192, + socket_options: None + | (connection._TYPE_SOCKET_OPTIONS) = default_socket_options, + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + ) -> None: + super().__init__( + host=host, + port=port, + timeout=Timeout.resolve_default_timeout(timeout), + source_address=source_address, + blocksize=blocksize, + ) + self.socket_options = socket_options + self.proxy = proxy + self.proxy_config = proxy_config + + self._has_connected_to_proxy = False + self._response_options = None + self._tunnel_host: str | None = None + self._tunnel_port: int | None = None + self._tunnel_scheme: str | None = None + + # https://github.com/python/mypy/issues/4125 + # Mypy treats this as LSP violation, which is considered a bug. + # If `host` is made a property it violates LSP, because a writeable attribute is overridden with a read-only one. + # However, there is also a `host` setter so LSP is not violated. + # Potentially, a `@host.deleter` might be needed depending on how this issue will be fixed. @property - def host(self): + def host(self) -> str: """ Getter method to remove any trailing dots that indicate the hostname is an FQDN. @@ -143,7 +182,7 @@ def host(self): return self._dns_host.rstrip(".") @host.setter - def host(self, value): + def host(self, value: str) -> None: """ Setter for the `host` property. @@ -152,124 +191,294 @@ def host(self, value): """ self._dns_host = value - def _new_conn(self): + def _new_conn(self) -> socket.socket: """Establish a socket connection and set nodelay settings on it. :return: New socket connection. """ - extra_kw = {} - if self.source_address: - extra_kw["source_address"] = self.source_address - - if self.socket_options: - extra_kw["socket_options"] = self.socket_options - try: - conn = connection.create_connection( - (self._dns_host, self.port), self.timeout, **extra_kw + sock = connection.create_connection( + (self._dns_host, self.port), + self.timeout, + source_address=self.source_address, + socket_options=self.socket_options, ) - - except SocketTimeout: + except socket.gaierror as e: + raise NameResolutionError(self.host, self, e) from e + except SocketTimeout as e: raise ConnectTimeoutError( self, - "Connection to %s timed out. (connect timeout=%s)" - % (self.host, self.timeout), - ) + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e - except SocketError as e: + except OSError as e: raise NewConnectionError( - self, "Failed to establish a new connection: %s" % e - ) + self, f"Failed to establish a new connection: {e}" + ) from e - return conn + return sock - def _is_using_tunnel(self): - # Google App Engine's httplib does not define _tunnel_host - return getattr(self, "_tunnel_host", None) + def set_tunnel( + self, + host: str, + port: int | None = None, + headers: typing.Mapping[str, str] | None = None, + scheme: str = "http", + ) -> None: + if scheme not in ("http", "https"): + raise ValueError( + f"Invalid proxy scheme for tunneling: {scheme!r}, must be either 'http' or 'https'" + ) + super().set_tunnel(host, port=port, headers=headers) + self._tunnel_scheme = scheme + + def connect(self) -> None: + self.sock = self._new_conn() + if self._tunnel_host: + # If we're tunneling it means we're connected to our proxy. + self._has_connected_to_proxy = True - def _prepare_conn(self, conn): - self.sock = conn - if self._is_using_tunnel(): # TODO: Fix tunnel so it doesn't depend on self.sock state. - self._tunnel() - # Mark this connection as not reusable - self.auto_open = 0 + self._tunnel() # type: ignore[attr-defined] + + # If there's a proxy to be connected to we are fully connected. + # This is set twice (once above and here) due to forwarding proxies + # not using tunnelling. + self._has_connected_to_proxy = bool(self.proxy) + + @property + def is_closed(self) -> bool: + return self.sock is None - def connect(self): - conn = self._new_conn() - self._prepare_conn(conn) + @property + def is_connected(self) -> bool: + if self.sock is None: + return False + return not wait_for_read(self.sock, timeout=0.0) + + @property + def has_connected_to_proxy(self) -> bool: + return self._has_connected_to_proxy - def putrequest(self, method, url, *args, **kwargs): + def close(self) -> None: + try: + super().close() + finally: + # Reset all stateful properties so connection + # can be re-used without leaking prior configs. + self.sock = None + self.is_verified = False + self.proxy_is_verified = None + self._has_connected_to_proxy = False + self._response_options = None + self._tunnel_host = None + self._tunnel_port = None + self._tunnel_scheme = None + + def putrequest( + self, + method: str, + url: str, + skip_host: bool = False, + skip_accept_encoding: bool = False, + ) -> None: """""" # Empty docstring because the indentation of CPython's implementation # is broken but we don't want this method in our documentation. match = _CONTAINS_CONTROL_CHAR_RE.search(method) if match: raise ValueError( - "Method cannot contain non-token characters %r (found at least %r)" - % (method, match.group()) + f"Method cannot contain non-token characters {method!r} (found at least {match.group()!r})" ) - return _HTTPConnection.putrequest(self, method, url, *args, **kwargs) + return super().putrequest( + method, url, skip_host=skip_host, skip_accept_encoding=skip_accept_encoding + ) - def putheader(self, header, *values): + def putheader(self, header: str, *values: str) -> None: """""" - if SKIP_HEADER not in values: - _HTTPConnection.putheader(self, header, *values) - elif six.ensure_str(header.lower()) not in SKIPPABLE_HEADERS: + if not any(isinstance(v, str) and v == SKIP_HEADER for v in values): + super().putheader(header, *values) + elif to_str(header.lower()) not in SKIPPABLE_HEADERS: + skippable_headers = "', '".join( + [str.title(header) for header in sorted(SKIPPABLE_HEADERS)] + ) raise ValueError( - "urllib3.util.SKIP_HEADER only supports '%s'" - % ("', '".join(map(str.title, sorted(SKIPPABLE_HEADERS))),) + f"urllib3.util.SKIP_HEADER only supports '{skippable_headers}'" ) - def request(self, method, url, body=None, headers=None): + # `request` method's signature intentionally violates LSP. + # urllib3's API is different from `http.client.HTTPConnection` and the subclassing is only incidental. + def request( # type: ignore[override] + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + *, + chunked: bool = False, + preload_content: bool = True, + decode_content: bool = True, + enforce_content_length: bool = True, + ) -> None: + # Update the inner socket's timeout value to send the request. + # This only triggers if the connection is re-used. + if self.sock is not None: + self.sock.settimeout(self.timeout) + + # Store these values to be fed into the HTTPResponse + # object later. TODO: Remove this in favor of a real + # HTTP lifecycle mechanism. + + # We have to store these before we call .request() + # because sometimes we can still salvage a response + # off the wire even if we aren't able to completely + # send the request body. + self._response_options = _ResponseOptions( + request_method=method, + request_url=url, + preload_content=preload_content, + decode_content=decode_content, + enforce_content_length=enforce_content_length, + ) + if headers is None: headers = {} - else: - # Avoid modifying the headers passed into .request() - headers = headers.copy() - if "user-agent" not in (k.lower() for k in headers): - headers["User-Agent"] = _get_default_user_agent() - super(HTTPConnection, self).request(method, url, body=body, headers=headers) - - def request_chunked(self, method, url, body=None, headers=None): - """ - Alternative to the common request method, which sends the - body with chunked encoding and not as one block - """ - headers = headers or {} - header_keys = set([six.ensure_str(k.lower()) for k in headers]) + header_keys = frozenset(to_str(k.lower()) for k in headers) skip_accept_encoding = "accept-encoding" in header_keys skip_host = "host" in header_keys self.putrequest( method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host ) + + # Transform the body into an iterable of sendall()-able chunks + # and detect if an explicit Content-Length is doable. + chunks_and_cl = body_to_chunks(body, method=method, blocksize=self.blocksize) + chunks = chunks_and_cl.chunks + content_length = chunks_and_cl.content_length + + # When chunked is explicit set to 'True' we respect that. + if chunked: + if "transfer-encoding" not in header_keys: + self.putheader("Transfer-Encoding", "chunked") + else: + # Detect whether a framing mechanism is already in use. If so + # we respect that value, otherwise we pick chunked vs content-length + # depending on the type of 'body'. + if "content-length" in header_keys: + chunked = False + elif "transfer-encoding" in header_keys: + chunked = True + + # Otherwise we go off the recommendation of 'body_to_chunks()'. + else: + chunked = False + if content_length is None: + if chunks is not None: + chunked = True + self.putheader("Transfer-Encoding", "chunked") + else: + self.putheader("Content-Length", str(content_length)) + + # Now that framing headers are out of the way we send all the other headers. if "user-agent" not in header_keys: self.putheader("User-Agent", _get_default_user_agent()) for header, value in headers.items(): self.putheader(header, value) - if "transfer-encoding" not in headers: - self.putheader("Transfer-Encoding", "chunked") self.endheaders() - if body is not None: - stringish_types = six.string_types + (bytes,) - if isinstance(body, stringish_types): - body = (body,) - for chunk in body: + # If we're given a body we start sending that in chunks. + if chunks is not None: + for chunk in chunks: + # Sending empty chunks isn't allowed for TE: chunked + # as it indicates the end of the body. if not chunk: continue - if not isinstance(chunk, bytes): - chunk = chunk.encode("utf8") - len_str = hex(len(chunk))[2:] - to_send = bytearray(len_str.encode()) - to_send += b"\r\n" - to_send += chunk - to_send += b"\r\n" - self.send(to_send) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + if chunked: + self.send(b"%x\r\n%b\r\n" % (len(chunk), chunk)) + else: + self.send(chunk) + + # Regardless of whether we have a body or not, if we're in + # chunked mode we want to send an explicit empty chunk. + if chunked: + self.send(b"0\r\n\r\n") + + def request_chunked( + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + ) -> None: + """ + Alternative to the common request method, which sends the + body with chunked encoding and not as one block + """ + warnings.warn( + "HTTPConnection.request_chunked() is deprecated and will be removed " + "in urllib3 v2.1.0. Instead use HTTPConnection.request(..., chunked=True).", + category=DeprecationWarning, + stacklevel=2, + ) + self.request(method, url, body=body, headers=headers, chunked=True) + + def getresponse( # type: ignore[override] + self, + ) -> HTTPResponse: + """ + Get the response from the server. + + If the HTTPConnection is in the correct state, returns an instance of HTTPResponse or of whatever object is returned by the response_class variable. + + If a request has not been sent or if a previous response has not be handled, ResponseNotReady is raised. If the HTTP response indicates that the connection should be closed, then it will be closed before the response is returned. When the connection is closed, the underlying socket is closed. + """ + # Raise the same error as http.client.HTTPConnection + if self._response_options is None: + raise ResponseNotReady() + + # Reset this attribute for being used again. + resp_options = self._response_options + self._response_options = None - # After the if clause, to always have a closed body - self.send(b"0\r\n\r\n") + # Since the connection's timeout value may have been updated + # we need to set the timeout on the socket. + self.sock.settimeout(self.timeout) + + # This is needed here to avoid circular import errors + from .response import HTTPResponse + + # Get the response from http.client.HTTPConnection + httplib_response = super().getresponse() + + try: + assert_header_parsing(httplib_response.msg) + except (HeaderParsingError, TypeError) as hpe: + log.warning( + "Failed to parse headers (url=%s): %s", + _url_from_connection(self, resp_options.request_url), + hpe, + exc_info=True, + ) + + headers = HTTPHeaderDict(httplib_response.msg.items()) + + response = HTTPResponse( + body=httplib_response, + headers=headers, + status=httplib_response.status, + version=httplib_response.version, + reason=httplib_response.reason, + preload_content=resp_options.preload_content, + decode_content=resp_options.decode_content, + original_response=httplib_response, + enforce_content_length=resp_options.enforce_content_length, + request_method=resp_options.request_method, + request_url=resp_options.request_url, + ) + return response class HTTPSConnection(HTTPConnection): @@ -278,57 +487,100 @@ class HTTPSConnection(HTTPConnection): socket by means of :py:func:`urllib3.util.ssl_wrap_socket`. """ - default_port = port_by_scheme["https"] + default_port = port_by_scheme["https"] # type: ignore[misc] - cert_reqs = None - ca_certs = None - ca_cert_dir = None - ca_cert_data = None - ssl_version = None - assert_fingerprint = None - tls_in_tls_required = False + cert_reqs: int | str | None = None + ca_certs: str | None = None + ca_cert_dir: str | None = None + ca_cert_data: None | str | bytes = None + ssl_version: int | str | None = None + ssl_minimum_version: int | None = None + ssl_maximum_version: int | None = None + assert_fingerprint: str | None = None def __init__( self, - host, - port=None, - key_file=None, - cert_file=None, - key_password=None, - strict=None, - timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - ssl_context=None, - server_hostname=None, - **kw - ): - - HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw) + host: str, + port: int | None = None, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 8192, + socket_options: None + | (connection._TYPE_SOCKET_OPTIONS) = HTTPConnection.default_socket_options, + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + cert_reqs: int | str | None = None, + assert_hostname: None | str | Literal[False] = None, + assert_fingerprint: str | None = None, + server_hostname: str | None = None, + ssl_context: ssl.SSLContext | None = None, + ca_certs: str | None = None, + ca_cert_dir: str | None = None, + ca_cert_data: None | str | bytes = None, + ssl_minimum_version: int | None = None, + ssl_maximum_version: int | None = None, + ssl_version: int | str | None = None, # Deprecated + cert_file: str | None = None, + key_file: str | None = None, + key_password: str | None = None, + ) -> None: + super().__init__( + host, + port=port, + timeout=timeout, + source_address=source_address, + blocksize=blocksize, + socket_options=socket_options, + proxy=proxy, + proxy_config=proxy_config, + ) self.key_file = key_file self.cert_file = cert_file self.key_password = key_password self.ssl_context = ssl_context self.server_hostname = server_hostname + self.assert_hostname = assert_hostname + self.assert_fingerprint = assert_fingerprint + self.ssl_version = ssl_version + self.ssl_minimum_version = ssl_minimum_version + self.ssl_maximum_version = ssl_maximum_version + self.ca_certs = ca_certs and os.path.expanduser(ca_certs) + self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) + self.ca_cert_data = ca_cert_data - # Required property for Google AppEngine 1.9.0 which otherwise causes - # HTTPS requests to go out as HTTP. (See Issue #356) - self._protocol = "https" + # cert_reqs depends on ssl_context so calculate last. + if cert_reqs is None: + if self.ssl_context is not None: + cert_reqs = self.ssl_context.verify_mode + else: + cert_reqs = resolve_cert_reqs(None) + self.cert_reqs = cert_reqs def set_cert( self, - key_file=None, - cert_file=None, - cert_reqs=None, - key_password=None, - ca_certs=None, - assert_hostname=None, - assert_fingerprint=None, - ca_cert_dir=None, - ca_cert_data=None, - ): + key_file: str | None = None, + cert_file: str | None = None, + cert_reqs: int | str | None = None, + key_password: str | None = None, + ca_certs: str | None = None, + assert_hostname: None | str | Literal[False] = None, + assert_fingerprint: str | None = None, + ca_cert_dir: str | None = None, + ca_cert_data: None | str | bytes = None, + ) -> None: """ This method should only be called once, before the connection is used. """ + warnings.warn( + "HTTPSConnection.set_cert() is deprecated and will be removed " + "in urllib3 v2.1.0. Instead provide the parameters to the " + "HTTPSConnection constructor.", + category=DeprecationWarning, + stacklevel=2, + ) + # If cert_reqs is not provided we'll assume CERT_REQUIRED unless we also # have an SSLContext object in which case we'll use its verify_mode. if cert_reqs is None: @@ -347,29 +599,26 @@ def set_cert( self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) self.ca_cert_data = ca_cert_data - def connect(self): - # Add certificate verification - conn = self._new_conn() - hostname = self.host + def connect(self) -> None: + sock: socket.socket | ssl.SSLSocket + self.sock = sock = self._new_conn() + server_hostname: str = self.host tls_in_tls = False - if self._is_using_tunnel(): - if self.tls_in_tls_required: - conn = self._connect_tls_proxy(hostname, conn) + # Do we need to establish a tunnel? + if self._tunnel_host is not None: + # We're tunneling to an HTTPS origin so need to do TLS-in-TLS. + if self._tunnel_scheme == "https": + self.sock = sock = self._connect_tls_proxy(self.host, sock) tls_in_tls = True - self.sock = conn - - # Calls self._set_hostport(), so self.host is - # self._tunnel_host below. - self._tunnel() - # Mark this connection as not reusable - self.auto_open = 0 + # If we're tunneling it means we're connected to our proxy. + self._has_connected_to_proxy = True + self._tunnel() # type: ignore[attr-defined] # Override the host with the one we're requesting data from. - hostname = self._tunnel_host + server_hostname = self._tunnel_host - server_hostname = hostname if self.server_hostname is not None: server_hostname = self.server_hostname @@ -377,134 +626,210 @@ def connect(self): if is_time_off: warnings.warn( ( - "System time is way off (before {0}). This will probably " + f"System time is way off (before {RECENT_DATE}). This will probably " "lead to SSL verification errors" - ).format(RECENT_DATE), + ), SystemTimeWarning, ) - # Wrap socket using verification with the root certs in - # trusted_root_certs - default_ssl_context = False - if self.ssl_context is None: - default_ssl_context = True - self.ssl_context = create_urllib3_context( - ssl_version=resolve_ssl_version(self.ssl_version), - cert_reqs=resolve_cert_reqs(self.cert_reqs), - ) - - context = self.ssl_context - context.verify_mode = resolve_cert_reqs(self.cert_reqs) - - # Try to load OS default certs if none are given. - # Works well on Windows (requires Python3.4+) - if ( - not self.ca_certs - and not self.ca_cert_dir - and not self.ca_cert_data - and default_ssl_context - and hasattr(context, "load_default_certs") - ): - context.load_default_certs() - - self.sock = ssl_wrap_socket( - sock=conn, - keyfile=self.key_file, - certfile=self.cert_file, - key_password=self.key_password, + sock_and_verified = _ssl_wrap_socket_and_match_hostname( + sock=sock, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ssl_minimum_version=self.ssl_minimum_version, + ssl_maximum_version=self.ssl_maximum_version, ca_certs=self.ca_certs, ca_cert_dir=self.ca_cert_dir, ca_cert_data=self.ca_cert_data, + cert_file=self.cert_file, + key_file=self.key_file, + key_password=self.key_password, server_hostname=server_hostname, - ssl_context=context, + ssl_context=self.ssl_context, tls_in_tls=tls_in_tls, + assert_hostname=self.assert_hostname, + assert_fingerprint=self.assert_fingerprint, ) + self.sock = sock_and_verified.socket + self.is_verified = sock_and_verified.is_verified - # If we're using all defaults and the connection - # is TLSv1 or TLSv1.1 we throw a DeprecationWarning - # for the host. - if ( - default_ssl_context - and self.ssl_version is None - and hasattr(self.sock, "version") - and self.sock.version() in {"TLSv1", "TLSv1.1"} - ): - warnings.warn( - "Negotiating TLSv1/TLSv1.1 by default is deprecated " - "and will be disabled in urllib3 v2.0.0. Connecting to " - "'%s' with '%s' can be enabled by explicitly opting-in " - "with 'ssl_version'" % (self.host, self.sock.version()), - DeprecationWarning, - ) + # If there's a proxy to be connected to we are fully connected. + # This is set twice (once above and here) due to forwarding proxies + # not using tunnelling. + self._has_connected_to_proxy = bool(self.proxy) - if self.assert_fingerprint: - assert_fingerprint( - self.sock.getpeercert(binary_form=True), self.assert_fingerprint - ) - elif ( - context.verify_mode != ssl.CERT_NONE - and not getattr(context, "check_hostname", False) - and self.assert_hostname is not False - ): - # While urllib3 attempts to always turn off hostname matching from - # the TLS library, this cannot always be done. So we check whether - # the TLS Library still thinks it's matching hostnames. - cert = self.sock.getpeercert() - if not cert.get("subjectAltName", ()): - warnings.warn( - ( - "Certificate for {0} has no `subjectAltName`, falling back to check for a " - "`commonName` for now. This feature is being removed by major browsers and " - "deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 " - "for details.)".format(hostname) - ), - SubjectAltNameWarning, - ) - _match_hostname(cert, self.assert_hostname or server_hostname) - - self.is_verified = ( - context.verify_mode == ssl.CERT_REQUIRED - or self.assert_fingerprint is not None - ) - - def _connect_tls_proxy(self, hostname, conn): + def _connect_tls_proxy(self, hostname: str, sock: socket.socket) -> ssl.SSLSocket: """ Establish a TLS connection to the proxy using the provided SSL context. """ - proxy_config = self.proxy_config + # `_connect_tls_proxy` is called when self._tunnel_host is truthy. + proxy_config = typing.cast(ProxyConfig, self.proxy_config) ssl_context = proxy_config.ssl_context - if ssl_context: - # If the user provided a proxy context, we assume CA and client - # certificates have already been set - return ssl_wrap_socket( - sock=conn, - server_hostname=hostname, - ssl_context=ssl_context, - ) - - ssl_context = create_proxy_ssl_context( - self.ssl_version, - self.cert_reqs, - self.ca_certs, - self.ca_cert_dir, - self.ca_cert_data, - ) - - # If no cert was provided, use only the default options for server - # certificate validation - return ssl_wrap_socket( - sock=conn, + sock_and_verified = _ssl_wrap_socket_and_match_hostname( + sock, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ssl_minimum_version=self.ssl_minimum_version, + ssl_maximum_version=self.ssl_maximum_version, ca_certs=self.ca_certs, ca_cert_dir=self.ca_cert_dir, ca_cert_data=self.ca_cert_data, server_hostname=hostname, ssl_context=ssl_context, + assert_hostname=proxy_config.assert_hostname, + assert_fingerprint=proxy_config.assert_fingerprint, + # Features that aren't implemented for proxies yet: + cert_file=None, + key_file=None, + key_password=None, + tls_in_tls=False, + ) + self.proxy_is_verified = sock_and_verified.is_verified + return sock_and_verified.socket # type: ignore[return-value] + + +class _WrappedAndVerifiedSocket(typing.NamedTuple): + """ + Wrapped socket and whether the connection is + verified after the TLS handshake + """ + + socket: ssl.SSLSocket | SSLTransport + is_verified: bool + + +def _ssl_wrap_socket_and_match_hostname( + sock: socket.socket, + *, + cert_reqs: None | str | int, + ssl_version: None | str | int, + ssl_minimum_version: int | None, + ssl_maximum_version: int | None, + cert_file: str | None, + key_file: str | None, + key_password: str | None, + ca_certs: str | None, + ca_cert_dir: str | None, + ca_cert_data: None | str | bytes, + assert_hostname: None | str | Literal[False], + assert_fingerprint: str | None, + server_hostname: str | None, + ssl_context: ssl.SSLContext | None, + tls_in_tls: bool = False, +) -> _WrappedAndVerifiedSocket: + """Logic for constructing an SSLContext from all TLS parameters, passing + that down into ssl_wrap_socket, and then doing certificate verification + either via hostname or fingerprint. This function exists to guarantee + that both proxies and targets have the same behavior when connecting via TLS. + """ + default_ssl_context = False + if ssl_context is None: + default_ssl_context = True + context = create_urllib3_context( + ssl_version=resolve_ssl_version(ssl_version), + ssl_minimum_version=ssl_minimum_version, + ssl_maximum_version=ssl_maximum_version, + cert_reqs=resolve_cert_reqs(cert_reqs), + ) + else: + context = ssl_context + + context.verify_mode = resolve_cert_reqs(cert_reqs) + + # In some cases, we want to verify hostnames ourselves + if ( + # `ssl` can't verify fingerprints or alternate hostnames + assert_fingerprint + or assert_hostname + # We still support OpenSSL 1.0.2, which prevents us from verifying + # hostnames easily: https://github.com/pyca/pyopenssl/pull/933 + or ssl_.IS_PYOPENSSL + or not ssl_.HAS_NEVER_CHECK_COMMON_NAME + ): + context.check_hostname = False + + # Try to load OS default certs if none are given. + # We need to do the hasattr() check for our custom + # pyOpenSSL and SecureTransport SSLContext objects + # because neither support load_default_certs(). + if ( + not ca_certs + and not ca_cert_dir + and not ca_cert_data + and default_ssl_context + and hasattr(context, "load_default_certs") + ): + context.load_default_certs() + + # Ensure that IPv6 addresses are in the proper format and don't have a + # scope ID. Python's SSL module fails to recognize scoped IPv6 addresses + # and interprets them as DNS hostnames. + if server_hostname is not None: + normalized = server_hostname.strip("[]") + if "%" in normalized: + normalized = normalized[: normalized.rfind("%")] + if is_ipaddress(normalized): + server_hostname = normalized + + ssl_sock = ssl_wrap_socket( + sock=sock, + keyfile=key_file, + certfile=cert_file, + key_password=key_password, + ca_certs=ca_certs, + ca_cert_dir=ca_cert_dir, + ca_cert_data=ca_cert_data, + server_hostname=server_hostname, + ssl_context=context, + tls_in_tls=tls_in_tls, + ) + + if assert_fingerprint: + _assert_fingerprint(ssl_sock.getpeercert(binary_form=True), assert_fingerprint) + elif ( + context.verify_mode != ssl.CERT_NONE + and not context.check_hostname + and assert_hostname is not False + ): + cert: _TYPE_PEER_CERT_RET_DICT = ssl_sock.getpeercert() # type: ignore[assignment] + + # Need to signal to our match_hostname whether to use 'commonName' or not. + # If we're using our own constructed SSLContext we explicitly set 'False' + # because PyPy hard-codes 'True' from SSLContext.hostname_checks_common_name. + if default_ssl_context: + hostname_checks_common_name = False + else: + hostname_checks_common_name = ( + getattr(context, "hostname_checks_common_name", False) or False + ) + + _match_hostname( + cert, + assert_hostname or server_hostname, # type: ignore[arg-type] + hostname_checks_common_name, ) + return _WrappedAndVerifiedSocket( + socket=ssl_sock, + is_verified=context.verify_mode == ssl.CERT_REQUIRED + or bool(assert_fingerprint), + ) + + +def _match_hostname( + cert: _TYPE_PEER_CERT_RET_DICT | None, + asserted_hostname: str, + hostname_checks_common_name: bool = False, +) -> None: + # Our upstream implementation of ssl.match_hostname() + # only applies this normalization to IP addresses so it doesn't + # match DNS SANs so we do the same thing! + stripped_hostname = asserted_hostname.strip("[]") + if is_ipaddress(stripped_hostname): + asserted_hostname = stripped_hostname -def _match_hostname(cert, asserted_hostname): try: - match_hostname(cert, asserted_hostname) + match_hostname(cert, asserted_hostname, hostname_checks_common_name) except CertificateError as e: log.warning( "Certificate did not match expected hostname: %s. Certificate: %s", @@ -513,22 +838,54 @@ def _match_hostname(cert, asserted_hostname): ) # Add cert to exception and reraise so client code can inspect # the cert when catching the exception, if they want to - e._peer_cert = cert + e._peer_cert = cert # type: ignore[attr-defined] raise -def _get_default_user_agent(): - return "python-urllib3/%s" % __version__ - - -class DummyConnection(object): +def _wrap_proxy_error(err: Exception, proxy_scheme: str | None) -> ProxyError: + # Look for the phrase 'wrong version number', if found + # then we should warn the user that we're very sure that + # this proxy is HTTP-only and they have a configuration issue. + error_normalized = " ".join(re.split("[^a-z]", str(err).lower())) + is_likely_http_proxy = ( + "wrong version number" in error_normalized + or "unknown protocol" in error_normalized + ) + http_proxy_warning = ( + ". Your proxy appears to only use HTTP and not HTTPS, " + "try changing your proxy URL to be HTTP. See: " + "https://urllib3.readthedocs.io/en/latest/advanced-usage.html" + "#https-proxy-error-http-proxy" + ) + new_err = ProxyError( + f"Unable to connect to proxy" + f"{http_proxy_warning if is_likely_http_proxy and proxy_scheme == 'https' else ''}", + err, + ) + new_err.__cause__ = err + return new_err + + +def _get_default_user_agent() -> str: + return f"python-urllib3/{__version__}" + + +class DummyConnection: """Used to detect a failed ConnectionCls import.""" - pass - if not ssl: - HTTPSConnection = DummyConnection # noqa: F811 + HTTPSConnection = DummyConnection # type: ignore[misc, assignment] # noqa: F811 VerifiedHTTPSConnection = HTTPSConnection + + +def _url_from_connection( + conn: HTTPConnection | HTTPSConnection, path: str | None = None +) -> str: + """Returns the URL from a given connection. This is mainly used for testing and logging.""" + + scheme = "https" if isinstance(conn, HTTPSConnection) else "http" + + return Url(scheme=scheme, host=conn.host, port=conn.port, path=path).url diff --git a/src/urllib3/connectionpool.py b/src/urllib3/connectionpool.py index 4708c5bfc7..2479405bd5 100644 --- a/src/urllib3/connectionpool.py +++ b/src/urllib3/connectionpool.py @@ -1,13 +1,17 @@ -from __future__ import absolute_import +from __future__ import annotations import errno import logging -import socket +import queue import sys +import typing import warnings -from socket import error as SocketError +import weakref from socket import timeout as SocketTimeout +from types import TracebackType +from ._base_connection import _TYPE_BODY +from ._request_methods import RequestMethods from .connection import ( BaseSSLError, BrokenPipeError, @@ -15,13 +19,14 @@ HTTPConnection, HTTPException, HTTPSConnection, - VerifiedHTTPSConnection, - port_by_scheme, + ProxyConfig, + _wrap_proxy_error, ) +from .connection import port_by_scheme as port_by_scheme from .exceptions import ( ClosedPoolError, EmptyPoolError, - HeaderParsingError, + FullPoolError, HostChangedError, InsecureRequestWarning, LocationValueError, @@ -33,31 +38,34 @@ SSLError, TimeoutError, ) -from .packages import six -from .packages.six.moves import queue -from .packages.ssl_match_hostname import CertificateError -from .request import RequestMethods -from .response import HTTPResponse +from .response import BaseHTTPResponse from .util.connection import is_connection_dropped from .util.proxy import connection_requires_http_tunnel -from .util.queue import LifoQueue -from .util.request import set_file_position -from .util.response import assert_header_parsing +from .util.request import _TYPE_BODY_POSITION, set_file_position from .util.retry import Retry -from .util.timeout import Timeout +from .util.ssl_match_hostname import CertificateError +from .util.timeout import _DEFAULT_TIMEOUT, _TYPE_DEFAULT, Timeout from .util.url import Url, _encode_target from .util.url import _normalize_host as normalize_host -from .util.url import get_host, parse_url +from .util.url import parse_url +from .util.util import to_str -xrange = six.moves.xrange +if typing.TYPE_CHECKING: + import ssl + + from typing_extensions import Literal + + from ._base_connection import BaseHTTPConnection, BaseHTTPSConnection log = logging.getLogger(__name__) -_Default = object() +_TYPE_TIMEOUT = typing.Union[Timeout, float, _TYPE_DEFAULT, None] + +_SelfT = typing.TypeVar("_SelfT") # Pool objects -class ConnectionPool(object): +class ConnectionPool: """ Base class for all connection pools, such as :class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`. @@ -68,33 +76,42 @@ class ConnectionPool(object): target URIs. """ - scheme = None - QueueCls = LifoQueue + scheme: str | None = None + QueueCls = queue.LifoQueue - def __init__(self, host, port=None): + def __init__(self, host: str, port: int | None = None) -> None: if not host: raise LocationValueError("No host specified.") self.host = _normalize_host(host, scheme=self.scheme) - self._proxy_host = host.lower() self.port = port - def __str__(self): - return "%s(host=%r, port=%r)" % (type(self).__name__, self.host, self.port) + # This property uses 'normalize_host()' (not '_normalize_host()') + # to avoid removing square braces around IPv6 addresses. + # This value is sent to `HTTPConnection.set_tunnel()` if called + # because square braces are required for HTTP CONNECT tunneling. + self._tunnel_host = normalize_host(host, scheme=self.scheme).lower() - def __enter__(self): + def __str__(self) -> str: + return f"{type(self).__name__}(host={self.host!r}, port={self.port!r})" + + def __enter__(self: _SelfT) -> _SelfT: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: self.close() # Return False to re-raise any potential exceptions return False - def close(self): + def close(self) -> None: """ Close all pooled connections and disable the pool. """ - pass # This is taken from http://hg.python.org/cpython/file/7aaba721ebc0/Lib/socket.py#l252 @@ -113,14 +130,6 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): Port used for this HTTP Connection (None is equivalent to 80), passed into :class:`http.client.HTTPConnection`. - :param strict: - Causes BadStatusLine to be raised if the status line can't be parsed - as a valid HTTP/1.0 or 1.1 status line, passed into - :class:`http.client.HTTPConnection`. - - .. note:: - Only works in Python 2. This parameter is ignored in Python 3. - :param timeout: Socket timeout in seconds for each individual connection. This can be a float or integer, which sets the timeout for the HTTP request, @@ -162,29 +171,27 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): """ scheme = "http" - ConnectionCls = HTTPConnection - ResponseCls = HTTPResponse + ConnectionCls: ( + type[BaseHTTPConnection] | type[BaseHTTPSConnection] + ) = HTTPConnection def __init__( self, - host, - port=None, - strict=False, - timeout=Timeout.DEFAULT_TIMEOUT, - maxsize=1, - block=False, - headers=None, - retries=None, - _proxy=None, - _proxy_headers=None, - _proxy_config=None, - **conn_kw + host: str, + port: int | None = None, + timeout: _TYPE_TIMEOUT | None = _DEFAULT_TIMEOUT, + maxsize: int = 1, + block: bool = False, + headers: typing.Mapping[str, str] | None = None, + retries: Retry | bool | int | None = None, + _proxy: Url | None = None, + _proxy_headers: typing.Mapping[str, str] | None = None, + _proxy_config: ProxyConfig | None = None, + **conn_kw: typing.Any, ): ConnectionPool.__init__(self, host, port) RequestMethods.__init__(self, headers) - self.strict = strict - if not isinstance(timeout, Timeout): timeout = Timeout.from_float(timeout) @@ -194,7 +201,7 @@ def __init__( self.timeout = timeout self.retries = retries - self.pool = self.QueueCls(maxsize) + self.pool: queue.LifoQueue[typing.Any] | None = self.QueueCls(maxsize) self.block = block self.proxy = _proxy @@ -202,7 +209,7 @@ def __init__( self.proxy_config = _proxy_config # Fill the queue up so that doing get() on it will block properly - for _ in xrange(maxsize): + for _ in range(maxsize): self.pool.put(None) # These are mostly for testing and debugging purposes. @@ -219,7 +226,17 @@ def __init__( self.conn_kw["proxy"] = self.proxy self.conn_kw["proxy_config"] = self.proxy_config - def _new_conn(self): + # Do not pass 'self' as callback to 'finalize'. + # Then the 'finalize' would keep an endless living (leak) to self. + # By just passing a reference to the pool allows the garbage collector + # to free self if nobody else has a reference to it. + pool = self.pool + + # Close all the HTTPConnections in the pool before the + # HTTPConnectionPool object is garbage collected. + weakref.finalize(self, _close_pool_connections, pool) + + def _new_conn(self) -> BaseHTTPConnection: """ Return a fresh :class:`HTTPConnection`. """ @@ -235,12 +252,11 @@ def _new_conn(self): host=self.host, port=self.port, timeout=self.timeout.connect_timeout, - strict=self.strict, - **self.conn_kw + **self.conn_kw, ) return conn - def _get_conn(self, timeout=None): + def _get_conn(self, timeout: float | None = None) -> BaseHTTPConnection: """ Get a connection. Will return a pooled connection if one is available. @@ -253,33 +269,32 @@ def _get_conn(self, timeout=None): :prop:`.block` is ``True``. """ conn = None + + if self.pool is None: + raise ClosedPoolError(self, "Pool is closed.") + try: conn = self.pool.get(block=self.block, timeout=timeout) except AttributeError: # self.pool is None - raise ClosedPoolError(self, "Pool is closed.") + raise ClosedPoolError(self, "Pool is closed.") from None # Defensive: except queue.Empty: if self.block: raise EmptyPoolError( self, - "Pool reached maximum size and no more connections are allowed.", - ) + "Pool is empty and a new connection can't be opened due to blocking mode.", + ) from None pass # Oh well, we'll create a new connection then # If this is a persistent connection, check if it got disconnected if conn and is_connection_dropped(conn): log.debug("Resetting dropped connection: %s", self.host) conn.close() - if getattr(conn, "auto_open", 1) == 0: - # This is a proxied connection that has been mutated by - # http.client._tunnel() and cannot be reused (since it would - # attempt to bypass the proxy) - conn = None return conn or self._new_conn() - def _put_conn(self, conn): + def _put_conn(self, conn: BaseHTTPConnection | None) -> None: """ Put a connection back into the pool. @@ -293,33 +308,47 @@ def _put_conn(self, conn): If the pool is closed, then the connection will be closed and discarded. """ - try: - self.pool.put(conn, block=False) - return # Everything is dandy, done. - except AttributeError: - # self.pool is None. - pass - except queue.Full: - # This should never happen if self.block == True - log.warning("Connection pool is full, discarding connection: %s", self.host) + if self.pool is not None: + try: + self.pool.put(conn, block=False) + return # Everything is dandy, done. + except AttributeError: + # self.pool is None. + pass + except queue.Full: + # Connection never got put back into the pool, close it. + if conn: + conn.close() + + if self.block: + # This should never happen if you got the conn from self._get_conn + raise FullPoolError( + self, + "Pool reached maximum size and no more connections are allowed.", + ) from None + + log.warning( + "Connection pool is full, discarding connection: %s. Connection pool size: %s", + self.host, + self.pool.qsize(), + ) # Connection never got put back into the pool, close it. if conn: conn.close() - def _validate_conn(self, conn): + def _validate_conn(self, conn: BaseHTTPConnection) -> None: """ Called right before a request is made, after the socket is created. """ - pass - def _prepare_proxy(self, conn): + def _prepare_proxy(self, conn: BaseHTTPConnection) -> None: # Nothing to do for HTTP connections. pass - def _get_timeout(self, timeout): - """ Helper that always returns a :class:`urllib3.util.Timeout` """ - if timeout is _Default: + def _get_timeout(self, timeout: _TYPE_TIMEOUT) -> Timeout: + """Helper that always returns a :class:`urllib3.util.Timeout`""" + if timeout is _DEFAULT_TIMEOUT: return self.timeout.clone() if isinstance(timeout, Timeout): @@ -329,34 +358,40 @@ def _get_timeout(self, timeout): # can be removed later return Timeout.from_float(timeout) - def _raise_timeout(self, err, url, timeout_value): + def _raise_timeout( + self, + err: BaseSSLError | OSError | SocketTimeout, + url: str, + timeout_value: _TYPE_TIMEOUT | None, + ) -> None: """Is the error actually a timeout? Will raise a ReadTimeout or pass""" if isinstance(err, SocketTimeout): raise ReadTimeoutError( - self, url, "Read timed out. (read timeout=%s)" % timeout_value - ) + self, url, f"Read timed out. (read timeout={timeout_value})" + ) from err - # See the above comment about EAGAIN in Python 3. In Python 2 we have - # to specifically catch it and throw the timeout error + # See the above comment about EAGAIN in Python 3. if hasattr(err, "errno") and err.errno in _blocking_errnos: raise ReadTimeoutError( - self, url, "Read timed out. (read timeout=%s)" % timeout_value - ) - - # Catch possible read timeouts thrown as SSL errors. If not the - # case, rethrow the original. We need to do this because of: - # http://bugs.python.org/issue10272 - if "timed out" in str(err) or "did not complete (read)" in str( - err - ): # Python < 2.7.4 - raise ReadTimeoutError( - self, url, "Read timed out. (read timeout=%s)" % timeout_value - ) + self, url, f"Read timed out. (read timeout={timeout_value})" + ) from err def _make_request( - self, conn, method, url, timeout=_Default, chunked=False, **httplib_request_kw - ): + self, + conn: BaseHTTPConnection, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + retries: Retry | None = None, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + chunked: bool = False, + response_conn: BaseHTTPConnection | None = None, + preload_content: bool = True, + decode_content: bool = True, + enforce_content_length: bool = True, + ) -> BaseHTTPResponse: """ Perform a request on a given urllib connection object taken from our pool. @@ -364,57 +399,127 @@ def _make_request( :param conn: a connection from one of our connection pools + :param method: + HTTP request method (such as GET, POST, PUT, etc.) + + :param url: + The URL to perform the request on. + + :param body: + Data to send in the request body, either :class:`str`, :class:`bytes`, + an iterable of :class:`str`/:class:`bytes`, or a file-like object. + + :param headers: + Dictionary of custom headers to send, such as User-Agent, + If-None-Match, etc. If None, pool headers are used. If provided, + these headers completely replace any pool-specific headers. + + :param retries: + Configure the number of retries to allow before raising a + :class:`~urllib3.exceptions.MaxRetryError` exception. + + Pass ``None`` to retry until you receive a response. Pass a + :class:`~urllib3.util.retry.Retry` object for fine-grained control + over different types of retries. + Pass an integer number to retry connection errors that many times, + but no other types of errors. Pass zero to never retry. + + If ``False``, then retries are disabled and any exception is raised + immediately. Also, instead of raising a MaxRetryError on redirects, + the redirect response will be returned. + + :type retries: :class:`~urllib3.util.retry.Retry`, False, or an int. + :param timeout: - Socket timeout in seconds for the request. This can be a - float or integer, which will set the same timeout value for - the socket connect and the socket read, or an instance of - :class:`urllib3.util.Timeout`, which gives you more fine-grained - control over your timeouts. + If specified, overrides the default timeout for this one + request. It may be a float (in seconds) or an instance of + :class:`urllib3.util.Timeout`. + + :param chunked: + If True, urllib3 will send the body using chunked transfer + encoding. Otherwise, urllib3 will send the body using the standard + content-length form. Defaults to False. + + :param response_conn: + Set this to ``None`` if you will handle releasing the connection or + set the connection to have the response release it. + + :param preload_content: + If True, the response's body will be preloaded during construction. + + :param decode_content: + If True, will attempt to decode the body based on the + 'content-encoding' header. + + :param enforce_content_length: + Enforce content length checking. Body returned by server must match + value of Content-Length header, if present. Otherwise, raise error. """ self.num_requests += 1 timeout_obj = self._get_timeout(timeout) timeout_obj.start_connect() - conn.timeout = timeout_obj.connect_timeout + conn.timeout = Timeout.resolve_default_timeout(timeout_obj.connect_timeout) - # Trigger any extra validation we need to do. try: - self._validate_conn(conn) - except (SocketTimeout, BaseSSLError) as e: - # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout. - self._raise_timeout(err=e, url=url, timeout_value=conn.timeout) - raise + # Trigger any extra validation we need to do. + try: + self._validate_conn(conn) + except (SocketTimeout, BaseSSLError) as e: + self._raise_timeout(err=e, url=url, timeout_value=conn.timeout) + raise + + # _validate_conn() starts the connection to an HTTPS proxy + # so we need to wrap errors with 'ProxyError' here too. + except ( + OSError, + NewConnectionError, + TimeoutError, + BaseSSLError, + CertificateError, + SSLError, + ) as e: + new_e: Exception = e + if isinstance(e, (BaseSSLError, CertificateError)): + new_e = SSLError(e) + # If the connection didn't successfully connect to it's proxy + # then there + if isinstance( + new_e, (OSError, NewConnectionError, TimeoutError, SSLError) + ) and (conn and conn.proxy and not conn.has_connected_to_proxy): + new_e = _wrap_proxy_error(new_e, conn.proxy.scheme) + raise new_e # conn.request() calls http.client.*.request, not the method in # urllib3.request. It also calls makefile (recv) on the socket. try: - if chunked: - conn.request_chunked(method, url, **httplib_request_kw) - else: - conn.request(method, url, **httplib_request_kw) + conn.request( + method, + url, + body=body, + headers=headers, + chunked=chunked, + preload_content=preload_content, + decode_content=decode_content, + enforce_content_length=enforce_content_length, + ) # We are swallowing BrokenPipeError (errno.EPIPE) since the server is # legitimately able to close the connection after sending a valid response. # With this behaviour, the received response is still readable. except BrokenPipeError: - # Python 3 pass - except IOError as e: - # Python 2 and macOS/Linux - # EPIPE and ESHUTDOWN are BrokenPipeError on Python 2, and EPROTOTYPE is needed on macOS + except OSError as e: + # MacOS/Linux + # EPROTOTYPE is needed on macOS # https://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ - if e.errno not in { - errno.EPIPE, - errno.ESHUTDOWN, - errno.EPROTOTYPE, - }: + if e.errno != errno.EPROTOTYPE: raise # Reset the timeout for the recv() on the socket read_timeout = timeout_obj.read_timeout - # App Engine doesn't have a sock attr - if getattr(conn, "sock", None): + if not conn.is_closed: # In Python 3 socket.py will catch EAGAIN and return None when you # try and read into the file pointer created by http.client, which # instead raises a BadStatusLine exception. Instead of catching @@ -422,33 +527,22 @@ def _make_request( # timeouts, check for a zero timeout before making the request. if read_timeout == 0: raise ReadTimeoutError( - self, url, "Read timed out. (read timeout=%s)" % read_timeout + self, url, f"Read timed out. (read timeout={read_timeout})" ) - if read_timeout is Timeout.DEFAULT_TIMEOUT: - conn.sock.settimeout(socket.getdefaulttimeout()) - else: # None or a value - conn.sock.settimeout(read_timeout) + conn.timeout = read_timeout # Receive the response from the server try: - try: - # Python 2.7, use buffering of HTTP responses - httplib_response = conn.getresponse(buffering=True) - except TypeError: - # Python 3 - try: - httplib_response = conn.getresponse() - except BaseException as e: - # Remove the TypeError from the exception chain in - # Python 3 (including for exceptions like SystemExit). - # Otherwise it looks like a bug in the code. - six.raise_from(e, None) - except (SocketTimeout, BaseSSLError, SocketError) as e: + response = conn.getresponse() + except (BaseSSLError, OSError) as e: self._raise_timeout(err=e, url=url, timeout_value=read_timeout) raise - # AppEngine doesn't have a version attr. - http_version = getattr(conn, "_http_vsn_str", "HTTP/?") + # Set properties that are used by the pooling layer. + response.retries = retries + response._connection = response_conn # type: ignore[attr-defined] + response._pool = self # type: ignore[attr-defined] + log.debug( '%s://%s:%s "%s %s %s" %s %s', self.scheme, @@ -456,27 +550,15 @@ def _make_request( self.port, method, url, - http_version, - httplib_response.status, - httplib_response.length, + # HTTP version + conn._http_vsn_str, # type: ignore[attr-defined] + response.status, + response.length_remaining, # type: ignore[attr-defined] ) - try: - assert_header_parsing(httplib_response.msg) - except (HeaderParsingError, TypeError) as hpe: # Platform-specific: Python 3 - log.warning( - "Failed to parse headers (url=%s): %s", - self._absolute_url(url), - hpe, - exc_info=True, - ) - - return httplib_response - - def _absolute_url(self, path): - return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url + return response - def close(self): + def close(self) -> None: """ Close all pooled connections and disable the pool. """ @@ -485,16 +567,10 @@ def close(self): # Disable access to the pool old_pool, self.pool = self.pool, None - try: - while True: - conn = old_pool.get(block=False) - if conn: - conn.close() - - except queue.Empty: - pass # Done. + # Close all the HTTPConnections in the pool. + _close_pool_connections(old_pool) - def is_same_host(self, url): + def is_same_host(self, url: str) -> bool: """ Check if the given ``url`` is a member of the same host as this connection pool. @@ -503,7 +579,8 @@ def is_same_host(self, url): return True # TODO: Add optional support for socket.gethostbyname checking. - scheme, host, port = get_host(url) + scheme, _, host, port, *_ = parse_url(url) + scheme = scheme or "http" if host is not None: host = _normalize_host(host, scheme=scheme) @@ -515,22 +592,24 @@ def is_same_host(self, url): return (scheme, host, port) == (self.scheme, self.host, self.port) - def urlopen( + def urlopen( # type: ignore[override] self, - method, - url, - body=None, - headers=None, - retries=None, - redirect=True, - assert_same_host=True, - timeout=_Default, - pool_timeout=None, - release_conn=None, - chunked=False, - body_pos=None, - **response_kw - ): + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + retries: Retry | bool | int | None = None, + redirect: bool = True, + assert_same_host: bool = True, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + pool_timeout: int | None = None, + release_conn: bool | None = None, + chunked: bool = False, + body_pos: _TYPE_BODY_POSITION | None = None, + preload_content: bool = True, + decode_content: bool = True, + **response_kw: typing.Any, + ) -> BaseHTTPResponse: """ Get a connection from the pool and perform an HTTP request. This is the lowest level call for making a request, so you'll need to specify all @@ -538,8 +617,8 @@ def urlopen( .. note:: - More commonly, it's appropriate to use a convenience method provided - by :class:`.RequestMethods`, such as :meth:`request`. + More commonly, it's appropriate to use a convenience method + such as :meth:`request`. .. note:: @@ -599,6 +678,13 @@ def urlopen( block for ``pool_timeout`` seconds and raise EmptyPoolError if no connection is available within the time period. + :param bool preload_content: + If True, the response's body will be preloaded into memory. + + :param bool decode_content: + If True, will attempt to decode the body based on the + 'content-encoding' header. + :param release_conn: If False, then the urlopen call will not release the connection back into the pool once a response is received (but will release if @@ -606,10 +692,10 @@ def urlopen( `preload_content=True`). This is useful if you're not preloading the response's content immediately. You will need to call ``r.release_conn()`` on the response ``r`` to return the connection - back into the pool. If None, it takes the value of - ``response_kw.get('preload_content', True)``. + back into the pool. If None, it takes the value of ``preload_content`` + which defaults to ``True``. - :param chunked: + :param bool chunked: If True, urllib3 will send the body using chunked transfer encoding. Otherwise, urllib3 will send the body using the standard content-length form. Defaults to False. @@ -618,12 +704,7 @@ def urlopen( Position to seek to in file-like body in the event of a retry or redirect. Typically this won't need to be set because urllib3 will auto-populate the value when needed. - - :param \\**response_kw: - Additional parameters are passed to - :meth:`urllib3.response.HTTPResponse.from_httplib` """ - parsed_url = parse_url(url) destination_scheme = parsed_url.scheme @@ -634,7 +715,7 @@ def urlopen( retries = Retry.from_int(retries, redirect=redirect, default=self.retries) if release_conn is None: - release_conn = response_kw.get("preload_content", True) + release_conn = preload_content # Check host if assert_same_host and not self.is_same_host(url): @@ -642,9 +723,9 @@ def urlopen( # Ensure that the URL we're connecting to is properly encoded if url.startswith("/"): - url = six.ensure_str(_encode_target(url)) + url = to_str(_encode_target(url)) else: - url = six.ensure_str(parsed_url.url) + url = to_str(parsed_url.url) conn = None @@ -667,8 +748,8 @@ def urlopen( # have to copy the headers dict so we can safely change it without those # changes being reflected in anyone else's copy. if not http_tunnel_required: - headers = headers.copy() - headers.update(self.proxy_headers) + headers = headers.copy() # type: ignore[attr-defined] + headers.update(self.proxy_headers) # type: ignore[union-attr] # Must keep the exception bound to a separate variable or else Python 3 # complains about UnboundLocalError. @@ -687,16 +768,26 @@ def urlopen( timeout_obj = self._get_timeout(timeout) conn = self._get_conn(timeout=pool_timeout) - conn.timeout = timeout_obj.connect_timeout + conn.timeout = timeout_obj.connect_timeout # type: ignore[assignment] - is_new_proxy_conn = self.proxy is not None and not getattr( - conn, "sock", None - ) - if is_new_proxy_conn and http_tunnel_required: - self._prepare_proxy(conn) + # Is this a closed/new connection that requires CONNECT tunnelling? + if self.proxy is not None and http_tunnel_required and conn.is_closed: + try: + self._prepare_proxy(conn) + except (BaseSSLError, OSError, SocketTimeout) as e: + self._raise_timeout( + err=e, url=self.proxy.url, timeout_value=conn.timeout + ) + raise - # Make the request on the httplib connection object. - httplib_response = self._make_request( + # If we're going to release the connection in ``finally:``, then + # the response doesn't need to know about the connection. Otherwise + # it will also try to release it and we'll have a double-release + # mess. + response_conn = conn if not release_conn else None + + # Make the request on the HTTPConnection object + response = self._make_request( conn, method, url, @@ -704,24 +795,11 @@ def urlopen( body=body, headers=headers, chunked=chunked, - ) - - # If we're going to release the connection in ``finally:``, then - # the response doesn't need to know about the connection. Otherwise - # it will also try to release it and we'll have a double-release - # mess. - response_conn = conn if not release_conn else None - - # Pass method to Response for length checking - response_kw["request_method"] = method - - # Import httplib's response into our own wrapper object - response = self.ResponseCls.from_httplib( - httplib_response, - pool=self, - connection=response_conn, retries=retries, - **response_kw + response_conn=response_conn, + preload_content=preload_content, + decode_content=decode_content, + **response_kw, ) # Everything went great! @@ -736,24 +814,35 @@ def urlopen( except ( TimeoutError, HTTPException, - SocketError, + OSError, ProtocolError, BaseSSLError, SSLError, CertificateError, + ProxyError, ) as e: # Discard the connection for these exceptions. It will be # replaced during the next _get_conn() call. clean_exit = False + new_e: Exception = e if isinstance(e, (BaseSSLError, CertificateError)): - e = SSLError(e) - elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy: - e = ProxyError("Cannot connect to proxy.", e) - elif isinstance(e, (SocketError, HTTPException)): - e = ProtocolError("Connection aborted.", e) + new_e = SSLError(e) + if isinstance( + new_e, + ( + OSError, + NewConnectionError, + TimeoutError, + SSLError, + HTTPException, + ), + ) and (conn and conn.proxy and not conn.has_connected_to_proxy): + new_e = _wrap_proxy_error(new_e, conn.proxy.scheme) + elif isinstance(new_e, (OSError, HTTPException)): + new_e = ProtocolError("Connection aborted.", new_e) retries = retries.increment( - method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2] + method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2] ) retries.sleep() @@ -766,7 +855,9 @@ def urlopen( # to throw the connection away unless explicitly told not to. # Close the connection, set the variable to None, and make sure # we put the None back in the pool to avoid leaking it. - conn = conn and conn.close() + if conn: + conn.close() + conn = None release_this_conn = True if release_this_conn: @@ -793,7 +884,9 @@ def urlopen( release_conn=release_conn, chunked=chunked, body_pos=body_pos, - **response_kw + preload_content=preload_content, + decode_content=decode_content, + **response_kw, ) # Handle redirect? @@ -826,11 +919,13 @@ def urlopen( release_conn=release_conn, chunked=chunked, body_pos=body_pos, - **response_kw + preload_content=preload_content, + decode_content=decode_content, + **response_kw, ) # Check if we should retry the HTTP response. - has_retry_after = bool(response.getheader("Retry-After")) + has_retry_after = bool(response.headers.get("Retry-After")) if retries.is_retry(method, response.status, has_retry_after): try: retries = retries.increment(method, url, response=response, _pool=self) @@ -856,7 +951,9 @@ def urlopen( release_conn=release_conn, chunked=chunked, body_pos=body_pos, - **response_kw + preload_content=preload_content, + decode_content=decode_content, + **response_kw, ) return response @@ -877,37 +974,35 @@ class HTTPSConnectionPool(HTTPConnectionPool): """ scheme = "https" - ConnectionCls = HTTPSConnection + ConnectionCls: type[BaseHTTPSConnection] = HTTPSConnection def __init__( self, - host, - port=None, - strict=False, - timeout=Timeout.DEFAULT_TIMEOUT, - maxsize=1, - block=False, - headers=None, - retries=None, - _proxy=None, - _proxy_headers=None, - key_file=None, - cert_file=None, - cert_reqs=None, - key_password=None, - ca_certs=None, - ssl_version=None, - assert_hostname=None, - assert_fingerprint=None, - ca_cert_dir=None, - **conn_kw - ): - - HTTPConnectionPool.__init__( - self, + host: str, + port: int | None = None, + timeout: _TYPE_TIMEOUT | None = _DEFAULT_TIMEOUT, + maxsize: int = 1, + block: bool = False, + headers: typing.Mapping[str, str] | None = None, + retries: Retry | bool | int | None = None, + _proxy: Url | None = None, + _proxy_headers: typing.Mapping[str, str] | None = None, + key_file: str | None = None, + cert_file: str | None = None, + cert_reqs: int | str | None = None, + key_password: str | None = None, + ca_certs: str | None = None, + ssl_version: int | str | None = None, + ssl_minimum_version: ssl.TLSVersion | None = None, + ssl_maximum_version: ssl.TLSVersion | None = None, + assert_hostname: str | Literal[False] | None = None, + assert_fingerprint: str | None = None, + ca_cert_dir: str | None = None, + **conn_kw: typing.Any, + ) -> None: + super().__init__( host, port, - strict, timeout, maxsize, block, @@ -915,7 +1010,7 @@ def __init__( retries, _proxy, _proxy_headers, - **conn_kw + **conn_kw, ) self.key_file = key_file @@ -925,47 +1020,29 @@ def __init__( self.ca_certs = ca_certs self.ca_cert_dir = ca_cert_dir self.ssl_version = ssl_version + self.ssl_minimum_version = ssl_minimum_version + self.ssl_maximum_version = ssl_maximum_version self.assert_hostname = assert_hostname self.assert_fingerprint = assert_fingerprint - def _prepare_conn(self, conn): - """ - Prepare the ``connection`` for :meth:`urllib3.util.ssl_wrap_socket` - and establish the tunnel if proxy is used. - """ - - if isinstance(conn, VerifiedHTTPSConnection): - conn.set_cert( - key_file=self.key_file, - key_password=self.key_password, - cert_file=self.cert_file, - cert_reqs=self.cert_reqs, - ca_certs=self.ca_certs, - ca_cert_dir=self.ca_cert_dir, - assert_hostname=self.assert_hostname, - assert_fingerprint=self.assert_fingerprint, - ) - conn.ssl_version = self.ssl_version - return conn - - def _prepare_proxy(self, conn): - """ - Establishes a tunnel connection through HTTP CONNECT. - - Tunnel connection is established early because otherwise httplib would - improperly set Host: header to proxy's IP:port. - """ - - conn.set_tunnel(self._proxy_host, self.port, self.proxy_headers) - - if self.proxy.scheme == "https": - conn.tls_in_tls_required = True + def _prepare_proxy(self, conn: HTTPSConnection) -> None: # type: ignore[override] + """Establishes a tunnel connection through HTTP CONNECT.""" + if self.proxy and self.proxy.scheme == "https": + tunnel_scheme = "https" + else: + tunnel_scheme = "http" + conn.set_tunnel( + scheme=tunnel_scheme, + host=self._tunnel_host, + port=self.port, + headers=self.proxy_headers, + ) conn.connect() - def _new_conn(self): + def _new_conn(self) -> BaseHTTPSConnection: """ - Return a fresh :class:`http.client.HTTPSConnection`. + Return a fresh :class:`urllib3.connection.HTTPConnection`. """ self.num_connections += 1 log.debug( @@ -975,53 +1052,58 @@ def _new_conn(self): self.port or "443", ) - if not self.ConnectionCls or self.ConnectionCls is DummyConnection: - raise SSLError( + if not self.ConnectionCls or self.ConnectionCls is DummyConnection: # type: ignore[comparison-overlap] + raise ImportError( "Can't connect to HTTPS URL because the SSL module is not available." ) - actual_host = self.host + actual_host: str = self.host actual_port = self.port - if self.proxy is not None: + if self.proxy is not None and self.proxy.host is not None: actual_host = self.proxy.host actual_port = self.proxy.port - conn = self.ConnectionCls( + return self.ConnectionCls( host=actual_host, port=actual_port, timeout=self.timeout.connect_timeout, - strict=self.strict, cert_file=self.cert_file, key_file=self.key_file, key_password=self.key_password, - **self.conn_kw + cert_reqs=self.cert_reqs, + ca_certs=self.ca_certs, + ca_cert_dir=self.ca_cert_dir, + assert_hostname=self.assert_hostname, + assert_fingerprint=self.assert_fingerprint, + ssl_version=self.ssl_version, + ssl_minimum_version=self.ssl_minimum_version, + ssl_maximum_version=self.ssl_maximum_version, + **self.conn_kw, ) - return self._prepare_conn(conn) - - def _validate_conn(self, conn): + def _validate_conn(self, conn: BaseHTTPConnection) -> None: """ Called right before a request is made, after the socket is created. """ - super(HTTPSConnectionPool, self)._validate_conn(conn) + super()._validate_conn(conn) # Force connect early to allow us to validate the connection. - if not getattr(conn, "sock", None): # AppEngine might not have `.sock` + if conn.is_closed: conn.connect() if not conn.is_verified: warnings.warn( ( - "Unverified HTTPS request is being made to host '%s'. " + f"Unverified HTTPS request is being made to host '{conn.host}'. " "Adding certificate verification is strongly advised. See: " "https://urllib3.readthedocs.io/en/latest/advanced-usage.html" - "#ssl-warnings" % conn.host + "#tls-warnings" ), InsecureRequestWarning, ) -def connection_from_url(url, **kw): +def connection_from_url(url: str, **kw: typing.Any) -> HTTPConnectionPool: """ Given a url, return an :class:`.ConnectionPool` instance of its host. @@ -1041,15 +1123,26 @@ def connection_from_url(url, **kw): >>> conn = connection_from_url('http://google.com/') >>> r = conn.request('GET', '/') """ - scheme, host, port = get_host(url) + scheme, _, host, port, *_ = parse_url(url) + scheme = scheme or "http" port = port or port_by_scheme.get(scheme, 80) if scheme == "https": - return HTTPSConnectionPool(host, port=port, **kw) + return HTTPSConnectionPool(host, port=port, **kw) # type: ignore[arg-type] else: - return HTTPConnectionPool(host, port=port, **kw) + return HTTPConnectionPool(host, port=port, **kw) # type: ignore[arg-type] + + +@typing.overload +def _normalize_host(host: None, scheme: str | None) -> None: + ... + +@typing.overload +def _normalize_host(host: str, scheme: str | None) -> str: + ... -def _normalize_host(host, scheme): + +def _normalize_host(host: str | None, scheme: str | None) -> str | None: """ Normalize hosts for comparisons and use with sockets. """ @@ -1062,6 +1155,24 @@ def _normalize_host(host, scheme): # Instead, we need to make sure we never pass ``None`` as the port. # However, for backward compatibility reasons we can't actually # *assert* that. See http://bugs.python.org/issue28539 - if host.startswith("[") and host.endswith("]"): + if host and host.startswith("[") and host.endswith("]"): host = host[1:-1] return host + + +def _url_from_pool( + pool: HTTPConnectionPool | HTTPSConnectionPool, path: str | None = None +) -> str: + """Returns the URL from a given connection pool. This is mainly used for testing and logging.""" + return Url(scheme=pool.scheme, host=pool.host, port=pool.port, path=path).url + + +def _close_pool_connections(pool: queue.LifoQueue[typing.Any]) -> None: + """Drains a queue of connections and closes each one.""" + try: + while True: + conn = pool.get(block=False) + if conn: + conn.close() + except queue.Empty: + pass # Done. diff --git a/src/urllib3/contrib/__init__.pyi b/src/urllib3/contrib/__init__.pyi deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/urllib3/contrib/_appengine_environ.py b/src/urllib3/contrib/_appengine_environ.py deleted file mode 100644 index 8765b907d7..0000000000 --- a/src/urllib3/contrib/_appengine_environ.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -This module provides means to detect the App Engine environment. -""" - -import os - - -def is_appengine(): - return is_local_appengine() or is_prod_appengine() - - -def is_appengine_sandbox(): - """Reports if the app is running in the first generation sandbox. - - The second generation runtimes are technically still in a sandbox, but it - is much less restrictive, so generally you shouldn't need to check for it. - see https://cloud.google.com/appengine/docs/standard/runtimes - """ - return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27" - - -def is_local_appengine(): - return "APPENGINE_RUNTIME" in os.environ and os.environ.get( - "SERVER_SOFTWARE", "" - ).startswith("Development/") - - -def is_prod_appengine(): - return "APPENGINE_RUNTIME" in os.environ and os.environ.get( - "SERVER_SOFTWARE", "" - ).startswith("Google App Engine/") - - -def is_prod_appengine_mvms(): - """Deprecated.""" - return False diff --git a/src/urllib3/contrib/_securetransport/bindings.py b/src/urllib3/contrib/_securetransport/bindings.py index 11524d400b..3e4cd466ea 100644 --- a/src/urllib3/contrib/_securetransport/bindings.py +++ b/src/urllib3/contrib/_securetransport/bindings.py @@ -1,3 +1,5 @@ +# type: ignore + """ This module uses ctypes to bind a whole bunch of functions and constants from SecureTransport. The goal here is to provide the low-level API to @@ -29,7 +31,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import + +from __future__ import annotations import platform from ctypes import ( @@ -48,8 +51,6 @@ ) from ctypes.util import find_library -from urllib3.packages.six import raise_from - if platform.system() != "Darwin": raise ImportError("Only macOS is supported") @@ -57,16 +58,16 @@ version_info = tuple(map(int, version.split("."))) if version_info < (10, 8): raise OSError( - "Only OS X 10.8 and newer are supported, not %s.%s" - % (version_info[0], version_info[1]) + f"Only OS X 10.8 and newer are supported, not {version_info[0]}.{version_info[1]}" ) -def load_cdll(name, macos10_16_path): +def load_cdll(name: str, macos10_16_path: str) -> CDLL: """Loads a CDLL by name, falling back to known path on 10.16+""" try: # Big Sur is technically 11 but we use 10.16 due to the Big Sur # beta being labeled as 10.16. + path: str | None if version_info >= (10, 16): path = macos10_16_path else: @@ -75,7 +76,7 @@ def load_cdll(name, macos10_16_path): raise OSError # Caught and reraised as 'ImportError' return CDLL(path, use_errno=True) except OSError: - raise_from(ImportError("The library %s failed to load" % name), None) + raise ImportError(f"The library {name} failed to load") from None Security = load_cdll( @@ -416,104 +417,14 @@ def load_cdll(name, macos10_16_path): CoreFoundation.CFStringRef = CFStringRef CoreFoundation.CFDictionaryRef = CFDictionaryRef -except (AttributeError): - raise ImportError("Error initializing ctypes") +except AttributeError: + raise ImportError("Error initializing ctypes") from None -class CFConst(object): +class CFConst: """ A class object that acts as essentially a namespace for CoreFoundation constants. """ kCFStringEncodingUTF8 = CFStringEncoding(0x08000100) - - -class SecurityConst(object): - """ - A class object that acts as essentially a namespace for Security constants. - """ - - kSSLSessionOptionBreakOnServerAuth = 0 - - kSSLProtocol2 = 1 - kSSLProtocol3 = 2 - kTLSProtocol1 = 4 - kTLSProtocol11 = 7 - kTLSProtocol12 = 8 - # SecureTransport does not support TLS 1.3 even if there's a constant for it - kTLSProtocol13 = 10 - kTLSProtocolMaxSupported = 999 - - kSSLClientSide = 1 - kSSLStreamType = 0 - - kSecFormatPEMSequence = 10 - - kSecTrustResultInvalid = 0 - kSecTrustResultProceed = 1 - # This gap is present on purpose: this was kSecTrustResultConfirm, which - # is deprecated. - kSecTrustResultDeny = 3 - kSecTrustResultUnspecified = 4 - kSecTrustResultRecoverableTrustFailure = 5 - kSecTrustResultFatalTrustFailure = 6 - kSecTrustResultOtherError = 7 - - errSSLProtocol = -9800 - errSSLWouldBlock = -9803 - errSSLClosedGraceful = -9805 - errSSLClosedNoNotify = -9816 - errSSLClosedAbort = -9806 - - errSSLXCertChainInvalid = -9807 - errSSLCrypto = -9809 - errSSLInternal = -9810 - errSSLCertExpired = -9814 - errSSLCertNotYetValid = -9815 - errSSLUnknownRootCert = -9812 - errSSLNoRootCert = -9813 - errSSLHostNameMismatch = -9843 - errSSLPeerHandshakeFail = -9824 - errSSLPeerUserCancelled = -9839 - errSSLWeakPeerEphemeralDHKey = -9850 - errSSLServerAuthCompleted = -9841 - errSSLRecordOverflow = -9847 - - errSecVerifyFailed = -67808 - errSecNoTrustSettings = -25263 - errSecItemNotFound = -25300 - errSecInvalidTrustSettings = -25262 - - # Cipher suites. We only pick the ones our default cipher string allows. - # Source: https://developer.apple.com/documentation/security/1550981-ssl_cipher_suite_values - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030 - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9 - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8 - TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F - TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024 - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028 - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014 - TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B - TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013 - TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067 - TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033 - TLS_RSA_WITH_AES_256_GCM_SHA384 = 0x009D - TLS_RSA_WITH_AES_128_GCM_SHA256 = 0x009C - TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D - TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C - TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035 - TLS_RSA_WITH_AES_128_CBC_SHA = 0x002F - TLS_AES_128_GCM_SHA256 = 0x1301 - TLS_AES_256_GCM_SHA384 = 0x1302 - TLS_AES_128_CCM_8_SHA256 = 0x1305 - TLS_AES_128_CCM_SHA256 = 0x1304 diff --git a/src/urllib3/contrib/_securetransport/low_level.py b/src/urllib3/contrib/_securetransport/low_level.py index ed8120190c..e23569972c 100644 --- a/src/urllib3/contrib/_securetransport/low_level.py +++ b/src/urllib3/contrib/_securetransport/low_level.py @@ -7,6 +7,8 @@ are almost entirely about trying to avoid memory leaks and providing appropriate and useful assistance to the higher-level code. """ +from __future__ import annotations + import base64 import ctypes import itertools @@ -15,8 +17,20 @@ import ssl import struct import tempfile - -from .bindings import CFConst, CoreFoundation, Security +import typing + +from .bindings import ( # type: ignore[attr-defined] + CFArray, + CFConst, + CFData, + CFDictionary, + CFMutableArray, + CFString, + CFTypeRef, + CoreFoundation, + SecKeychainRef, + Security, +) # This regular expression is used to grab PEM data out of a PEM bundle. _PEM_CERTS_RE = re.compile( @@ -24,7 +38,7 @@ ) -def _cf_data_from_bytes(bytestring): +def _cf_data_from_bytes(bytestring: bytes) -> CFData: """ Given a bytestring, create a CFData object from it. This CFData object must be CFReleased by the caller. @@ -34,7 +48,9 @@ def _cf_data_from_bytes(bytestring): ) -def _cf_dictionary_from_tuples(tuples): +def _cf_dictionary_from_tuples( + tuples: list[tuple[typing.Any, typing.Any]] +) -> CFDictionary: """ Given a list of Python tuples, create an associated CFDictionary. """ @@ -56,7 +72,7 @@ def _cf_dictionary_from_tuples(tuples): ) -def _cfstr(py_bstr): +def _cfstr(py_bstr: bytes) -> CFString: """ Given a Python binary data, create a CFString. The string must be CFReleased by the caller. @@ -70,7 +86,7 @@ def _cfstr(py_bstr): return cf_str -def _create_cfstring_array(lst): +def _create_cfstring_array(lst: list[bytes]) -> CFMutableArray: """ Given a list of Python binary data, create an associated CFMutableArray. The array must be CFReleased by the caller. @@ -97,11 +113,11 @@ def _create_cfstring_array(lst): except BaseException as e: if cf_arr: CoreFoundation.CFRelease(cf_arr) - raise ssl.SSLError("Unable to allocate array: %s" % (e,)) + raise ssl.SSLError(f"Unable to allocate array: {e}") from None return cf_arr -def _cf_string_to_unicode(value): +def _cf_string_to_unicode(value: CFString) -> str | None: """ Creates a Unicode string from a CFString object. Used entirely for error reporting. @@ -123,10 +139,12 @@ def _cf_string_to_unicode(value): string = buffer.value if string is not None: string = string.decode("utf-8") - return string + return string # type: ignore[no-any-return] -def _assert_no_error(error, exception_class=None): +def _assert_no_error( + error: int, exception_class: type[BaseException] | None = None +) -> None: """ Checks the return code and throws an exception if there is an error to report @@ -138,8 +156,8 @@ def _assert_no_error(error, exception_class=None): output = _cf_string_to_unicode(cf_error_string) CoreFoundation.CFRelease(cf_error_string) - if output is None or output == u"": - output = u"OSStatus %s" % error + if output is None or output == "": + output = f"OSStatus {error}" if exception_class is None: exception_class = ssl.SSLError @@ -147,7 +165,7 @@ def _assert_no_error(error, exception_class=None): raise exception_class(output) -def _cert_array_from_pem(pem_bundle): +def _cert_array_from_pem(pem_bundle: bytes) -> CFArray: """ Given a bundle of certs in PEM format, turns them into a CFArray of certs that can be used to validate a cert chain. @@ -188,27 +206,28 @@ def _cert_array_from_pem(pem_bundle): # We only want to do that if an error occurs: otherwise, the caller # should free. CoreFoundation.CFRelease(cert_array) + raise return cert_array -def _is_cert(item): +def _is_cert(item: CFTypeRef) -> bool: """ Returns True if a given CFTypeRef is a certificate. """ expected = Security.SecCertificateGetTypeID() - return CoreFoundation.CFGetTypeID(item) == expected + return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return] -def _is_identity(item): +def _is_identity(item: CFTypeRef) -> bool: """ Returns True if a given CFTypeRef is an identity. """ expected = Security.SecIdentityGetTypeID() - return CoreFoundation.CFGetTypeID(item) == expected + return CoreFoundation.CFGetTypeID(item) == expected # type: ignore[no-any-return] -def _temporary_keychain(): +def _temporary_keychain() -> tuple[SecKeychainRef, str]: """ This function creates a temporary Mac keychain that we can use to work with credentials. This keychain uses a one-time password and a temporary file to @@ -243,7 +262,9 @@ def _temporary_keychain(): return keychain, tempdirectory -def _load_items_from_file(keychain, path): +def _load_items_from_file( + keychain: SecKeychainRef, path: str +) -> tuple[list[CFTypeRef], list[CFTypeRef]]: """ Given a single file, loads all the trust objects from it into arrays and the keychain. @@ -298,7 +319,7 @@ def _load_items_from_file(keychain, path): return (identities, certificates) -def _load_client_cert_chain(keychain, *paths): +def _load_client_cert_chain(keychain: SecKeychainRef, *paths: str | None) -> CFArray: """ Load certificates and maybe keys from a number of files. Has the end goal of returning a CFArray containing one SecIdentityRef, and then zero or more @@ -334,10 +355,10 @@ def _load_client_cert_chain(keychain, *paths): identities = [] # Filter out bad paths. - paths = (path for path in paths if path) + filtered_paths = (path for path in paths if path) try: - for file_path in paths: + for file_path in filtered_paths: new_identities, new_certs = _load_items_from_file(keychain, file_path) identities.extend(new_identities) certificates.extend(new_certs) @@ -382,7 +403,7 @@ def _load_client_cert_chain(keychain, *paths): } -def _build_tls_unknown_ca_alert(version): +def _build_tls_unknown_ca_alert(version: str) -> bytes: """ Builds a TLS alert record for an unknown CA. """ @@ -394,3 +415,60 @@ def _build_tls_unknown_ca_alert(version): record_type_alert = 0x15 record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg return record + + +class SecurityConst: + """ + A class object that acts as essentially a namespace for Security constants. + """ + + kSSLSessionOptionBreakOnServerAuth = 0 + + kSSLProtocol2 = 1 + kSSLProtocol3 = 2 + kTLSProtocol1 = 4 + kTLSProtocol11 = 7 + kTLSProtocol12 = 8 + # SecureTransport does not support TLS 1.3 even if there's a constant for it + kTLSProtocol13 = 10 + kTLSProtocolMaxSupported = 999 + + kSSLClientSide = 1 + kSSLStreamType = 0 + + kSecFormatPEMSequence = 10 + + kSecTrustResultInvalid = 0 + kSecTrustResultProceed = 1 + # This gap is present on purpose: this was kSecTrustResultConfirm, which + # is deprecated. + kSecTrustResultDeny = 3 + kSecTrustResultUnspecified = 4 + kSecTrustResultRecoverableTrustFailure = 5 + kSecTrustResultFatalTrustFailure = 6 + kSecTrustResultOtherError = 7 + + errSSLProtocol = -9800 + errSSLWouldBlock = -9803 + errSSLClosedGraceful = -9805 + errSSLClosedNoNotify = -9816 + errSSLClosedAbort = -9806 + + errSSLXCertChainInvalid = -9807 + errSSLCrypto = -9809 + errSSLInternal = -9810 + errSSLCertExpired = -9814 + errSSLCertNotYetValid = -9815 + errSSLUnknownRootCert = -9812 + errSSLNoRootCert = -9813 + errSSLHostNameMismatch = -9843 + errSSLPeerHandshakeFail = -9824 + errSSLPeerUserCancelled = -9839 + errSSLWeakPeerEphemeralDHKey = -9850 + errSSLServerAuthCompleted = -9841 + errSSLRecordOverflow = -9847 + + errSecVerifyFailed = -67808 + errSecNoTrustSettings = -25263 + errSecItemNotFound = -25300 + errSecInvalidTrustSettings = -25262 diff --git a/src/urllib3/contrib/appengine.py b/src/urllib3/contrib/appengine.py deleted file mode 100644 index aa64a0914c..0000000000 --- a/src/urllib3/contrib/appengine.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -This module provides a pool manager that uses Google App Engine's -`URLFetch Service `_. - -Example usage:: - - from urllib3 import PoolManager - from urllib3.contrib.appengine import AppEngineManager, is_appengine_sandbox - - if is_appengine_sandbox(): - # AppEngineManager uses AppEngine's URLFetch API behind the scenes - http = AppEngineManager() - else: - # PoolManager uses a socket-level API behind the scenes - http = PoolManager() - - r = http.request('GET', 'https://google.com/') - -There are `limitations `_ to the URLFetch service and it may not be -the best choice for your application. There are three options for using -urllib3 on Google App Engine: - -1. You can use :class:`AppEngineManager` with URLFetch. URLFetch is - cost-effective in many circumstances as long as your usage is within the - limitations. -2. You can use a normal :class:`~urllib3.PoolManager` by enabling sockets. - Sockets also have `limitations and restrictions - `_ and have a lower free quota than URLFetch. - To use sockets, be sure to specify the following in your ``app.yaml``:: - - env_variables: - GAE_USE_SOCKETS_HTTPLIB : 'true' - -3. If you are using `App Engine Flexible -`_, you can use the standard -:class:`PoolManager` without any configuration or special environment variables. -""" - -from __future__ import absolute_import - -import io -import logging -import warnings - -from ..exceptions import ( - HTTPError, - HTTPWarning, - MaxRetryError, - ProtocolError, - SSLError, - TimeoutError, -) -from ..packages.six.moves.urllib.parse import urljoin -from ..request import RequestMethods -from ..response import HTTPResponse -from ..util.retry import Retry -from ..util.timeout import Timeout -from . import _appengine_environ - -try: - from google.appengine.api import urlfetch -except ImportError: - urlfetch = None - - -log = logging.getLogger(__name__) - - -class AppEnginePlatformWarning(HTTPWarning): - pass - - -class AppEnginePlatformError(HTTPError): - pass - - -class AppEngineManager(RequestMethods): - """ - Connection manager for Google App Engine sandbox applications. - - This manager uses the URLFetch service directly instead of using the - emulated httplib, and is subject to URLFetch limitations as described in - the App Engine documentation `here - `_. - - Notably it will raise an :class:`AppEnginePlatformError` if: - * URLFetch is not available. - * If you attempt to use this on App Engine Flexible, as full socket - support is available. - * If a request size is more than 10 megabytes. - * If a response size is more than 32 megabytes. - * If you use an unsupported request method such as OPTIONS. - - Beyond those cases, it will raise normal urllib3 errors. - """ - - def __init__( - self, - headers=None, - retries=None, - validate_certificate=True, - urlfetch_retries=True, - ): - if not urlfetch: - raise AppEnginePlatformError( - "URLFetch is not available in this environment." - ) - - warnings.warn( - "urllib3 is using URLFetch on Google App Engine sandbox instead " - "of sockets. To use sockets directly instead of URLFetch see " - "https://urllib3.readthedocs.io/en/latest/reference/urllib3.contrib.html.", - AppEnginePlatformWarning, - ) - - RequestMethods.__init__(self, headers) - self.validate_certificate = validate_certificate - self.urlfetch_retries = urlfetch_retries - - self.retries = retries or Retry.DEFAULT - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Return False to re-raise any potential exceptions - return False - - def urlopen( - self, - method, - url, - body=None, - headers=None, - retries=None, - redirect=True, - timeout=Timeout.DEFAULT_TIMEOUT, - **response_kw - ): - - retries = self._get_retries(retries, redirect) - - try: - follow_redirects = redirect and retries.redirect != 0 and retries.total - response = urlfetch.fetch( - url, - payload=body, - method=method, - headers=headers or {}, - allow_truncated=False, - follow_redirects=self.urlfetch_retries and follow_redirects, - deadline=self._get_absolute_timeout(timeout), - validate_certificate=self.validate_certificate, - ) - except urlfetch.DeadlineExceededError as e: - raise TimeoutError(self, e) - - except urlfetch.InvalidURLError as e: - if "too large" in str(e): - raise AppEnginePlatformError( - "URLFetch request too large, URLFetch only " - "supports requests up to 10mb in size.", - e, - ) - raise ProtocolError(e) - - except urlfetch.DownloadError as e: - if "Too many redirects" in str(e): - raise MaxRetryError(self, url, reason=e) - raise ProtocolError(e) - - except urlfetch.ResponseTooLargeError as e: - raise AppEnginePlatformError( - "URLFetch response too large, URLFetch only supports" - "responses up to 32mb in size.", - e, - ) - - except urlfetch.SSLCertificateError as e: - raise SSLError(e) - - except urlfetch.InvalidMethodError as e: - raise AppEnginePlatformError( - "URLFetch does not support method: %s" % method, e - ) - - http_response = self._urlfetch_response_to_http_response( - response, retries=retries, **response_kw - ) - - # Handle redirect? - redirect_location = redirect and http_response.get_redirect_location() - if redirect_location: - # Check for redirect response - if self.urlfetch_retries and retries.raise_on_redirect: - raise MaxRetryError(self, url, "too many redirects") - else: - if http_response.status == 303: - method = "GET" - - try: - retries = retries.increment( - method, url, response=http_response, _pool=self - ) - except MaxRetryError: - if retries.raise_on_redirect: - raise MaxRetryError(self, url, "too many redirects") - return http_response - - retries.sleep_for_retry(http_response) - log.debug("Redirecting %s -> %s", url, redirect_location) - redirect_url = urljoin(url, redirect_location) - return self.urlopen( - method, - redirect_url, - body, - headers, - retries=retries, - redirect=redirect, - timeout=timeout, - **response_kw - ) - - # Check if we should retry the HTTP response. - has_retry_after = bool(http_response.getheader("Retry-After")) - if retries.is_retry(method, http_response.status, has_retry_after): - retries = retries.increment(method, url, response=http_response, _pool=self) - log.debug("Retry: %s", url) - retries.sleep(http_response) - return self.urlopen( - method, - url, - body=body, - headers=headers, - retries=retries, - redirect=redirect, - timeout=timeout, - **response_kw - ) - - return http_response - - def _urlfetch_response_to_http_response(self, urlfetch_resp, **response_kw): - - if is_prod_appengine(): - # Production GAE handles deflate encoding automatically, but does - # not remove the encoding header. - content_encoding = urlfetch_resp.headers.get("content-encoding") - - if content_encoding == "deflate": - del urlfetch_resp.headers["content-encoding"] - - transfer_encoding = urlfetch_resp.headers.get("transfer-encoding") - # We have a full response's content, - # so let's make sure we don't report ourselves as chunked data. - if transfer_encoding == "chunked": - encodings = transfer_encoding.split(",") - encodings.remove("chunked") - urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings) - - original_response = HTTPResponse( - # In order for decoding to work, we must present the content as - # a file-like object. - body=io.BytesIO(urlfetch_resp.content), - msg=urlfetch_resp.header_msg, - headers=urlfetch_resp.headers, - status=urlfetch_resp.status_code, - **response_kw - ) - - return HTTPResponse( - body=io.BytesIO(urlfetch_resp.content), - headers=urlfetch_resp.headers, - status=urlfetch_resp.status_code, - original_response=original_response, - **response_kw - ) - - def _get_absolute_timeout(self, timeout): - if timeout is Timeout.DEFAULT_TIMEOUT: - return None # Defer to URLFetch's default. - if isinstance(timeout, Timeout): - if timeout._read is not None or timeout._connect is not None: - warnings.warn( - "URLFetch does not support granular timeout settings, " - "reverting to total or default URLFetch timeout.", - AppEnginePlatformWarning, - ) - return timeout.total - return timeout - - def _get_retries(self, retries, redirect): - if not isinstance(retries, Retry): - retries = Retry.from_int(retries, redirect=redirect, default=self.retries) - - if retries.connect or retries.read or retries.redirect: - warnings.warn( - "URLFetch only supports total retries and does not " - "recognize connect, read, or redirect retry parameters.", - AppEnginePlatformWarning, - ) - - return retries - - -# Alias methods from _appengine_environ to maintain public API interface. - -is_appengine = _appengine_environ.is_appengine -is_appengine_sandbox = _appengine_environ.is_appengine_sandbox -is_local_appengine = _appengine_environ.is_local_appengine -is_prod_appengine = _appengine_environ.is_prod_appengine -is_prod_appengine_mvms = _appengine_environ.is_prod_appengine_mvms diff --git a/src/urllib3/contrib/ntlmpool.py b/src/urllib3/contrib/ntlmpool.py deleted file mode 100644 index b2df45dcf6..0000000000 --- a/src/urllib3/contrib/ntlmpool.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -NTLM authenticating pool, contributed by erikcederstran - -Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10 -""" -from __future__ import absolute_import - -from logging import getLogger - -from ntlm import ntlm - -from .. import HTTPSConnectionPool -from ..packages.six.moves.http_client import HTTPSConnection - -log = getLogger(__name__) - - -class NTLMConnectionPool(HTTPSConnectionPool): - """ - Implements an NTLM authentication version of an urllib3 connection pool - """ - - scheme = "https" - - def __init__(self, user, pw, authurl, *args, **kwargs): - """ - authurl is a random URL on the server that is protected by NTLM. - user is the Windows user, probably in the DOMAIN\\username format. - pw is the password for the user. - """ - super(NTLMConnectionPool, self).__init__(*args, **kwargs) - self.authurl = authurl - self.rawuser = user - user_parts = user.split("\\", 1) - self.domain = user_parts[0].upper() - self.user = user_parts[1] - self.pw = pw - - def _new_conn(self): - # Performs the NTLM handshake that secures the connection. The socket - # must be kept open while requests are performed. - self.num_connections += 1 - log.debug( - "Starting NTLM HTTPS connection no. %d: https://%s%s", - self.num_connections, - self.host, - self.authurl, - ) - - headers = {"Connection": "Keep-Alive"} - req_header = "Authorization" - resp_header = "www-authenticate" - - conn = HTTPSConnection(host=self.host, port=self.port) - - # Send negotiation message - headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE( - self.rawuser - ) - log.debug("Request headers: %s", headers) - conn.request("GET", self.authurl, None, headers) - res = conn.getresponse() - reshdr = dict(res.getheaders()) - log.debug("Response status: %s %s", res.status, res.reason) - log.debug("Response headers: %s", reshdr) - log.debug("Response data: %s [...]", res.read(100)) - - # Remove the reference to the socket, so that it can not be closed by - # the response object (we want to keep the socket open) - res.fp = None - - # Server should respond with a challenge message - auth_header_values = reshdr[resp_header].split(", ") - auth_header_value = None - for s in auth_header_values: - if s[:5] == "NTLM ": - auth_header_value = s[5:] - if auth_header_value is None: - raise Exception( - "Unexpected %s response header: %s" % (resp_header, reshdr[resp_header]) - ) - - # Send authentication message - ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE( - auth_header_value - ) - auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE( - ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags - ) - headers[req_header] = "NTLM %s" % auth_msg - log.debug("Request headers: %s", headers) - conn.request("GET", self.authurl, None, headers) - res = conn.getresponse() - log.debug("Response status: %s %s", res.status, res.reason) - log.debug("Response headers: %s", dict(res.getheaders())) - log.debug("Response data: %s [...]", res.read()[:100]) - if res.status != 200: - if res.status == 401: - raise Exception("Server rejected request: wrong username or password") - raise Exception("Wrong server response: %s %s" % (res.status, res.reason)) - - res.fp = None - log.debug("Connection established") - return conn - - def urlopen( - self, - method, - url, - body=None, - headers=None, - retries=3, - redirect=True, - assert_same_host=True, - ): - if headers is None: - headers = {} - headers["Connection"] = "Keep-Alive" - return super(NTLMConnectionPool, self).urlopen( - method, url, body, headers, retries, redirect, assert_same_host - ) diff --git a/src/urllib3/contrib/pyopenssl.py b/src/urllib3/contrib/pyopenssl.py index 0cabab1aed..0089cd27e0 100644 --- a/src/urllib3/contrib/pyopenssl.py +++ b/src/urllib3/contrib/pyopenssl.py @@ -1,8 +1,8 @@ """ -TLS with SNI_-support for Python 2. Follow these instructions if you would -like to verify TLS certificates in Python 2. Note, the default libraries do -*not* do certificate checking; you need to do additional work to validate -certificates yourself. +Module for using pyOpenSSL as a TLS backend. This module was relevant before +the standard library ``ssl`` module supported SNI, but now that we've dropped +support for Python 2.7 all relevant Python versions support SNI so +**this module is no longer recommended**. This needs the following packages installed: @@ -10,7 +10,7 @@ * `cryptography`_ (minimum 1.3.4, from pyopenssl) * `idna`_ (minimum 2.0, from cryptography) -However, pyopenssl depends on cryptography, which depends on idna, so while we +However, pyOpenSSL depends on cryptography, which depends on idna, so while we use all three directly here we end up having relatively few packages required. You can install them with the following command: @@ -33,64 +33,55 @@ except ImportError: pass -Now you can use :mod:`urllib3` as you normally would, and it will support SNI -when the required modules are installed. - -Activating this module also has the positive side effect of disabling SSL/TLS -compression in Python 2 (see `CRIME attack`_). - -.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication -.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) .. _pyopenssl: https://www.pyopenssl.org .. _cryptography: https://cryptography.io .. _idna: https://github.com/kjd/idna """ -from __future__ import absolute_import -import OpenSSL.SSL +from __future__ import annotations + +import OpenSSL.SSL # type: ignore[import] from cryptography import x509 -from cryptography.hazmat.backends.openssl import backend as openssl_backend -from cryptography.hazmat.backends.openssl.x509 import _Certificate try: - from cryptography.x509 import UnsupportedExtension + from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined] except ImportError: # UnsupportedExtension is gone in cryptography >= 2.1.0 - class UnsupportedExtension(Exception): + class UnsupportedExtension(Exception): # type: ignore[no-redef] pass +import logging +import ssl +import typing +import warnings from io import BytesIO -from socket import error as SocketError +from socket import socket as socket_cls from socket import timeout -try: # Platform-specific: Python 2 - from socket import _fileobject -except ImportError: # Platform-specific: Python 3 - _fileobject = None - from ..packages.backports.makefile import backport_makefile +from .. import util -import logging -import ssl -import sys +warnings.warn( + "'urllib3.contrib.pyopenssl' module is deprecated and will be removed " + "in urllib3 v2.1.0. Read more in this issue: " + "https://github.com/urllib3/urllib3/issues/2680", + category=DeprecationWarning, + stacklevel=2, +) -from .. import util -from ..packages import six +if typing.TYPE_CHECKING: + from OpenSSL.crypto import X509 # type: ignore[import] -__all__ = ["inject_into_urllib3", "extract_from_urllib3"] -# SNI always works. -HAS_SNI = True +__all__ = ["inject_into_urllib3", "extract_from_urllib3"] # Map from urllib3 to PyOpenSSL compatible parameter-values. _openssl_versions = { - util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, + util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] + util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, } -if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"): - _openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD - if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD @@ -104,43 +95,77 @@ class UnsupportedExtension(Exception): ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } -_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items()) +_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()} + +# The SSLvX values are the most likely to be missing in the future +# but we check them all just to be sure. +_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr( + OpenSSL.SSL, "OP_NO_SSLv3", 0 +) +_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0) +_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0) +_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0) +_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0) + +_openssl_to_ssl_minimum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1, + ssl.TLSVersion.TLSv1_3: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), + ssl.TLSVersion.MAXIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), +} +_openssl_to_ssl_maximum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 + | _OP_NO_TLSv1 + | _OP_NO_TLSv1_1 + | _OP_NO_TLSv1_2 + | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, +} # OpenSSL will only write 16K at a time SSL_WRITE_BLOCKSIZE = 16384 -orig_util_HAS_SNI = util.HAS_SNI orig_util_SSLContext = util.ssl_.SSLContext log = logging.getLogger(__name__) -def inject_into_urllib3(): +def inject_into_urllib3() -> None: "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." _validate_dependencies_met() - util.SSLContext = PyOpenSSLContext - util.ssl_.SSLContext = PyOpenSSLContext - util.HAS_SNI = HAS_SNI - util.ssl_.HAS_SNI = HAS_SNI + util.SSLContext = PyOpenSSLContext # type: ignore[assignment] + util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment] util.IS_PYOPENSSL = True util.ssl_.IS_PYOPENSSL = True -def extract_from_urllib3(): +def extract_from_urllib3() -> None: "Undo monkey-patching by :func:`inject_into_urllib3`." util.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext - util.HAS_SNI = orig_util_HAS_SNI - util.ssl_.HAS_SNI = orig_util_HAS_SNI util.IS_PYOPENSSL = False util.ssl_.IS_PYOPENSSL = False -def _validate_dependencies_met(): +def _validate_dependencies_met() -> None: """ Verifies that PyOpenSSL's package-level dependencies have been met. Throws `ImportError` if they are not met. @@ -166,7 +191,7 @@ def _validate_dependencies_met(): ) -def _dnsname_to_stdlib(name): +def _dnsname_to_stdlib(name: str) -> str | None: """ Converts a dNSName SubjectAlternativeName field to the form used by the standard library on the given Python version. @@ -180,7 +205,7 @@ def _dnsname_to_stdlib(name): the name given should be skipped. """ - def idna_encode(name): + def idna_encode(name: str) -> bytes | None: """ Borrowed wholesale from the Python Cryptography Project. It turns out that we can't just safely call `idna.encode`: it can explode for @@ -189,7 +214,7 @@ def idna_encode(name): import idna try: - for prefix in [u"*.", u"."]: + for prefix in ["*.", "."]: if name.startswith(prefix): name = name[len(prefix) :] return prefix.encode("ascii") + idna.encode(name) @@ -201,25 +226,17 @@ def idna_encode(name): if ":" in name: return name - name = idna_encode(name) - if name is None: + encoded_name = idna_encode(name) + if encoded_name is None: return None - elif sys.version_info >= (3, 0): - name = name.decode("utf-8") - return name + return encoded_name.decode("utf-8") -def get_subj_alt_name(peer_cert): +def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]: """ Given an PyOpenSSL certificate, provides all the subject alternative names. """ - # Pass the cert to cryptography, which has much better APIs for this. - if hasattr(peer_cert, "to_cryptography"): - cert = peer_cert.to_cryptography() - else: - # This is technically using private APIs, but should work across all - # relevant versions before PyOpenSSL got a proper API for this. - cert = _Certificate(openssl_backend, peer_cert._x509) + cert = peer_cert.to_cryptography() # We want to find the SAN extension. Ask Cryptography to locate it (it's # faster than looping in Python) @@ -263,93 +280,94 @@ def get_subj_alt_name(peer_cert): return names -class WrappedSocket(object): - """API-compatibility wrapper for Python OpenSSL's Connection-class. - - Note: _makefile_refs, _drop() and _reuse() are needed for the garbage - collector of pypy. - """ +class WrappedSocket: + """API-compatibility wrapper for Python OpenSSL's Connection-class.""" - def __init__(self, connection, socket, suppress_ragged_eofs=True): + def __init__( + self, + connection: OpenSSL.SSL.Connection, + socket: socket_cls, + suppress_ragged_eofs: bool = True, + ) -> None: self.connection = connection self.socket = socket self.suppress_ragged_eofs = suppress_ragged_eofs - self._makefile_refs = 0 + self._io_refs = 0 self._closed = False - def fileno(self): + def fileno(self) -> int: return self.socket.fileno() # Copy-pasted from Python 3.5 source code - def _decref_socketios(self): - if self._makefile_refs > 0: - self._makefile_refs -= 1 + def _decref_socketios(self) -> None: + if self._io_refs > 0: + self._io_refs -= 1 if self._closed: self.close() - def recv(self, *args, **kwargs): + def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes: try: data = self.connection.recv(*args, **kwargs) except OpenSSL.SSL.SysCallError as e: if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): return b"" else: - raise SocketError(str(e)) + raise OSError(e.args[0], str(e)) from e except OpenSSL.SSL.ZeroReturnError: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: return b"" else: raise - except OpenSSL.SSL.WantReadError: + except OpenSSL.SSL.WantReadError as e: if not util.wait_for_read(self.socket, self.socket.gettimeout()): - raise timeout("The read operation timed out") + raise timeout("The read operation timed out") from e else: return self.recv(*args, **kwargs) # TLS 1.3 post-handshake authentication except OpenSSL.SSL.Error as e: - raise ssl.SSLError("read error: %r" % e) + raise ssl.SSLError(f"read error: {e!r}") from e else: - return data + return data # type: ignore[no-any-return] - def recv_into(self, *args, **kwargs): + def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int: try: - return self.connection.recv_into(*args, **kwargs) + return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return] except OpenSSL.SSL.SysCallError as e: if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): return 0 else: - raise SocketError(str(e)) + raise OSError(e.args[0], str(e)) from e except OpenSSL.SSL.ZeroReturnError: if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: return 0 else: raise - except OpenSSL.SSL.WantReadError: + except OpenSSL.SSL.WantReadError as e: if not util.wait_for_read(self.socket, self.socket.gettimeout()): - raise timeout("The read operation timed out") + raise timeout("The read operation timed out") from e else: return self.recv_into(*args, **kwargs) # TLS 1.3 post-handshake authentication except OpenSSL.SSL.Error as e: - raise ssl.SSLError("read error: %r" % e) + raise ssl.SSLError(f"read error: {e!r}") from e - def settimeout(self, timeout): + def settimeout(self, timeout: float) -> None: return self.socket.settimeout(timeout) - def _send_until_done(self, data): + def _send_until_done(self, data: bytes) -> int: while True: try: - return self.connection.send(data) - except OpenSSL.SSL.WantWriteError: + return self.connection.send(data) # type: ignore[no-any-return] + except OpenSSL.SSL.WantWriteError as e: if not util.wait_for_write(self.socket, self.socket.gettimeout()): - raise timeout() + raise timeout() from e continue except OpenSSL.SSL.SysCallError as e: - raise SocketError(str(e)) + raise OSError(e.args[0], str(e)) from e - def sendall(self, data): + def sendall(self, data: bytes) -> None: total_sent = 0 while total_sent < len(data): sent = self._send_until_done( @@ -357,136 +375,135 @@ def sendall(self, data): ) total_sent += sent - def shutdown(self): + def shutdown(self) -> None: # FIXME rethrow compatible exceptions should we ever use this self.connection.shutdown() - def close(self): - if self._makefile_refs < 1: - try: - self._closed = True - return self.connection.close() - except OpenSSL.SSL.Error: - return - else: - self._makefile_refs -= 1 + def close(self) -> None: + self._closed = True + if self._io_refs <= 0: + self._real_close() - def getpeercert(self, binary_form=False): + def _real_close(self) -> None: + try: + return self.connection.close() # type: ignore[no-any-return] + except OpenSSL.SSL.Error: + return + + def getpeercert( + self, binary_form: bool = False + ) -> dict[str, list[typing.Any]] | None: x509 = self.connection.get_peer_certificate() if not x509: - return x509 + return x509 # type: ignore[no-any-return] if binary_form: - return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) + return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return] return { - "subject": ((("commonName", x509.get_subject().CN),),), + "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item] "subjectAltName": get_subj_alt_name(x509), } - def version(self): - return self.connection.get_protocol_version_name() - - def _reuse(self): - self._makefile_refs += 1 - - def _drop(self): - if self._makefile_refs < 1: - self.close() - else: - self._makefile_refs -= 1 - - -if _fileobject: # Platform-specific: Python 2 - - def makefile(self, mode, bufsize=-1): - self._makefile_refs += 1 - return _fileobject(self, mode, bufsize, close=True) + def version(self) -> str: + return self.connection.get_protocol_version_name() # type: ignore[no-any-return] -else: # Platform-specific: Python 3 - makefile = backport_makefile +WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined] -WrappedSocket.makefile = makefile - -class PyOpenSSLContext(object): +class PyOpenSSLContext: """ I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible for translating the interface of the standard library ``SSLContext`` object to calls into PyOpenSSL. """ - def __init__(self, protocol): + def __init__(self, protocol: int) -> None: self.protocol = _openssl_versions[protocol] self._ctx = OpenSSL.SSL.Context(self.protocol) self._options = 0 self.check_hostname = False + self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED + self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED @property - def options(self): + def options(self) -> int: return self._options @options.setter - def options(self, value): + def options(self, value: int) -> None: self._options = value - self._ctx.set_options(value) + self._set_ctx_options() @property - def verify_mode(self): + def verify_mode(self) -> int: return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] @verify_mode.setter - def verify_mode(self, value): + def verify_mode(self, value: ssl.VerifyMode) -> None: self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) - def set_default_verify_paths(self): + def set_default_verify_paths(self) -> None: self._ctx.set_default_verify_paths() - def set_ciphers(self, ciphers): - if isinstance(ciphers, six.text_type): + def set_ciphers(self, ciphers: bytes | str) -> None: + if isinstance(ciphers, str): ciphers = ciphers.encode("utf-8") self._ctx.set_cipher_list(ciphers) - def load_verify_locations(self, cafile=None, capath=None, cadata=None): + def load_verify_locations( + self, + cafile: str | None = None, + capath: str | None = None, + cadata: bytes | None = None, + ) -> None: if cafile is not None: - cafile = cafile.encode("utf-8") + cafile = cafile.encode("utf-8") # type: ignore[assignment] if capath is not None: - capath = capath.encode("utf-8") + capath = capath.encode("utf-8") # type: ignore[assignment] try: self._ctx.load_verify_locations(cafile, capath) if cadata is not None: self._ctx.load_verify_locations(BytesIO(cadata)) except OpenSSL.SSL.Error as e: - raise ssl.SSLError("unable to load trusted certificates: %r" % e) + raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e - def load_cert_chain(self, certfile, keyfile=None, password=None): - self._ctx.use_certificate_chain_file(certfile) - if password is not None: - if not isinstance(password, six.binary_type): - password = password.encode("utf-8") - self._ctx.set_passwd_cb(lambda *_: password) - self._ctx.use_privatekey_file(keyfile or certfile) + def load_cert_chain( + self, + certfile: str, + keyfile: str | None = None, + password: str | None = None, + ) -> None: + try: + self._ctx.use_certificate_chain_file(certfile) + if password is not None: + if not isinstance(password, bytes): + password = password.encode("utf-8") # type: ignore[assignment] + self._ctx.set_passwd_cb(lambda *_: password) + self._ctx.use_privatekey_file(keyfile or certfile) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e - def set_alpn_protocols(self, protocols): - protocols = [six.ensure_binary(p) for p in protocols] - return self._ctx.set_alpn_protos(protocols) + def set_alpn_protocols(self, protocols: list[bytes | str]) -> None: + protocols = [util.util.to_bytes(p, "ascii") for p in protocols] + return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return] def wrap_socket( self, - sock, - server_side=False, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - server_hostname=None, - ): + sock: socket_cls, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: bytes | str | None = None, + ) -> WrappedSocket: cnx = OpenSSL.SSL.Connection(self._ctx, sock) - if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3 - server_hostname = server_hostname.encode("utf-8") - - if server_hostname is not None: + # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3 + if server_hostname and not util.ssl_.is_ipaddress(server_hostname): + if isinstance(server_hostname, str): + server_hostname = server_hostname.encode("utf-8") cnx.set_tlsext_host_name(server_hostname) cnx.set_connect_state() @@ -494,16 +511,47 @@ def wrap_socket( while True: try: cnx.do_handshake() - except OpenSSL.SSL.WantReadError: + except OpenSSL.SSL.WantReadError as e: if not util.wait_for_read(sock, sock.gettimeout()): - raise timeout("select timed out") + raise timeout("select timed out") from e continue except OpenSSL.SSL.Error as e: - raise ssl.SSLError("bad handshake: %r" % e) + raise ssl.SSLError(f"bad handshake: {e!r}") from e break return WrappedSocket(cnx, sock) + def _set_ctx_options(self) -> None: + self._ctx.set_options( + self._options + | _openssl_to_ssl_minimum_version[self._minimum_version] + | _openssl_to_ssl_maximum_version[self._maximum_version] + ) -def _verify_callback(cnx, x509, err_no, err_depth, return_code): + @property + def minimum_version(self) -> int: + return self._minimum_version + + @minimum_version.setter + def minimum_version(self, minimum_version: int) -> None: + self._minimum_version = minimum_version + self._set_ctx_options() + + @property + def maximum_version(self) -> int: + return self._maximum_version + + @maximum_version.setter + def maximum_version(self, maximum_version: int) -> None: + self._maximum_version = maximum_version + self._set_ctx_options() + + +def _verify_callback( + cnx: OpenSSL.SSL.Connection, + x509: X509, + err_no: int, + err_depth: int, + return_code: int, +) -> bool: return err_no == 0 diff --git a/src/urllib3/contrib/securetransport.py b/src/urllib3/contrib/securetransport.py index ab092de67a..11beb3dfef 100644 --- a/src/urllib3/contrib/securetransport.py +++ b/src/urllib3/contrib/securetransport.py @@ -51,7 +51,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from __future__ import absolute_import + +from __future__ import annotations import contextlib import ctypes @@ -62,13 +63,18 @@ import ssl import struct import threading +import typing +import warnings import weakref - -import six +from socket import socket as socket_cls from .. import util -from ._securetransport.bindings import CoreFoundation, Security, SecurityConst +from ._securetransport.bindings import ( # type: ignore[attr-defined] + CoreFoundation, + Security, +) from ._securetransport.low_level import ( + SecurityConst, _assert_no_error, _build_tls_unknown_ca_alert, _cert_array_from_pem, @@ -77,18 +83,19 @@ _temporary_keychain, ) -try: # Platform-specific: Python 2 - from socket import _fileobject -except ImportError: # Platform-specific: Python 3 - _fileobject = None - from ..packages.backports.makefile import backport_makefile +warnings.warn( + "'urllib3.contrib.securetransport' module is deprecated and will be removed " + "in urllib3 v2.1.0. Read more in this issue: " + "https://github.com/urllib3/urllib3/issues/2681", + category=DeprecationWarning, + stacklevel=2, +) -__all__ = ["inject_into_urllib3", "extract_from_urllib3"] +if typing.TYPE_CHECKING: + from typing_extensions import Literal -# SNI always works -HAS_SNI = True +__all__ = ["inject_into_urllib3", "extract_from_urllib3"] -orig_util_HAS_SNI = util.HAS_SNI orig_util_SSLContext = util.ssl_.SSLContext # This dictionary is used by the read callback to obtain a handle to the @@ -107,54 +114,24 @@ # # This is good: if we had to lock in the callbacks we'd drastically slow down # the performance of this code. -_connection_refs = weakref.WeakValueDictionary() +_connection_refs: weakref.WeakValueDictionary[ + int, WrappedSocket +] = weakref.WeakValueDictionary() _connection_ref_lock = threading.Lock() # Limit writes to 16kB. This is OpenSSL's limit, but we'll cargo-cult it over # for no better reason than we need *a* limit, and this one is right there. SSL_WRITE_BLOCKSIZE = 16384 -# This is our equivalent of util.ssl_.DEFAULT_CIPHERS, but expanded out to -# individual cipher suites. We need to do this because this is how -# SecureTransport wants them. -CIPHER_SUITES = [ - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - SecurityConst.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - SecurityConst.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - SecurityConst.TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, - SecurityConst.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - SecurityConst.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - SecurityConst.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, - SecurityConst.TLS_DHE_RSA_WITH_AES_256_CBC_SHA, - SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, - SecurityConst.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, - SecurityConst.TLS_AES_256_GCM_SHA384, - SecurityConst.TLS_AES_128_GCM_SHA256, - SecurityConst.TLS_RSA_WITH_AES_256_GCM_SHA384, - SecurityConst.TLS_RSA_WITH_AES_128_GCM_SHA256, - SecurityConst.TLS_AES_128_CCM_8_SHA256, - SecurityConst.TLS_AES_128_CCM_SHA256, - SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA256, - SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA256, - SecurityConst.TLS_RSA_WITH_AES_256_CBC_SHA, - SecurityConst.TLS_RSA_WITH_AES_128_CBC_SHA, -] - # Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of # TLSv1 and a high of TLSv1.2. For everything else, we pin to that version. # TLSv1 to 1.2 are supported on macOS 10.8+ _protocol_to_min_max = { - util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12) + util.ssl_.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12), # type: ignore[attr-defined] + util.ssl_.PROTOCOL_TLS_CLIENT: ( # type: ignore[attr-defined] + SecurityConst.kTLSProtocol1, + SecurityConst.kTLSProtocol12, + ), } if hasattr(ssl, "PROTOCOL_SSLv2"): @@ -184,31 +161,38 @@ ) -def inject_into_urllib3(): +_tls_version_to_st: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: SecurityConst.kTLSProtocol1, + ssl.TLSVersion.TLSv1: SecurityConst.kTLSProtocol1, + ssl.TLSVersion.TLSv1_1: SecurityConst.kTLSProtocol11, + ssl.TLSVersion.TLSv1_2: SecurityConst.kTLSProtocol12, + ssl.TLSVersion.MAXIMUM_SUPPORTED: SecurityConst.kTLSProtocol12, +} + + +def inject_into_urllib3() -> None: """ Monkey-patch urllib3 with SecureTransport-backed SSL-support. """ - util.SSLContext = SecureTransportContext - util.ssl_.SSLContext = SecureTransportContext - util.HAS_SNI = HAS_SNI - util.ssl_.HAS_SNI = HAS_SNI + util.SSLContext = SecureTransportContext # type: ignore[assignment] + util.ssl_.SSLContext = SecureTransportContext # type: ignore[assignment] util.IS_SECURETRANSPORT = True util.ssl_.IS_SECURETRANSPORT = True -def extract_from_urllib3(): +def extract_from_urllib3() -> None: """ Undo monkey-patching by :func:`inject_into_urllib3`. """ util.SSLContext = orig_util_SSLContext util.ssl_.SSLContext = orig_util_SSLContext - util.HAS_SNI = orig_util_HAS_SNI - util.ssl_.HAS_SNI = orig_util_HAS_SNI util.IS_SECURETRANSPORT = False util.ssl_.IS_SECURETRANSPORT = False -def _read_callback(connection_id, data_buffer, data_length_pointer): +def _read_callback( + connection_id: int, data_buffer: int, data_length_pointer: bytearray +) -> int: """ SecureTransport read callback. This is called by ST to request that data be returned from the socket. @@ -230,7 +214,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer): while read_count < requested_length: if timeout is None or timeout >= 0: if not util.wait_for_read(base_socket, timeout): - raise socket.error(errno.EAGAIN, "timed out") + raise OSError(errno.EAGAIN, "timed out") remaining = requested_length - read_count buffer = (ctypes.c_char * remaining).from_address( @@ -242,7 +226,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer): if not read_count: return SecurityConst.errSSLClosedGraceful break - except (socket.error) as e: + except OSError as e: error = e.errno if error is not None and error != errno.EAGAIN: @@ -263,7 +247,9 @@ def _read_callback(connection_id, data_buffer, data_length_pointer): return SecurityConst.errSSLInternal -def _write_callback(connection_id, data_buffer, data_length_pointer): +def _write_callback( + connection_id: int, data_buffer: int, data_length_pointer: bytearray +) -> int: """ SecureTransport write callback. This is called by ST to request that data actually be sent on the network. @@ -286,14 +272,14 @@ def _write_callback(connection_id, data_buffer, data_length_pointer): while sent < bytes_to_write: if timeout is None or timeout >= 0: if not util.wait_for_write(base_socket, timeout): - raise socket.error(errno.EAGAIN, "timed out") + raise OSError(errno.EAGAIN, "timed out") chunk_sent = base_socket.send(data) sent += chunk_sent # This has some needless copying here, but I'm not sure there's # much value in optimising this data path. data = data[chunk_sent:] - except (socket.error) as e: + except OSError as e: error = e.errno if error is not None and error != errno.EAGAIN: @@ -321,22 +307,20 @@ def _write_callback(connection_id, data_buffer, data_length_pointer): _write_callback_pointer = Security.SSLWriteFunc(_write_callback) -class WrappedSocket(object): +class WrappedSocket: """ API-compatibility wrapper for Python's OpenSSL wrapped socket object. - - Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage - collector of PyPy. """ - def __init__(self, socket): + def __init__(self, socket: socket_cls) -> None: self.socket = socket self.context = None - self._makefile_refs = 0 + self._io_refs = 0 self._closed = False - self._exception = None + self._real_closed = False + self._exception: Exception | None = None self._keychain = None - self._keychain_dir = None + self._keychain_dir: str | None = None self._client_cert_chain = None # We save off the previously-configured timeout and then set it to @@ -348,7 +332,7 @@ def __init__(self, socket): self.socket.settimeout(0) @contextlib.contextmanager - def _raise_on_error(self): + def _raise_on_error(self) -> typing.Generator[None, None, None]: """ A context manager that can be used to wrap calls that do I/O from SecureTransport. If any of the I/O callbacks hit an exception, this @@ -365,23 +349,10 @@ def _raise_on_error(self): yield if self._exception is not None: exception, self._exception = self._exception, None - self.close() + self._real_close() raise exception - def _set_ciphers(self): - """ - Sets up the allowed ciphers. By default this matches the set in - util.ssl_.DEFAULT_CIPHERS, at least as supported by macOS. This is done - custom and doesn't allow changing at this time, mostly because parsing - OpenSSL cipher strings is going to be a freaking nightmare. - """ - ciphers = (Security.SSLCipherSuite * len(CIPHER_SUITES))(*CIPHER_SUITES) - result = Security.SSLSetEnabledCiphers( - self.context, ciphers, len(CIPHER_SUITES) - ) - _assert_no_error(result) - - def _set_alpn_protocols(self, protocols): + def _set_alpn_protocols(self, protocols: list[bytes] | None) -> None: """ Sets up the ALPN protocols on the context. """ @@ -394,7 +365,7 @@ def _set_alpn_protocols(self, protocols): finally: CoreFoundation.CFRelease(protocols_arr) - def _custom_validate(self, verify, trust_bundle): + def _custom_validate(self, verify: bool, trust_bundle: bytes | None) -> None: """ Called when we have set custom validation. We do this in two cases: first, when cert validation is entirely disabled; and second, when @@ -402,7 +373,7 @@ def _custom_validate(self, verify, trust_bundle): Raises an SSLError if the connection is not trusted. """ # If we disabled cert validation, just say: cool. - if not verify: + if not verify or trust_bundle is None: return successes = ( @@ -413,10 +384,12 @@ def _custom_validate(self, verify, trust_bundle): trust_result = self._evaluate_trust(trust_bundle) if trust_result in successes: return - reason = "error code: %d" % (trust_result,) + reason = f"error code: {int(trust_result)}" + exc = None except Exception as e: # Do not trust on error - reason = "exception: %r" % (e,) + reason = f"exception: {e!r}" + exc = e # SecureTransport does not send an alert nor shuts down the connection. rec = _build_tls_unknown_ca_alert(self.version()) @@ -426,10 +399,10 @@ def _custom_validate(self, verify, trust_bundle): # l_linger = 0, linger for 0 seoncds opts = struct.pack("ii", 1, 0) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts) - self.close() - raise ssl.SSLError("certificate verify failed, %s" % reason) + self._real_close() + raise ssl.SSLError(f"certificate verify failed, {reason}") from exc - def _evaluate_trust(self, trust_bundle): + def _evaluate_trust(self, trust_bundle: bytes) -> int: # We want data in memory, so load it up. if os.path.isfile(trust_bundle): with open(trust_bundle, "rb") as f: @@ -467,20 +440,20 @@ def _evaluate_trust(self, trust_bundle): if cert_array is not None: CoreFoundation.CFRelease(cert_array) - return trust_result.value + return trust_result.value # type: ignore[no-any-return] def handshake( self, - server_hostname, - verify, - trust_bundle, - min_version, - max_version, - client_cert, - client_key, - client_key_passphrase, - alpn_protocols, - ): + server_hostname: bytes | str | None, + verify: bool, + trust_bundle: bytes | None, + min_version: int, + max_version: int, + client_cert: str | None, + client_key: str | None, + client_key_passphrase: typing.Any, + alpn_protocols: list[bytes] | None, + ) -> None: """ Actually performs the TLS handshake. This is run automatically by wrapped socket, and shouldn't be needed in user code. @@ -508,6 +481,8 @@ def handshake( _assert_no_error(result) # If we have a server hostname, we should set that too. + # RFC6066 Section 3 tells us not to use SNI when the host is an IP, but we have + # to do it anyway to match server_hostname against the server certificate if server_hostname: if not isinstance(server_hostname, bytes): server_hostname = server_hostname.encode("utf-8") @@ -517,9 +492,6 @@ def handshake( ) _assert_no_error(result) - # Setup the ciphers. - self._set_ciphers() - # Setup the ALPN protocols. self._set_alpn_protocols(alpn_protocols) @@ -562,25 +534,27 @@ def handshake( _assert_no_error(result) break - def fileno(self): + def fileno(self) -> int: return self.socket.fileno() # Copy-pasted from Python 3.5 source code - def _decref_socketios(self): - if self._makefile_refs > 0: - self._makefile_refs -= 1 + def _decref_socketios(self) -> None: + if self._io_refs > 0: + self._io_refs -= 1 if self._closed: self.close() - def recv(self, bufsiz): + def recv(self, bufsiz: int) -> bytes: buffer = ctypes.create_string_buffer(bufsiz) bytes_read = self.recv_into(buffer, bufsiz) data = buffer[:bytes_read] - return data + return typing.cast(bytes, data) - def recv_into(self, buffer, nbytes=None): + def recv_into( + self, buffer: ctypes.Array[ctypes.c_char], nbytes: int | None = None + ) -> int: # Read short on EOF. - if self._closed: + if self._real_closed: return 0 if nbytes is None: @@ -613,7 +587,7 @@ def recv_into(self, buffer, nbytes=None): # well. Note that we don't actually return here because in # principle this could actually be fired along with return data. # It's unlikely though. - self.close() + self._real_close() else: _assert_no_error(result) @@ -621,13 +595,13 @@ def recv_into(self, buffer, nbytes=None): # was actually read. return processed_bytes.value - def settimeout(self, timeout): + def settimeout(self, timeout: float) -> None: self._timeout = timeout - def gettimeout(self): + def gettimeout(self) -> float | None: return self._timeout - def send(self, data): + def send(self, data: bytes) -> int: processed_bytes = ctypes.c_size_t(0) with self._raise_on_error(): @@ -644,36 +618,38 @@ def send(self, data): # We sent, and probably succeeded. Tell them how much we sent. return processed_bytes.value - def sendall(self, data): + def sendall(self, data: bytes) -> None: total_sent = 0 while total_sent < len(data): sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]) total_sent += sent - def shutdown(self): + def shutdown(self) -> None: with self._raise_on_error(): Security.SSLClose(self.context) - def close(self): + def close(self) -> None: + self._closed = True # TODO: should I do clean shutdown here? Do I have to? - if self._makefile_refs < 1: - self._closed = True - if self.context: - CoreFoundation.CFRelease(self.context) - self.context = None - if self._client_cert_chain: - CoreFoundation.CFRelease(self._client_cert_chain) - self._client_cert_chain = None - if self._keychain: - Security.SecKeychainDelete(self._keychain) - CoreFoundation.CFRelease(self._keychain) - shutil.rmtree(self._keychain_dir) - self._keychain = self._keychain_dir = None - return self.socket.close() - else: - self._makefile_refs -= 1 - - def getpeercert(self, binary_form=False): + if self._io_refs <= 0: + self._real_close() + + def _real_close(self) -> None: + self._real_closed = True + if self.context: + CoreFoundation.CFRelease(self.context) + self.context = None + if self._client_cert_chain: + CoreFoundation.CFRelease(self._client_cert_chain) + self._client_cert_chain = None + if self._keychain: + Security.SecKeychainDelete(self._keychain) + CoreFoundation.CFRelease(self._keychain) + shutil.rmtree(self._keychain_dir) + self._keychain = self._keychain_dir = None + return self.socket.close() + + def getpeercert(self, binary_form: bool = False) -> bytes | None: # Urgh, annoying. # # Here's how we do this: @@ -731,7 +707,7 @@ def getpeercert(self, binary_form=False): return der_bytes - def version(self): + def version(self) -> str: protocol = Security.SSLProtocol() result = Security.SSLGetNegotiatedProtocolVersion( self.context, ctypes.byref(protocol) @@ -750,56 +726,50 @@ def version(self): elif protocol.value == SecurityConst.kSSLProtocol2: return "SSLv2" else: - raise ssl.SSLError("Unknown TLS version: %r" % protocol) - - def _reuse(self): - self._makefile_refs += 1 - - def _drop(self): - if self._makefile_refs < 1: - self.close() - else: - self._makefile_refs -= 1 - - -if _fileobject: # Platform-specific: Python 2 - - def makefile(self, mode, bufsize=-1): - self._makefile_refs += 1 - return _fileobject(self, mode, bufsize, close=True) - + raise ssl.SSLError(f"Unknown TLS version: {protocol!r}") -else: # Platform-specific: Python 3 - def makefile(self, mode="r", buffering=None, *args, **kwargs): - # We disable buffering with SecureTransport because it conflicts with - # the buffering that ST does internally (see issue #1153 for more). - buffering = 0 - return backport_makefile(self, mode, buffering, *args, **kwargs) +def makefile( + self: socket_cls, + mode: ( + Literal["r"] | Literal["w"] | Literal["rw"] | Literal["wr"] | Literal[""] + ) = "r", + buffering: int | None = None, + *args: typing.Any, + **kwargs: typing.Any, +) -> typing.BinaryIO | typing.TextIO: + # We disable buffering with SecureTransport because it conflicts with + # the buffering that ST does internally (see issue #1153 for more). + buffering = 0 + return socket_cls.makefile(self, mode, buffering, *args, **kwargs) -WrappedSocket.makefile = makefile +WrappedSocket.makefile = makefile # type: ignore[attr-defined] -class SecureTransportContext(object): +class SecureTransportContext: """ I am a wrapper class for the SecureTransport library, to translate the interface of the standard library ``SSLContext`` object to calls into SecureTransport. """ - def __init__(self, protocol): - self._min_version, self._max_version = _protocol_to_min_max[protocol] + def __init__(self, protocol: int) -> None: + self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED + self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED + if protocol not in (None, ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_CLIENT): + self._min_version, self._max_version = _protocol_to_min_max[protocol] + self._options = 0 self._verify = False - self._trust_bundle = None - self._client_cert = None - self._client_key = None + self._trust_bundle: bytes | None = None + self._client_cert: str | None = None + self._client_key: str | None = None self._client_key_passphrase = None - self._alpn_protocols = None + self._alpn_protocols: list[bytes] | None = None @property - def check_hostname(self): + def check_hostname(self) -> Literal[True]: """ SecureTransport cannot have its hostname checking disabled. For more, see the comment on getpeercert() in this file. @@ -807,15 +777,14 @@ def check_hostname(self): return True @check_hostname.setter - def check_hostname(self, value): + def check_hostname(self, value: typing.Any) -> None: """ SecureTransport cannot have its hostname checking disabled. For more, see the comment on getpeercert() in this file. """ - pass @property - def options(self): + def options(self) -> int: # TODO: Well, crap. # # So this is the bit of the code that is the most likely to cause us @@ -825,19 +794,19 @@ def options(self): return self._options @options.setter - def options(self, value): + def options(self, value: int) -> None: # TODO: Update in line with above. self._options = value @property - def verify_mode(self): + def verify_mode(self) -> int: return ssl.CERT_REQUIRED if self._verify else ssl.CERT_NONE @verify_mode.setter - def verify_mode(self, value): - self._verify = True if value == ssl.CERT_REQUIRED else False + def verify_mode(self, value: int) -> None: + self._verify = value == ssl.CERT_REQUIRED - def set_default_verify_paths(self): + def set_default_verify_paths(self) -> None: # So, this has to do something a bit weird. Specifically, what it does # is nothing. # @@ -849,15 +818,18 @@ def set_default_verify_paths(self): # ignoring it. pass - def load_default_certs(self): + def load_default_certs(self) -> None: return self.set_default_verify_paths() - def set_ciphers(self, ciphers): - # For now, we just require the default cipher string. - if ciphers != util.ssl_.DEFAULT_CIPHERS: - raise ValueError("SecureTransport doesn't support custom cipher strings") + def set_ciphers(self, ciphers: typing.Any) -> None: + raise ValueError("SecureTransport doesn't support custom cipher strings") - def load_verify_locations(self, cafile=None, capath=None, cadata=None): + def load_verify_locations( + self, + cafile: str | None = None, + capath: str | None = None, + cadata: bytes | None = None, + ) -> None: # OK, we only really support cadata and cafile. if capath is not None: raise ValueError("SecureTransport does not support cert directories") @@ -867,14 +839,19 @@ def load_verify_locations(self, cafile=None, capath=None, cadata=None): with open(cafile): pass - self._trust_bundle = cafile or cadata + self._trust_bundle = cafile or cadata # type: ignore[assignment] - def load_cert_chain(self, certfile, keyfile=None, password=None): + def load_cert_chain( + self, + certfile: str, + keyfile: str | None = None, + password: str | None = None, + ) -> None: self._client_cert = certfile self._client_key = keyfile self._client_cert_passphrase = password - def set_alpn_protocols(self, protocols): + def set_alpn_protocols(self, protocols: list[str | bytes]) -> None: """ Sets the ALPN protocols that will later be set on the context. @@ -884,16 +861,16 @@ def set_alpn_protocols(self, protocols): raise NotImplementedError( "SecureTransport supports ALPN only in macOS 10.12+" ) - self._alpn_protocols = [six.ensure_binary(p) for p in protocols] + self._alpn_protocols = [util.util.to_bytes(p, "ascii") for p in protocols] def wrap_socket( self, - sock, - server_side=False, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - server_hostname=None, - ): + sock: socket_cls, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: bytes | str | None = None, + ) -> WrappedSocket: # So, what do we do here? Firstly, we assert some properties. This is a # stripped down shim, so there is some functionality we don't support. # See PEP 543 for the real deal. @@ -910,11 +887,27 @@ def wrap_socket( server_hostname, self._verify, self._trust_bundle, - self._min_version, - self._max_version, + _tls_version_to_st[self._minimum_version], + _tls_version_to_st[self._maximum_version], self._client_cert, self._client_key, self._client_key_passphrase, self._alpn_protocols, ) return wrapped_socket + + @property + def minimum_version(self) -> int: + return self._minimum_version + + @minimum_version.setter + def minimum_version(self, minimum_version: int) -> None: + self._minimum_version = minimum_version + + @property + def maximum_version(self) -> int: + return self._maximum_version + + @maximum_version.setter + def maximum_version(self, maximum_version: int) -> None: + self._maximum_version = maximum_version diff --git a/src/urllib3/contrib/socks.py b/src/urllib3/contrib/socks.py index 93df8325d5..5e552ddaed 100644 --- a/src/urllib3/contrib/socks.py +++ b/src/urllib3/contrib/socks.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ This module contains provisional support for SOCKS proxies from within urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and @@ -38,10 +37,11 @@ proxy_url="socks5h://:@proxy-host" """ -from __future__ import absolute_import + +from __future__ import annotations try: - import socks + import socks # type: ignore[import] except ImportError: import warnings @@ -57,7 +57,7 @@ ) raise -from socket import error as SocketError +import typing from socket import timeout as SocketTimeout from ..connection import HTTPConnection, HTTPSConnection @@ -69,7 +69,21 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore[assignment] + +try: + from typing import TypedDict + + class _TYPE_SOCKS_OPTIONS(TypedDict): + socks_version: int + proxy_host: str | None + proxy_port: str | None + username: str | None + password: str | None + rdns: bool + +except ImportError: # Python 3.7 + _TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment] class SOCKSConnection(HTTPConnection): @@ -77,15 +91,20 @@ class SOCKSConnection(HTTPConnection): A plain-text HTTP connection that connects via a SOCKS proxy. """ - def __init__(self, *args, **kwargs): - self._socks_options = kwargs.pop("_socks_options") - super(SOCKSConnection, self).__init__(*args, **kwargs) - - def _new_conn(self): + def __init__( + self, + _socks_options: _TYPE_SOCKS_OPTIONS, + *args: typing.Any, + **kwargs: typing.Any, + ) -> None: + self._socks_options = _socks_options + super().__init__(*args, **kwargs) + + def _new_conn(self) -> socks.socksocket: """ Establish a new connection via the SOCKS proxy. """ - extra_kw = {} + extra_kw: dict[str, typing.Any] = {} if self.source_address: extra_kw["source_address"] = self.source_address @@ -102,15 +121,14 @@ def _new_conn(self): proxy_password=self._socks_options["password"], proxy_rdns=self._socks_options["rdns"], timeout=self.timeout, - **extra_kw + **extra_kw, ) - except SocketTimeout: + except SocketTimeout as e: raise ConnectTimeoutError( self, - "Connection to %s timed out. (connect timeout=%s)" - % (self.host, self.timeout), - ) + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e except socks.ProxyError as e: # This is fragile as hell, but it seems to be the only way to raise @@ -120,22 +138,23 @@ def _new_conn(self): if isinstance(error, SocketTimeout): raise ConnectTimeoutError( self, - "Connection to %s timed out. (connect timeout=%s)" - % (self.host, self.timeout), - ) + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e else: + # Adding `from e` messes with coverage somehow, so it's omitted. + # See #2386. raise NewConnectionError( - self, "Failed to establish a new connection: %s" % error + self, f"Failed to establish a new connection: {error}" ) else: raise NewConnectionError( - self, "Failed to establish a new connection: %s" % e - ) + self, f"Failed to establish a new connection: {e}" + ) from e - except SocketError as e: # Defensive: PySocks should catch all these. + except OSError as e: # Defensive: PySocks should catch all these. raise NewConnectionError( - self, "Failed to establish a new connection: %s" % e - ) + self, f"Failed to establish a new connection: {e}" + ) from e return conn @@ -169,12 +188,12 @@ class SOCKSProxyManager(PoolManager): def __init__( self, - proxy_url, - username=None, - password=None, - num_pools=10, - headers=None, - **connection_pool_kw + proxy_url: str, + username: str | None = None, + password: str | None = None, + num_pools: int = 10, + headers: typing.Mapping[str, str] | None = None, + **connection_pool_kw: typing.Any, ): parsed = parse_url(proxy_url) @@ -195,7 +214,7 @@ def __init__( socks_version = socks.PROXY_TYPE_SOCKS4 rdns = True else: - raise ValueError("Unable to determine SOCKS version from %s" % proxy_url) + raise ValueError(f"Unable to determine SOCKS version from {proxy_url}") self.proxy_url = proxy_url @@ -209,8 +228,6 @@ def __init__( } connection_pool_kw["_socks_options"] = socks_options - super(SOCKSProxyManager, self).__init__( - num_pools, headers, **connection_pool_kw - ) + super().__init__(num_pools, headers, **connection_pool_kw) self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme diff --git a/src/urllib3/exceptions.py b/src/urllib3/exceptions.py index d69958d5df..7f42689a4d 100644 --- a/src/urllib3/exceptions.py +++ b/src/urllib3/exceptions.py @@ -1,6 +1,16 @@ -from __future__ import absolute_import +from __future__ import annotations -from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead +import socket +import typing +import warnings +from email.errors import MessageDefect +from http.client import IncompleteRead as httplib_IncompleteRead + +if typing.TYPE_CHECKING: + from .connection import HTTPConnection + from .connectionpool import ConnectionPool + from .response import HTTPResponse + from .util.retry import Retry # Base Exceptions @@ -8,23 +18,24 @@ class HTTPError(Exception): """Base exception used by this module.""" - pass - class HTTPWarning(Warning): """Base warning used by this module.""" - pass + +_TYPE_REDUCE_RESULT = typing.Tuple[ + typing.Callable[..., object], typing.Tuple[object, ...] +] class PoolError(HTTPError): """Base exception for errors caused within a pool.""" - def __init__(self, pool, message): + def __init__(self, pool: ConnectionPool, message: str) -> None: self.pool = pool - HTTPError.__init__(self, "%s: %s" % (pool, message)) + super().__init__(f"{pool}: {message}") - def __reduce__(self): + def __reduce__(self) -> _TYPE_REDUCE_RESULT: # For pickling purposes. return self.__class__, (None, None) @@ -32,11 +43,11 @@ def __reduce__(self): class RequestError(PoolError): """Base exception for PoolErrors that have associated URLs.""" - def __init__(self, pool, url, message): + def __init__(self, pool: ConnectionPool, url: str, message: str) -> None: self.url = url - PoolError.__init__(self, pool, message) + super().__init__(pool, message) - def __reduce__(self): + def __reduce__(self) -> _TYPE_REDUCE_RESULT: # For pickling purposes. return self.__class__, (None, self.url, None) @@ -44,28 +55,25 @@ def __reduce__(self): class SSLError(HTTPError): """Raised when SSL certificate fails in an HTTPS connection.""" - pass - class ProxyError(HTTPError): """Raised when the connection to a proxy fails.""" - def __init__(self, message, error, *args): - super(ProxyError, self).__init__(message, error, *args) + # The original error is also available as __cause__. + original_error: Exception + + def __init__(self, message: str, error: Exception) -> None: + super().__init__(message, error) self.original_error = error class DecodeError(HTTPError): """Raised when automatic decoding based on Content-Type fails.""" - pass - class ProtocolError(HTTPError): """Raised when something unexpected happens mid-request/response.""" - pass - #: Renamed to ProtocolError but aliased for backwards compatibility. ConnectionError = ProtocolError @@ -79,33 +87,36 @@ class MaxRetryError(RequestError): :param pool: The connection pool :type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool` - :param string url: The requested Url - :param exceptions.Exception reason: The underlying error + :param str url: The requested Url + :param reason: The underlying error + :type reason: :class:`Exception` """ - def __init__(self, pool, url, reason=None): + def __init__( + self, pool: ConnectionPool, url: str, reason: Exception | None = None + ) -> None: self.reason = reason - message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason) + message = f"Max retries exceeded with url: {url} (Caused by {reason!r})" - RequestError.__init__(self, pool, url, message) + super().__init__(pool, url, message) class HostChangedError(RequestError): """Raised when an existing pool gets a request for a foreign host.""" - def __init__(self, pool, url, retries=3): - message = "Tried to open a foreign host with url: %s" % url - RequestError.__init__(self, pool, url, message) + def __init__( + self, pool: ConnectionPool, url: str, retries: Retry | int = 3 + ) -> None: + message = f"Tried to open a foreign host with url: {url}" + super().__init__(pool, url, message) self.retries = retries class TimeoutStateError(HTTPError): """Raised when passing an invalid state to a timeout""" - pass - class TimeoutError(HTTPError): """Raised when a socket timeout error occurs. @@ -114,53 +125,66 @@ class TimeoutError(HTTPError): ` and :exc:`ConnectTimeoutErrors `. """ - pass - class ReadTimeoutError(TimeoutError, RequestError): """Raised when a socket timeout occurs while receiving data from a server""" - pass - # This timeout error does not have a URL attached and needs to inherit from the # base HTTPError class ConnectTimeoutError(TimeoutError): """Raised when a socket timeout occurs while connecting to a server""" - pass - -class NewConnectionError(ConnectTimeoutError, PoolError): +class NewConnectionError(ConnectTimeoutError, HTTPError): """Raised when we fail to establish a new connection. Usually ECONNREFUSED.""" - pass + def __init__(self, conn: HTTPConnection, message: str) -> None: + self.conn = conn + super().__init__(f"{conn}: {message}") + + @property + def pool(self) -> HTTPConnection: + warnings.warn( + "The 'pool' property is deprecated and will be removed " + "in urllib3 v2.1.0. Use 'conn' instead.", + DeprecationWarning, + stacklevel=2, + ) + + return self.conn + + +class NameResolutionError(NewConnectionError): + """Raised when host name resolution fails.""" + + def __init__(self, host: str, conn: HTTPConnection, reason: socket.gaierror): + message = f"Failed to resolve '{host}' ({reason})" + super().__init__(conn, message) class EmptyPoolError(PoolError): """Raised when a pool runs out of connections and no more are allowed.""" - pass + +class FullPoolError(PoolError): + """Raised when we try to add a connection to a full pool in blocking mode.""" class ClosedPoolError(PoolError): """Raised when a request enters a pool after the pool has been closed.""" - pass - class LocationValueError(ValueError, HTTPError): """Raised when there is something wrong with a given URL input.""" - pass - class LocationParseError(LocationValueError): """Raised when get_host or similar fails to parse the URL input.""" - def __init__(self, location): - message = "Failed to parse: %s" % location - HTTPError.__init__(self, message) + def __init__(self, location: str) -> None: + message = f"Failed to parse: {location}" + super().__init__(message) self.location = location @@ -168,9 +192,9 @@ def __init__(self, location): class URLSchemeUnknown(LocationValueError): """Raised when a URL input has an unsupported scheme.""" - def __init__(self, scheme): - message = "Not supported URL scheme %s" % scheme - super(URLSchemeUnknown, self).__init__(message) + def __init__(self, scheme: str): + message = f"Not supported URL scheme {scheme}" + super().__init__(message) self.scheme = scheme @@ -185,38 +209,18 @@ class ResponseError(HTTPError): class SecurityWarning(HTTPWarning): """Warned when performing security reducing actions""" - pass - - -class SubjectAltNameWarning(SecurityWarning): - """Warned when connecting to a host with a certificate missing a SAN.""" - - pass - class InsecureRequestWarning(SecurityWarning): """Warned when making an unverified HTTPS request.""" - pass - class SystemTimeWarning(SecurityWarning): """Warned when system time is suspected to be wrong""" - pass - class InsecurePlatformWarning(SecurityWarning): """Warned when certain TLS/SSL configuration is not available on a platform.""" - pass - - -class SNIMissingWarning(HTTPWarning): - """Warned when making a HTTPS request without SNI available.""" - - pass - class DependencyWarning(HTTPWarning): """ @@ -224,14 +228,10 @@ class DependencyWarning(HTTPWarning): dependencies. """ - pass - class ResponseNotChunked(ProtocolError, ValueError): """Response needs to be chunked in order to read it as chunks.""" - pass - class BodyNotHttplibCompatible(HTTPError): """ @@ -239,8 +239,6 @@ class BodyNotHttplibCompatible(HTTPError): (have an fp attribute which returns raw chunks) for read_chunked(). """ - pass - class IncompleteRead(HTTPError, httplib_IncompleteRead): """ @@ -250,12 +248,13 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead): for ``partial`` to avoid creating large objects on streamed reads. """ - def __init__(self, partial, expected): - super(IncompleteRead, self).__init__(partial, expected) + def __init__(self, partial: int, expected: int) -> None: + self.partial = partial # type: ignore[assignment] + self.expected = expected - def __repr__(self): + def __repr__(self) -> str: return "IncompleteRead(%i bytes read, %i more expected)" % ( - self.partial, + self.partial, # type: ignore[str-format] self.expected, ) @@ -263,14 +262,13 @@ def __repr__(self): class InvalidChunkLength(HTTPError, httplib_IncompleteRead): """Invalid chunk length in a chunked response.""" - def __init__(self, response, length): - super(InvalidChunkLength, self).__init__( - response.tell(), response.length_remaining - ) + def __init__(self, response: HTTPResponse, length: bytes) -> None: + self.partial: int = response.tell() # type: ignore[assignment] + self.expected: int | None = response.length_remaining self.response = response self.length = length - def __repr__(self): + def __repr__(self) -> str: return "InvalidChunkLength(got length %r, %i bytes read)" % ( self.length, self.partial, @@ -280,34 +278,37 @@ def __repr__(self): class InvalidHeader(HTTPError): """The header provided was somehow invalid.""" - pass - class ProxySchemeUnknown(AssertionError, URLSchemeUnknown): """ProxyManager does not support the supplied scheme""" # TODO(t-8ch): Stop inheriting from AssertionError in v2.0. - def __init__(self, scheme): - message = "Not supported proxy scheme %s" % scheme - super(ProxySchemeUnknown, self).__init__(message) + def __init__(self, scheme: str | None) -> None: + # 'localhost' is here because our URL parser parses + # localhost:8080 -> scheme=localhost, remove if we fix this. + if scheme == "localhost": + scheme = None + if scheme is None: + message = "Proxy URL had no scheme, should start with http:// or https://" + else: + message = f"Proxy URL had unsupported scheme {scheme}, should use http:// or https://" + super().__init__(message) class ProxySchemeUnsupported(ValueError): """Fetching HTTPS resources through HTTPS proxies is unsupported""" - pass - class HeaderParsingError(HTTPError): """Raised by assert_header_parsing, but we convert it to a log.warning statement.""" - def __init__(self, defects, unparsed_data): - message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data) - super(HeaderParsingError, self).__init__(message) + def __init__( + self, defects: list[MessageDefect], unparsed_data: bytes | str | None + ) -> None: + message = f"{defects or 'Unknown'}, unparsed data: {unparsed_data!r}" + super().__init__(message) class UnrewindableBodyError(HTTPError): """urllib3 encountered an error when trying to rewind a body""" - - pass diff --git a/src/urllib3/exceptions.pyi b/src/urllib3/exceptions.pyi deleted file mode 100644 index ca528b09ad..0000000000 --- a/src/urllib3/exceptions.pyi +++ /dev/null @@ -1,55 +0,0 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union - -if TYPE_CHECKING: - from urllib3.connectionpool import ConnectionPool - -class HTTPError(Exception): ... -class HTTPWarning(Warning): ... - -class PoolError(HTTPError): - pool: ConnectionPool - def __init__(self, pool: ConnectionPool, message: str) -> None: ... - def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ... - -class RequestError(PoolError): - url: str - def __init__(self, pool: ConnectionPool, url: str, message: str) -> None: ... - def __reduce__(self) -> Union[str, Tuple[Any, ...]]: ... - -class SSLError(HTTPError): ... -class ProxyError(HTTPError): ... -class DecodeError(HTTPError): ... -class ProtocolError(HTTPError): ... - -ConnectionError: ProtocolError - -class MaxRetryError(RequestError): - reason: str - def __init__( - self, pool: ConnectionPool, url: str, reason: Optional[str] - ) -> None: ... - -class HostChangedError(RequestError): - retries: int - def __init__(self, pool: ConnectionPool, url: str, retries: int) -> None: ... - -class TimeoutStateError(HTTPError): ... -class TimeoutError(HTTPError): ... -class ReadTimeoutError(TimeoutError, RequestError): ... -class ConnectTimeoutError(TimeoutError): ... -class EmptyPoolError(PoolError): ... -class ClosedPoolError(PoolError): ... -class LocationValueError(ValueError, HTTPError): ... - -class LocationParseError(LocationValueError): - location: str - def __init__(self, location: str) -> None: ... - -class ResponseError(HTTPError): - GENERIC_ERROR: Any - SPECIFIC_ERROR: Any - -class SecurityWarning(HTTPWarning): ... -class InsecureRequestWarning(SecurityWarning): ... -class SystemTimeWarning(SecurityWarning): ... -class InsecurePlatformWarning(SecurityWarning): ... diff --git a/src/urllib3/fields.py b/src/urllib3/fields.py index 9d630f491d..51d898e24f 100644 --- a/src/urllib3/fields.py +++ b/src/urllib3/fields.py @@ -1,13 +1,20 @@ -from __future__ import absolute_import +from __future__ import annotations import email.utils import mimetypes -import re +import typing -from .packages import six +_TYPE_FIELD_VALUE = typing.Union[str, bytes] +_TYPE_FIELD_VALUE_TUPLE = typing.Union[ + _TYPE_FIELD_VALUE, + typing.Tuple[str, _TYPE_FIELD_VALUE], + typing.Tuple[str, _TYPE_FIELD_VALUE, str], +] -def guess_content_type(filename, default="application/octet-stream"): +def guess_content_type( + filename: str | None, default: str = "application/octet-stream" +) -> str: """ Guess the "Content-Type" of a file. @@ -21,7 +28,7 @@ def guess_content_type(filename, default="application/octet-stream"): return default -def format_header_param_rfc2231(name, value): +def format_header_param_rfc2231(name: str, value: _TYPE_FIELD_VALUE) -> str: """ Helper function to format and quote a single header parameter using the strategy defined in RFC 2231. @@ -34,14 +41,28 @@ def format_header_param_rfc2231(name, value): The name of the parameter, a string expected to be ASCII only. :param value: The value of the parameter, provided as ``bytes`` or `str``. - :ret: + :returns: An RFC-2231-formatted unicode string. + + .. deprecated:: 2.0.0 + Will be removed in urllib3 v2.1.0. This is not valid for + ``multipart/form-data`` header parameters. """ - if isinstance(value, six.binary_type): + import warnings + + warnings.warn( + "'format_header_param_rfc2231' is deprecated and will be " + "removed in urllib3 v2.1.0. This is not valid for " + "multipart/form-data header parameters.", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(value, bytes): value = value.decode("utf-8") if not any(ch in value for ch in '"\\\r\n'): - result = u'%s="%s"' % (name, value) + result = f'{name}="{value}"' try: result.encode("ascii") except (UnicodeEncodeError, UnicodeDecodeError): @@ -49,81 +70,87 @@ def format_header_param_rfc2231(name, value): else: return result - if six.PY2: # Python 2: - value = value.encode("utf-8") - - # encode_rfc2231 accepts an encoded string and returns an ascii-encoded - # string in Python 2 but accepts and returns unicode strings in Python 3 value = email.utils.encode_rfc2231(value, "utf-8") - value = "%s*=%s" % (name, value) - - if six.PY2: # Python 2: - value = value.decode("utf-8") + value = f"{name}*={value}" return value -_HTML5_REPLACEMENTS = { - u"\u0022": u"%22", - # Replace "\" with "\\". - u"\u005C": u"\u005C\u005C", -} - -# All control characters from 0x00 to 0x1F *except* 0x1B. -_HTML5_REPLACEMENTS.update( - { - six.unichr(cc): u"%{:02X}".format(cc) - for cc in range(0x00, 0x1F + 1) - if cc not in (0x1B,) - } -) - - -def _replace_multiple(value, needles_and_replacements): - def replacer(match): - return needles_and_replacements[match.group(0)] - - pattern = re.compile( - r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()]) - ) - - result = pattern.sub(replacer, value) - - return result - - -def format_header_param_html5(name, value): +def format_multipart_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str: """ - Helper function to format and quote a single header parameter using the - HTML5 strategy. + Format and quote a single multipart header parameter. - Particularly useful for header parameters which might contain - non-ASCII values, like file names. This follows the `HTML5 Working Draft - Section 4.10.22.7`_ and matches the behavior of curl and modern browsers. + This follows the `WHATWG HTML Standard`_ as of 2021/06/10, matching + the behavior of current browser and curl versions. Values are + assumed to be UTF-8. The ``\\n``, ``\\r``, and ``"`` characters are + percent encoded. - .. _HTML5 Working Draft Section 4.10.22.7: - https://w3c.github.io/html/sec-forms.html#multipart-form-data + .. _WHATWG HTML Standard: + https://html.spec.whatwg.org/multipage/ + form-control-infrastructure.html#multipart-form-data :param name: - The name of the parameter, a string expected to be ASCII only. + The name of the parameter, an ASCII-only ``str``. :param value: - The value of the parameter, provided as ``bytes`` or `str``. - :ret: - A unicode string, stripped of troublesome characters. + The value of the parameter, a ``str`` or UTF-8 encoded + ``bytes``. + :returns: + A string ``name="value"`` with the escaped value. + + .. versionchanged:: 2.0.0 + Matches the WHATWG HTML Standard as of 2021/06/10. Control + characters are no longer percent encoded. + + .. versionchanged:: 2.0.0 + Renamed from ``format_header_param_html5`` and + ``format_header_param``. The old names will be removed in + urllib3 v2.1.0. """ - if isinstance(value, six.binary_type): + if isinstance(value, bytes): value = value.decode("utf-8") - value = _replace_multiple(value, _HTML5_REPLACEMENTS) + # percent encode \n \r " + value = value.translate({10: "%0A", 13: "%0D", 34: "%22"}) + return f'{name}="{value}"' - return u'%s="%s"' % (name, value) + +def format_header_param_html5(name: str, value: _TYPE_FIELD_VALUE) -> str: + """ + .. deprecated:: 2.0.0 + Renamed to :func:`format_multipart_header_param`. Will be + removed in urllib3 v2.1.0. + """ + import warnings + + warnings.warn( + "'format_header_param_html5' has been renamed to " + "'format_multipart_header_param'. The old name will be " + "removed in urllib3 v2.1.0.", + DeprecationWarning, + stacklevel=2, + ) + return format_multipart_header_param(name, value) -# For backwards-compatibility. -format_header_param = format_header_param_html5 +def format_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str: + """ + .. deprecated:: 2.0.0 + Renamed to :func:`format_multipart_header_param`. Will be + removed in urllib3 v2.1.0. + """ + import warnings + + warnings.warn( + "'format_header_param' has been renamed to " + "'format_multipart_header_param'. The old name will be " + "removed in urllib3 v2.1.0.", + DeprecationWarning, + stacklevel=2, + ) + return format_multipart_header_param(name, value) -class RequestField(object): +class RequestField: """ A data container for request body parameters. @@ -135,29 +162,47 @@ class RequestField(object): An optional filename of the request field. Must be unicode. :param headers: An optional dict-like object of headers to initially use for the field. - :param header_formatter: - An optional callable that is used to encode and format the headers. By - default, this is :func:`format_header_param_html5`. + + .. versionchanged:: 2.0.0 + The ``header_formatter`` parameter is deprecated and will + be removed in urllib3 v2.1.0. """ def __init__( self, - name, - data, - filename=None, - headers=None, - header_formatter=format_header_param_html5, + name: str, + data: _TYPE_FIELD_VALUE, + filename: str | None = None, + headers: typing.Mapping[str, str] | None = None, + header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None, ): self._name = name self._filename = filename self.data = data - self.headers = {} + self.headers: dict[str, str | None] = {} if headers: self.headers = dict(headers) - self.header_formatter = header_formatter + + if header_formatter is not None: + import warnings + + warnings.warn( + "The 'header_formatter' parameter is deprecated and " + "will be removed in urllib3 v2.1.0.", + DeprecationWarning, + stacklevel=2, + ) + self.header_formatter = header_formatter + else: + self.header_formatter = format_multipart_header_param @classmethod - def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5): + def from_tuples( + cls, + fieldname: str, + value: _TYPE_FIELD_VALUE_TUPLE, + header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None, + ) -> RequestField: """ A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters. @@ -174,11 +219,19 @@ def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html Field names and filenames must be unicode. """ + filename: str | None + content_type: str | None + data: _TYPE_FIELD_VALUE + if isinstance(value, tuple): if len(value) == 3: - filename, data, content_type = value + filename, data, content_type = typing.cast( + typing.Tuple[str, _TYPE_FIELD_VALUE, str], value + ) else: - filename, data = value + filename, data = typing.cast( + typing.Tuple[str, _TYPE_FIELD_VALUE], value + ) content_type = guess_content_type(filename) else: filename = None @@ -192,20 +245,29 @@ def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html return request_param - def _render_part(self, name, value): + def _render_part(self, name: str, value: _TYPE_FIELD_VALUE) -> str: """ - Overridable helper function to format a single header parameter. By - default, this calls ``self.header_formatter``. + Override this method to change how each multipart header + parameter is formatted. By default, this calls + :func:`format_multipart_header_param`. :param name: - The name of the parameter, a string expected to be ASCII only. + The name of the parameter, an ASCII-only ``str``. :param value: - The value of the parameter, provided as a unicode string. - """ + The value of the parameter, a ``str`` or UTF-8 encoded + ``bytes``. + :meta public: + """ return self.header_formatter(name, value) - def _render_parts(self, header_parts): + def _render_parts( + self, + header_parts: ( + dict[str, _TYPE_FIELD_VALUE | None] + | typing.Sequence[tuple[str, _TYPE_FIELD_VALUE | None]] + ), + ) -> str: """ Helper function to format and quote a single header. @@ -216,18 +278,21 @@ def _render_parts(self, header_parts): A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format as `k1="v1"; k2="v2"; ...`. """ + iterable: typing.Iterable[tuple[str, _TYPE_FIELD_VALUE | None]] + parts = [] - iterable = header_parts if isinstance(header_parts, dict): iterable = header_parts.items() + else: + iterable = header_parts for name, value in iterable: if value is not None: parts.append(self._render_part(name, value)) - return u"; ".join(parts) + return "; ".join(parts) - def render_headers(self): + def render_headers(self) -> str: """ Renders the headers for this request field. """ @@ -236,39 +301,45 @@ def render_headers(self): sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"] for sort_key in sort_keys: if self.headers.get(sort_key, False): - lines.append(u"%s: %s" % (sort_key, self.headers[sort_key])) + lines.append(f"{sort_key}: {self.headers[sort_key]}") for header_name, header_value in self.headers.items(): if header_name not in sort_keys: if header_value: - lines.append(u"%s: %s" % (header_name, header_value)) + lines.append(f"{header_name}: {header_value}") - lines.append(u"\r\n") - return u"\r\n".join(lines) + lines.append("\r\n") + return "\r\n".join(lines) def make_multipart( - self, content_disposition=None, content_type=None, content_location=None - ): + self, + content_disposition: str | None = None, + content_type: str | None = None, + content_location: str | None = None, + ) -> None: """ Makes this request field into a multipart request field. This method overrides "Content-Disposition", "Content-Type" and "Content-Location" headers to the request parameter. + :param content_disposition: + The 'Content-Disposition' of the request body. Defaults to 'form-data' :param content_type: The 'Content-Type' of the request body. :param content_location: The 'Content-Location' of the request body. """ - self.headers["Content-Disposition"] = content_disposition or u"form-data" - self.headers["Content-Disposition"] += u"; ".join( + content_disposition = (content_disposition or "form-data") + "; ".join( [ - u"", + "", self._render_parts( - ((u"name", self._name), (u"filename", self._filename)) + (("name", self._name), ("filename", self._filename)) ), ] ) + + self.headers["Content-Disposition"] = content_disposition self.headers["Content-Type"] = content_type self.headers["Content-Location"] = content_location diff --git a/src/urllib3/fields.pyi b/src/urllib3/fields.pyi deleted file mode 100644 index e1f3f90506..0000000000 --- a/src/urllib3/fields.pyi +++ /dev/null @@ -1,28 +0,0 @@ -# Stubs for requests.packages.urllib3.fields (Python 3.4) - -from typing import Any, Callable, Mapping, Optional - -def guess_content_type(filename: str, default: str) -> str: ... -def format_header_param_rfc2231(name: str, value: str) -> str: ... -def format_header_param_html5(name: str, value: str) -> str: ... -def format_header_param(name: str, value: str) -> str: ... - -class RequestField: - data: Any - headers: Optional[Mapping[str, str]] - def __init__( - self, - name: str, - data: Any, - filename: Optional[str], - headers: Optional[Mapping[str, str]], - header_formatter: Callable[[str, str], str], - ) -> None: ... - @classmethod - def from_tuples( - cls, fieldname: str, value: str, header_formatter: Callable[[str, str], str] - ) -> RequestField: ... - def render_headers(self) -> str: ... - def make_multipart( - self, content_disposition: str, content_type: str, content_location: str - ) -> None: ... diff --git a/src/urllib3/filepost.py b/src/urllib3/filepost.py index 36c9252c64..1c90a211fb 100644 --- a/src/urllib3/filepost.py +++ b/src/urllib3/filepost.py @@ -1,28 +1,32 @@ -from __future__ import absolute_import +from __future__ import annotations import binascii import codecs import os +import typing from io import BytesIO -from .fields import RequestField -from .packages import six -from .packages.six import b +from .fields import _TYPE_FIELD_VALUE_TUPLE, RequestField writer = codecs.lookup("utf-8")[3] +_TYPE_FIELDS_SEQUENCE = typing.Sequence[ + typing.Union[typing.Tuple[str, _TYPE_FIELD_VALUE_TUPLE], RequestField] +] +_TYPE_FIELDS = typing.Union[ + _TYPE_FIELDS_SEQUENCE, + typing.Mapping[str, _TYPE_FIELD_VALUE_TUPLE], +] -def choose_boundary(): + +def choose_boundary() -> str: """ Our embarrassingly-simple replacement for mimetools.choose_boundary. """ - boundary = binascii.hexlify(os.urandom(16)) - if not six.PY2: - boundary = boundary.decode("ascii") - return boundary + return binascii.hexlify(os.urandom(16)).decode() -def iter_field_objects(fields): +def iter_field_objects(fields: _TYPE_FIELDS) -> typing.Iterable[RequestField]: """ Iterate over fields. @@ -30,42 +34,29 @@ def iter_field_objects(fields): :class:`~urllib3.fields.RequestField`. """ - if isinstance(fields, dict): - i = six.iteritems(fields) + iterable: typing.Iterable[RequestField | tuple[str, _TYPE_FIELD_VALUE_TUPLE]] + + if isinstance(fields, typing.Mapping): + iterable = fields.items() else: - i = iter(fields) + iterable = fields - for field in i: + for field in iterable: if isinstance(field, RequestField): yield field else: yield RequestField.from_tuples(*field) -def iter_fields(fields): - """ - .. deprecated:: 1.6 - - Iterate over fields. - - The addition of :class:`~urllib3.fields.RequestField` makes this function - obsolete. Instead, use :func:`iter_field_objects`, which returns - :class:`~urllib3.fields.RequestField` objects. - - Supports list of (k, v) tuples and dicts. - """ - if isinstance(fields, dict): - return ((k, v) for k, v in six.iteritems(fields)) - - return ((k, v) for k, v in fields) - - -def encode_multipart_formdata(fields, boundary=None): +def encode_multipart_formdata( + fields: _TYPE_FIELDS, boundary: str | None = None +) -> tuple[bytes, str]: """ Encode a dictionary of ``fields`` using the multipart/form-data MIME format. :param fields: Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`). + Values are processed by :func:`urllib3.fields.RequestField.from_tuples`. :param boundary: If not specified, then a random boundary will be generated using @@ -76,7 +67,7 @@ def encode_multipart_formdata(fields, boundary=None): boundary = choose_boundary() for field in iter_field_objects(fields): - body.write(b("--%s\r\n" % (boundary))) + body.write(f"--{boundary}\r\n".encode("latin-1")) writer(body).write(field.render_headers()) data = field.data @@ -84,15 +75,15 @@ def encode_multipart_formdata(fields, boundary=None): if isinstance(data, int): data = str(data) # Backwards compatibility - if isinstance(data, six.text_type): + if isinstance(data, str): writer(body).write(data) else: body.write(data) body.write(b"\r\n") - body.write(b("--%s--\r\n" % (boundary))) + body.write(f"--{boundary}--\r\n".encode("latin-1")) - content_type = str("multipart/form-data; boundary=%s" % boundary) + content_type = f"multipart/form-data; boundary={boundary}" return body.getvalue(), content_type diff --git a/src/urllib3/filepost.pyi b/src/urllib3/filepost.pyi deleted file mode 100644 index 54cf62a468..0000000000 --- a/src/urllib3/filepost.pyi +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Any, Generator, List, Mapping, Optional, Tuple, Union - -from . import fields - -RequestField = fields.RequestField -Fields = Union[Mapping[str, str], List[Tuple[str]], List[RequestField]] -Iterator = Generator[Tuple[str], None, None] - -writer: Any - -def choose_boundary() -> str: ... -def iter_field_objects(fields: Fields) -> Iterator: ... -def iter_fields(fields: Fields) -> Iterator: ... -def encode_multipart_formdata( - fields: Fields, boundary: Optional[str] -) -> Tuple[str]: ... diff --git a/src/urllib3/packages/__init__.py b/src/urllib3/packages/__init__.py deleted file mode 100644 index fce4caa65d..0000000000 --- a/src/urllib3/packages/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import absolute_import - -from . import ssl_match_hostname - -__all__ = ("ssl_match_hostname",) diff --git a/src/urllib3/packages/__init__.pyi b/src/urllib3/packages/__init__.pyi deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/urllib3/packages/backports/__init__.py b/src/urllib3/packages/backports/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/urllib3/packages/backports/makefile.py b/src/urllib3/packages/backports/makefile.py deleted file mode 100644 index b8fb2154b6..0000000000 --- a/src/urllib3/packages/backports/makefile.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -""" -backports.makefile -~~~~~~~~~~~~~~~~~~ - -Backports the Python 3 ``socket.makefile`` method for use with anything that -wants to create a "fake" socket object. -""" -import io -from socket import SocketIO - - -def backport_makefile( - self, mode="r", buffering=None, encoding=None, errors=None, newline=None -): - """ - Backport of ``socket.makefile`` from Python 3.5. - """ - if not set(mode) <= {"r", "w", "b"}: - raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) - writing = "w" in mode - reading = "r" in mode or not writing - assert reading or writing - binary = "b" in mode - rawmode = "" - if reading: - rawmode += "r" - if writing: - rawmode += "w" - raw = SocketIO(self, rawmode) - self._makefile_refs += 1 - if buffering is None: - buffering = -1 - if buffering < 0: - buffering = io.DEFAULT_BUFFER_SIZE - if buffering == 0: - if not binary: - raise ValueError("unbuffered streams must be binary") - return raw - if reading and writing: - buffer = io.BufferedRWPair(raw, raw, buffering) - elif reading: - buffer = io.BufferedReader(raw, buffering) - else: - assert writing - buffer = io.BufferedWriter(raw, buffering) - if binary: - return buffer - text = io.TextIOWrapper(buffer, encoding, errors, newline) - text.mode = mode - return text diff --git a/src/urllib3/packages/six.py b/src/urllib3/packages/six.py deleted file mode 100644 index 314424099f..0000000000 --- a/src/urllib3/packages/six.py +++ /dev/null @@ -1,1021 +0,0 @@ -# Copyright (c) 2010-2019 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Utilities for writing code that runs on Python 2 and 3""" - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.12.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = (str,) - integer_types = (int,) - class_types = (type,) - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = (basestring,) - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - def __len__(self): - return 1 << 31 - - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - del X - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - - get_source = get_code # same as get_code - - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute( - "filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse" - ), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("getoutput", "commands", "subprocess"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute( - "reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload" - ), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute( - "zip_longest", "itertools", "itertools", "izip_longest", "zip_longest" - ), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule( - "email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart" - ), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [MovedModule("winreg", "_winreg")] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute( - "unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes" - ), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("splitvalue", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module( - Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", - "moves.urllib.parse", -) - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module( - Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", - "moves.urllib.error", -) - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), - MovedAttribute("parse_http_list", "urllib2", "urllib.request"), - MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module( - Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", - "moves.urllib.request", -) - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module( - Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", - "moves.urllib.response", -) - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser") -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = ( - _urllib_robotparser_moved_attributes -) - -_importer._add_module( - Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", - "moves.urllib.robotparser", -) - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ["parse", "error", "request", "response", "robotparser"] - - -_importer._add_module( - Module_six_moves_urllib(__name__ + ".moves.urllib"), "moves.urllib" -) - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - - def advance_iterator(it): - return it.next() - - -next = advance_iterator - - -try: - callable = callable -except NameError: - - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc( - get_unbound_function, """Get the function out of a possibly unbound function""" -) - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc( - iterlists, "Return an iterator over the (key, [values]) pairs of a dictionary." -) - - -if PY3: - - def b(s): - return s.encode("latin-1") - - def u(s): - return s - - unichr = chr - import struct - - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - - StringIO = io.StringIO - BytesIO = io.BytesIO - del io - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" -else: - - def b(s): - return s - - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r"\\", r"\\\\"), "unicode_escape") - - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - try: - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - finally: - value = None - tb = None - - -else: - - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_( - """def reraise(tp, value, tb=None): - try: - raise tp, value, tb - finally: - tb = None -""" - ) - - -if sys.version_info[:2] == (3, 2): - exec_( - """def raise_from(value, from_value): - try: - if from_value is None: - raise value - raise value from from_value - finally: - value = None -""" - ) -elif sys.version_info[:2] > (3, 2): - exec_( - """def raise_from(value, from_value): - try: - raise value from from_value - finally: - value = None -""" - ) -else: - - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if ( - isinstance(fp, file) - and isinstance(data, unicode) - and fp.encoding is not None - ): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) - - -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - - def wraps( - wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES, - ): - def wrapper(f): - f = functools.wraps(wrapped, assigned, updated)(f) - f.__wrapped__ = wrapped - return f - - return wrapper - - -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - - return type.__new__(metaclass, "temporary_class", (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get("__slots__") - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop("__dict__", None) - orig_vars.pop("__weakref__", None) - if hasattr(cls, "__qualname__"): - orig_vars["__qualname__"] = cls.__qualname__ - return metaclass(cls.__name__, cls.__bases__, orig_vars) - - return wrapper - - -def ensure_binary(s, encoding="utf-8", errors="strict"): - """Coerce **s** to six.binary_type. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` - """ - if isinstance(s, text_type): - return s.encode(encoding, errors) - elif isinstance(s, binary_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def ensure_str(s, encoding="utf-8", errors="strict"): - """Coerce *s* to `str`. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if not isinstance(s, (text_type, binary_type)): - raise TypeError("not expecting type '%s'" % type(s)) - if PY2 and isinstance(s, text_type): - s = s.encode(encoding, errors) - elif PY3 and isinstance(s, binary_type): - s = s.decode(encoding, errors) - return s - - -def ensure_text(s, encoding="utf-8", errors="strict"): - """Coerce *s* to six.text_type. - - For Python 2: - - `unicode` -> `unicode` - - `str` -> `unicode` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if isinstance(s, binary_type): - return s.decode(encoding, errors) - elif isinstance(s, text_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def python_2_unicode_compatible(klass): - """ - A decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if "__str__" not in klass.__dict__: - raise ValueError( - "@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % klass.__name__ - ) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode("utf-8") - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if ( - type(importer).__name__ == "_SixMetaPathImporter" - and importer.name == __name__ - ): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) diff --git a/src/urllib3/packages/six.pyi b/src/urllib3/packages/six.pyi deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/urllib3/packages/ssl_match_hostname/__init__.py b/src/urllib3/packages/ssl_match_hostname/__init__.py deleted file mode 100644 index 6b12fd90aa..0000000000 --- a/src/urllib3/packages/ssl_match_hostname/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -import sys - -try: - # Our match_hostname function is the same as 3.5's, so we only want to - # import the match_hostname function if it's at least that good. - if sys.version_info < (3, 5): - raise ImportError("Fallback to vendored code") - - from ssl import CertificateError, match_hostname -except ImportError: - try: - # Backport of the function from a pypi module - from backports.ssl_match_hostname import ( # type: ignore - CertificateError, - match_hostname, - ) - except ImportError: - # Our vendored copy - from ._implementation import CertificateError, match_hostname # type: ignore - -# Not needed, but documenting what we provide. -__all__ = ("CertificateError", "match_hostname") diff --git a/src/urllib3/packages/ssl_match_hostname/__init__.pyi b/src/urllib3/packages/ssl_match_hostname/__init__.pyi deleted file mode 100644 index 1915c0e5d0..0000000000 --- a/src/urllib3/packages/ssl_match_hostname/__init__.pyi +++ /dev/null @@ -1,4 +0,0 @@ -import ssl - -CertificateError = ssl.CertificateError -match_hostname = ssl.match_hostname diff --git a/src/urllib3/packages/ssl_match_hostname/_implementation.pyi b/src/urllib3/packages/ssl_match_hostname/_implementation.pyi deleted file mode 100644 index ed472ba01e..0000000000 --- a/src/urllib3/packages/ssl_match_hostname/_implementation.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Dict, Tuple, Union - -# https://github.com/python/typeshed/blob/master/stdlib/2and3/ssl.pyi -_PCTRTT = Tuple[Tuple[str, str], ...] -_PCTRTTT = Tuple[_PCTRTT, ...] -_PeerCertRetDictType = Dict[str, Union[str, _PCTRTTT, _PCTRTT]] -_PeerCertRetType = Union[_PeerCertRetDictType, bytes, None] - -class CertificateError(ValueError): ... - -def match_hostname(cert: _PeerCertRetType, hostname: str) -> None: ... diff --git a/src/urllib3/poolmanager.py b/src/urllib3/poolmanager.py index 3a31a285bf..b8434a3263 100644 --- a/src/urllib3/poolmanager.py +++ b/src/urllib3/poolmanager.py @@ -1,24 +1,33 @@ -from __future__ import absolute_import +from __future__ import annotations -import collections import functools import logging +import typing +import warnings +from types import TracebackType +from urllib.parse import urljoin from ._collections import RecentlyUsedContainer +from ._request_methods import RequestMethods +from .connection import ProxyConfig from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, port_by_scheme from .exceptions import ( LocationValueError, MaxRetryError, ProxySchemeUnknown, - ProxySchemeUnsupported, URLSchemeUnknown, ) -from .packages import six -from .packages.six.moves.urllib.parse import urljoin -from .request import RequestMethods +from .response import BaseHTTPResponse +from .util.connection import _TYPE_SOCKET_OPTIONS from .util.proxy import connection_requires_http_tunnel from .util.retry import Retry -from .util.url import parse_url +from .util.timeout import Timeout +from .util.url import Url, parse_url + +if typing.TYPE_CHECKING: + import ssl + + from typing_extensions import Literal __all__ = ["PoolManager", "ProxyManager", "proxy_from_url"] @@ -31,51 +40,61 @@ "cert_reqs", "ca_certs", "ssl_version", + "ssl_minimum_version", + "ssl_maximum_version", "ca_cert_dir", "ssl_context", "key_password", + "server_hostname", ) +# Default value for `blocksize` - a new parameter introduced to +# http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7 +_DEFAULT_BLOCKSIZE = 16384 -# All known keyword arguments that could be provided to the pool manager, its -# pools, or the underlying connections. This is used to construct a pool key. -_key_fields = ( - "key_scheme", # str - "key_host", # str - "key_port", # int - "key_timeout", # int or float or Timeout - "key_retries", # int or Retry - "key_strict", # bool - "key_block", # bool - "key_source_address", # str - "key_key_file", # str - "key_key_password", # str - "key_cert_file", # str - "key_cert_reqs", # str - "key_ca_certs", # str - "key_ssl_version", # str - "key_ca_cert_dir", # str - "key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext - "key_maxsize", # int - "key_headers", # dict - "key__proxy", # parsed proxy url - "key__proxy_headers", # dict - "key__proxy_config", # class - "key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples - "key__socks_options", # dict - "key_assert_hostname", # bool or string - "key_assert_fingerprint", # str - "key_server_hostname", # str -) +_SelfT = typing.TypeVar("_SelfT") -#: The namedtuple class used to construct keys for the connection pool. -#: All custom key schemes should include the fields in this key at a minimum. -PoolKey = collections.namedtuple("PoolKey", _key_fields) -_proxy_config_fields = ("ssl_context", "use_forwarding_for_https") -ProxyConfig = collections.namedtuple("ProxyConfig", _proxy_config_fields) +class PoolKey(typing.NamedTuple): + """ + All known keyword arguments that could be provided to the pool manager, its + pools, or the underlying connections. + All custom key schemes should include the fields in this key at a minimum. + """ -def _default_key_normalizer(key_class, request_context): + key_scheme: str + key_host: str + key_port: int | None + key_timeout: Timeout | float | int | None + key_retries: Retry | bool | int | None + key_block: bool | None + key_source_address: tuple[str, int] | None + key_key_file: str | None + key_key_password: str | None + key_cert_file: str | None + key_cert_reqs: str | None + key_ca_certs: str | None + key_ssl_version: int | str | None + key_ssl_minimum_version: ssl.TLSVersion | None + key_ssl_maximum_version: ssl.TLSVersion | None + key_ca_cert_dir: str | None + key_ssl_context: ssl.SSLContext | None + key_maxsize: int | None + key_headers: frozenset[tuple[str, str]] | None + key__proxy: Url | None + key__proxy_headers: frozenset[tuple[str, str]] | None + key__proxy_config: ProxyConfig | None + key_socket_options: _TYPE_SOCKET_OPTIONS | None + key__socks_options: frozenset[tuple[str, str]] | None + key_assert_hostname: bool | str | None + key_assert_fingerprint: str | None + key_server_hostname: str | None + key_blocksize: int | None + + +def _default_key_normalizer( + key_class: type[PoolKey], request_context: dict[str, typing.Any] +) -> PoolKey: """ Create a pool key out of a request context dictionary. @@ -121,6 +140,10 @@ def _default_key_normalizer(key_class, request_context): if field not in context: context[field] = None + # Default key_blocksize to _DEFAULT_BLOCKSIZE if missing from the context + if context.get("key_blocksize") is None: + context["key_blocksize"] = _DEFAULT_BLOCKSIZE + return key_class(**context) @@ -153,39 +176,63 @@ class PoolManager(RequestMethods): Additional parameters are used to create fresh :class:`urllib3.connectionpool.ConnectionPool` instances. - Example:: + Example: + + .. code-block:: python + + import urllib3 + + http = urllib3.PoolManager(num_pools=2) - >>> manager = PoolManager(num_pools=2) - >>> r = manager.request('GET', 'http://google.com/') - >>> r = manager.request('GET', 'http://google.com/mail') - >>> r = manager.request('GET', 'http://yahoo.com/') - >>> len(manager.pools) - 2 + resp1 = http.request("GET", "https://google.com/") + resp2 = http.request("GET", "https://google.com/mail") + resp3 = http.request("GET", "https://yahoo.com/") + + print(len(http.pools)) + # 2 """ - proxy = None - proxy_config = None + proxy: Url | None = None + proxy_config: ProxyConfig | None = None - def __init__(self, num_pools=10, headers=None, **connection_pool_kw): - RequestMethods.__init__(self, headers) + def __init__( + self, + num_pools: int = 10, + headers: typing.Mapping[str, str] | None = None, + **connection_pool_kw: typing.Any, + ) -> None: + super().__init__(headers) self.connection_pool_kw = connection_pool_kw - self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close()) + + self.pools: RecentlyUsedContainer[PoolKey, HTTPConnectionPool] + self.pools = RecentlyUsedContainer(num_pools) # Locally set the pool classes and keys so other PoolManagers can # override them. self.pool_classes_by_scheme = pool_classes_by_scheme self.key_fn_by_scheme = key_fn_by_scheme.copy() - def __enter__(self): + def __enter__(self: _SelfT) -> _SelfT: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> Literal[False]: self.clear() # Return False to re-raise any potential exceptions return False - def _new_pool(self, scheme, host, port, request_context=None): + def _new_pool( + self, + scheme: str, + host: str, + port: int, + request_context: dict[str, typing.Any] | None = None, + ) -> HTTPConnectionPool: """ Create a new :class:`urllib3.connectionpool.ConnectionPool` based on host, port, scheme, and any additional pool keyword arguments. @@ -195,10 +242,15 @@ def _new_pool(self, scheme, host, port, request_context=None): connection pools handed out by :meth:`connection_from_url` and companion methods. It is intended to be overridden for customization. """ - pool_cls = self.pool_classes_by_scheme[scheme] + pool_cls: type[HTTPConnectionPool] = self.pool_classes_by_scheme[scheme] if request_context is None: request_context = self.connection_pool_kw.copy() + # Default blocksize to _DEFAULT_BLOCKSIZE if missing or explicitly + # set to 'None' in the request_context. + if request_context.get("blocksize") is None: + request_context["blocksize"] = _DEFAULT_BLOCKSIZE + # Although the context has everything necessary to create the pool, # this function has historically only used the scheme, host, and port # in the positional args. When an API change is acceptable these can @@ -212,7 +264,7 @@ def _new_pool(self, scheme, host, port, request_context=None): return pool_cls(host, port, **request_context) - def clear(self): + def clear(self) -> None: """ Empty our store of pools and direct them all to close. @@ -221,7 +273,13 @@ def clear(self): """ self.pools.clear() - def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None): + def connection_from_host( + self, + host: str | None, + port: int | None = None, + scheme: str | None = "http", + pool_kwargs: dict[str, typing.Any] | None = None, + ) -> HTTPConnectionPool: """ Get a :class:`urllib3.connectionpool.ConnectionPool` based on the host, port, and scheme. @@ -244,13 +302,23 @@ def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None) return self.connection_from_context(request_context) - def connection_from_context(self, request_context): + def connection_from_context( + self, request_context: dict[str, typing.Any] + ) -> HTTPConnectionPool: """ Get a :class:`urllib3.connectionpool.ConnectionPool` based on the request context. ``request_context`` must at least contain the ``scheme`` key and its value must be a key in ``key_fn_by_scheme`` instance variable. """ + if "strict" in request_context: + warnings.warn( + "The 'strict' parameter is no longer needed on Python 3+. " + "This will raise an error in urllib3 v2.1.0.", + DeprecationWarning, + ) + request_context.pop("strict") + scheme = request_context["scheme"].lower() pool_key_constructor = self.key_fn_by_scheme.get(scheme) if not pool_key_constructor: @@ -259,7 +327,9 @@ def connection_from_context(self, request_context): return self.connection_from_pool_key(pool_key, request_context=request_context) - def connection_from_pool_key(self, pool_key, request_context=None): + def connection_from_pool_key( + self, pool_key: PoolKey, request_context: dict[str, typing.Any] + ) -> HTTPConnectionPool: """ Get a :class:`urllib3.connectionpool.ConnectionPool` based on the provided pool key. @@ -283,7 +353,9 @@ def connection_from_pool_key(self, pool_key, request_context=None): return pool - def connection_from_url(self, url, pool_kwargs=None): + def connection_from_url( + self, url: str, pool_kwargs: dict[str, typing.Any] | None = None + ) -> HTTPConnectionPool: """ Similar to :func:`urllib3.connectionpool.connection_from_url`. @@ -299,7 +371,9 @@ def connection_from_url(self, url, pool_kwargs=None): u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs ) - def _merge_pool_kwargs(self, override): + def _merge_pool_kwargs( + self, override: dict[str, typing.Any] | None + ) -> dict[str, typing.Any]: """ Merge a dictionary of override values for self.connection_pool_kw. @@ -319,7 +393,7 @@ def _merge_pool_kwargs(self, override): base_pool_kwargs[key] = value return base_pool_kwargs - def _proxy_requires_url_absolute_form(self, parsed_url): + def _proxy_requires_url_absolute_form(self, parsed_url: Url) -> bool: """ Indicates if the proxy requires the complete destination URL in the request. Normally this is only needed when not using an HTTP CONNECT @@ -332,24 +406,9 @@ def _proxy_requires_url_absolute_form(self, parsed_url): self.proxy, self.proxy_config, parsed_url.scheme ) - def _validate_proxy_scheme_url_selection(self, url_scheme): - """ - Validates that were not attempting to do TLS in TLS connections on - Python2 or with unsupported SSL implementations. - """ - if self.proxy is None or url_scheme != "https": - return - - if self.proxy.scheme != "https": - return - - if six.PY2 and not self.proxy_config.use_forwarding_for_https: - raise ProxySchemeUnsupported( - "Contacting HTTPS destinations through HTTPS proxies " - "'via CONNECT tunnels' is not supported in Python 2" - ) - - def urlopen(self, method, url, redirect=True, **kw): + def urlopen( # type: ignore[override] + self, method: str, url: str, redirect: bool = True, **kw: typing.Any + ) -> BaseHTTPResponse: """ Same as :meth:`urllib3.HTTPConnectionPool.urlopen` with custom cross-host redirect logic and only sends the request-uri @@ -359,7 +418,6 @@ def urlopen(self, method, url, redirect=True, **kw): :class:`urllib3.connectionpool.ConnectionPool` can be chosen for it. """ u = parse_url(url) - self._validate_proxy_scheme_url_selection(u.scheme) conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme) @@ -367,7 +425,7 @@ def urlopen(self, method, url, redirect=True, **kw): kw["redirect"] = False if "headers" not in kw: - kw["headers"] = self.headers.copy() + kw["headers"] = self.headers if self._proxy_requires_url_absolute_form(u): response = conn.urlopen(method, url, **kw) @@ -395,10 +453,11 @@ def urlopen(self, method, url, redirect=True, **kw): if retries.remove_headers_on_redirect and not conn.is_same_host( redirect_location ): - headers = list(six.iterkeys(kw["headers"])) - for header in headers: + new_headers = kw["headers"].copy() + for header in kw["headers"]: if header.lower() in retries.remove_headers_on_redirect: - kw["headers"].pop(header, None) + new_headers.pop(header, None) + kw["headers"] = new_headers try: retries = retries.increment(method, url, response=response, _pool=conn) @@ -444,37 +503,51 @@ class ProxyManager(PoolManager): private. IP address, target hostname, SNI, and port are always visible to an HTTPS proxy even when this flag is disabled. + :param proxy_assert_hostname: + The hostname of the certificate to verify against. + + :param proxy_assert_fingerprint: + The fingerprint of the certificate to verify against. + Example: - >>> proxy = urllib3.ProxyManager('http://localhost:3128/') - >>> r1 = proxy.request('GET', 'http://google.com/') - >>> r2 = proxy.request('GET', 'http://httpbin.org/') - >>> len(proxy.pools) - 1 - >>> r3 = proxy.request('GET', 'https://httpbin.org/') - >>> r4 = proxy.request('GET', 'https://twitter.com/') - >>> len(proxy.pools) - 3 + + .. code-block:: python + + import urllib3 + + proxy = urllib3.ProxyManager("https://localhost:3128/") + + resp1 = proxy.request("GET", "https://google.com/") + resp2 = proxy.request("GET", "https://httpbin.org/") + + print(len(proxy.pools)) + # 1 + + resp3 = proxy.request("GET", "https://httpbin.org/") + resp4 = proxy.request("GET", "https://twitter.com/") + + print(len(proxy.pools)) + # 3 """ def __init__( self, - proxy_url, - num_pools=10, - headers=None, - proxy_headers=None, - proxy_ssl_context=None, - use_forwarding_for_https=False, - **connection_pool_kw - ): - + proxy_url: str, + num_pools: int = 10, + headers: typing.Mapping[str, str] | None = None, + proxy_headers: typing.Mapping[str, str] | None = None, + proxy_ssl_context: ssl.SSLContext | None = None, + use_forwarding_for_https: bool = False, + proxy_assert_hostname: None | str | Literal[False] = None, + proxy_assert_fingerprint: str | None = None, + **connection_pool_kw: typing.Any, + ) -> None: if isinstance(proxy_url, HTTPConnectionPool): - proxy_url = "%s://%s:%i" % ( - proxy_url.scheme, - proxy_url.host, - proxy_url.port, - ) - proxy = parse_url(proxy_url) + str_proxy_url = f"{proxy_url.scheme}://{proxy_url.host}:{proxy_url.port}" + else: + str_proxy_url = proxy_url + proxy = parse_url(str_proxy_url) if proxy.scheme not in ("http", "https"): raise ProxySchemeUnknown(proxy.scheme) @@ -486,25 +559,38 @@ def __init__( self.proxy = proxy self.proxy_headers = proxy_headers or {} self.proxy_ssl_context = proxy_ssl_context - self.proxy_config = ProxyConfig(proxy_ssl_context, use_forwarding_for_https) + self.proxy_config = ProxyConfig( + proxy_ssl_context, + use_forwarding_for_https, + proxy_assert_hostname, + proxy_assert_fingerprint, + ) connection_pool_kw["_proxy"] = self.proxy connection_pool_kw["_proxy_headers"] = self.proxy_headers connection_pool_kw["_proxy_config"] = self.proxy_config - super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw) + super().__init__(num_pools, headers, **connection_pool_kw) - def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None): + def connection_from_host( + self, + host: str | None, + port: int | None = None, + scheme: str | None = "http", + pool_kwargs: dict[str, typing.Any] | None = None, + ) -> HTTPConnectionPool: if scheme == "https": - return super(ProxyManager, self).connection_from_host( + return super().connection_from_host( host, port, scheme, pool_kwargs=pool_kwargs ) - return super(ProxyManager, self).connection_from_host( - self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs + return super().connection_from_host( + self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs # type: ignore[union-attr] ) - def _set_proxy_headers(self, url, headers=None): + def _set_proxy_headers( + self, url: str, headers: typing.Mapping[str, str] | None = None + ) -> typing.Mapping[str, str]: """ Sets headers needed by proxies: specifically, the Accept and Host headers. Only sets headers not provided by the user. @@ -519,7 +605,9 @@ def _set_proxy_headers(self, url, headers=None): headers_.update(headers) return headers_ - def urlopen(self, method, url, redirect=True, **kw): + def urlopen( # type: ignore[override] + self, method: str, url: str, redirect: bool = True, **kw: typing.Any + ) -> BaseHTTPResponse: "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." u = parse_url(url) if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme): @@ -529,8 +617,8 @@ def urlopen(self, method, url, redirect=True, **kw): headers = kw.get("headers", self.headers) kw["headers"] = self._set_proxy_headers(url, headers) - return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw) + return super().urlopen(method, url, redirect=redirect, **kw) -def proxy_from_url(url, **kw): +def proxy_from_url(url: str, **kw: typing.Any) -> ProxyManager: return ProxyManager(proxy_url=url, **kw) diff --git a/src/urllib3/py.typed b/src/urllib3/py.typed new file mode 100644 index 0000000000..5f3ea3d919 --- /dev/null +++ b/src/urllib3/py.typed @@ -0,0 +1,2 @@ +# Instruct type checkers to look for inline type annotations in this package. +# See PEP 561. diff --git a/src/urllib3/response.py b/src/urllib3/response.py index 38693f4fc6..69e1bd66d0 100644 --- a/src/urllib3/response.py +++ b/src/urllib3/response.py @@ -1,19 +1,47 @@ -from __future__ import absolute_import +from __future__ import annotations +import collections import io +import json as _json import logging +import re +import sys +import typing +import warnings import zlib from contextlib import contextmanager -from socket import error as SocketError +from http.client import HTTPMessage as _HttplibHTTPMessage +from http.client import HTTPResponse as _HttplibHTTPResponse from socket import timeout as SocketTimeout try: - import brotli + try: + import brotlicffi as brotli # type: ignore[import] + except ImportError: + import brotli # type: ignore[import] except ImportError: brotli = None +try: + import zstandard as zstd # type: ignore[import] + + # The package 'zstandard' added the 'eof' property starting + # in v0.18.0 which we require to ensure a complete and + # valid zstd stream was fed into the ZstdDecoder. + # See: https://github.com/urllib3/urllib3/pull/2624 + _zstd_version = _zstd_version = tuple( + map(int, re.search(r"^([0-9]+)\.([0-9]+)", zstd.__version__).groups()) # type: ignore[union-attr] + ) + if _zstd_version < (0, 18): # Defensive: + zstd = None + +except (AttributeError, ImportError, ValueError): # Defensive: + zstd = None + +from . import util +from ._base_connection import _TYPE_BODY from ._collections import HTTPHeaderDict -from .connection import BaseSSLError, HTTPException +from .connection import BaseSSLError, HTTPConnection, HTTPException from .exceptions import ( BodyNotHttplibCompatible, DecodeError, @@ -26,22 +54,32 @@ ResponseNotChunked, SSLError, ) -from .packages import six from .util.response import is_fp_closed, is_response_to_head +from .util.retry import Retry + +if typing.TYPE_CHECKING: + from typing_extensions import Literal + + from .connectionpool import HTTPConnectionPool log = logging.getLogger(__name__) -class DeflateDecoder(object): - def __init__(self): +class ContentDecoder: + def decompress(self, data: bytes) -> bytes: + raise NotImplementedError() + + def flush(self) -> bytes: + raise NotImplementedError() + + +class DeflateDecoder(ContentDecoder): + def __init__(self) -> None: self._first_try = True self._data = b"" self._obj = zlib.decompressobj() - def __getattr__(self, name): - return getattr(self._obj, name) - - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: if not data: return data @@ -53,7 +91,7 @@ def decompress(self, data): decompressed = self._obj.decompress(data) if decompressed: self._first_try = False - self._data = None + self._data = None # type: ignore[assignment] return decompressed except zlib.error: self._first_try = False @@ -61,25 +99,24 @@ def decompress(self, data): try: return self.decompress(self._data) finally: - self._data = None + self._data = None # type: ignore[assignment] + def flush(self) -> bytes: + return self._obj.flush() -class GzipDecoderState(object): +class GzipDecoderState: FIRST_MEMBER = 0 OTHER_MEMBERS = 1 SWALLOW_DATA = 2 -class GzipDecoder(object): - def __init__(self): +class GzipDecoder(ContentDecoder): + def __init__(self) -> None: self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) self._state = GzipDecoderState.FIRST_MEMBER - def __getattr__(self, name): - return getattr(self._obj, name) - - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: ret = bytearray() if self._state == GzipDecoderState.SWALLOW_DATA or not data: return bytes(ret) @@ -100,27 +137,48 @@ def decompress(self, data): self._state = GzipDecoderState.OTHER_MEMBERS self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS) + def flush(self) -> bytes: + return self._obj.flush() + if brotli is not None: - class BrotliDecoder(object): + class BrotliDecoder(ContentDecoder): # Supports both 'brotlipy' and 'Brotli' packages # since they share an import name. The top branches # are for 'brotlipy' and bottom branches for 'Brotli' - def __init__(self): + def __init__(self) -> None: self._obj = brotli.Decompressor() if hasattr(self._obj, "decompress"): - self.decompress = self._obj.decompress + setattr(self, "decompress", self._obj.decompress) else: - self.decompress = self._obj.process + setattr(self, "decompress", self._obj.process) - def flush(self): + def flush(self) -> bytes: if hasattr(self._obj, "flush"): - return self._obj.flush() + return self._obj.flush() # type: ignore[no-any-return] return b"" -class MultiDecoder(object): +if zstd is not None: + + class ZstdDecoder(ContentDecoder): + def __init__(self) -> None: + self._obj = zstd.ZstdDecompressor().decompressobj() + + def decompress(self, data: bytes) -> bytes: + if not data: + return b"" + return self._obj.decompress(data) # type: ignore[no-any-return] + + def flush(self) -> bytes: + ret = self._obj.flush() + if not self._obj.eof: + raise DecodeError("Zstandard data is incomplete") + return ret # type: ignore[no-any-return] + + +class MultiDecoder(ContentDecoder): """ From RFC7231: If one or more encodings have been applied to a representation, the @@ -129,19 +187,19 @@ class MultiDecoder(object): they were applied. """ - def __init__(self, modes): + def __init__(self, modes: str) -> None: self._decoders = [_get_decoder(m.strip()) for m in modes.split(",")] - def flush(self): + def flush(self) -> bytes: return self._decoders[0].flush() - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: for d in reversed(self._decoders): data = d.decompress(data) return data -def _get_decoder(mode): +def _get_decoder(mode: str) -> ContentDecoder: if "," in mode: return MultiDecoder(mode) @@ -151,10 +209,292 @@ def _get_decoder(mode): if brotli is not None and mode == "br": return BrotliDecoder() + if zstd is not None and mode == "zstd": + return ZstdDecoder() + return DeflateDecoder() -class HTTPResponse(io.IOBase): +class BytesQueueBuffer: + """Memory-efficient bytes buffer + + To return decoded data in read() and still follow the BufferedIOBase API, we need a + buffer to always return the correct amount of bytes. + + This buffer should be filled using calls to put() + + Our maximum memory usage is determined by the sum of the size of: + + * self.buffer, which contains the full data + * the largest chunk that we will copy in get() + + The worst case scenario is a single chunk, in which case we'll make a full copy of + the data inside get(). + """ + + def __init__(self) -> None: + self.buffer: typing.Deque[bytes] = collections.deque() + self._size: int = 0 + + def __len__(self) -> int: + return self._size + + def put(self, data: bytes) -> None: + self.buffer.append(data) + self._size += len(data) + + def get(self, n: int) -> bytes: + if not self.buffer: + raise RuntimeError("buffer is empty") + elif n < 0: + raise ValueError("n should be > 0") + + fetched = 0 + ret = io.BytesIO() + while fetched < n: + remaining = n - fetched + chunk = self.buffer.popleft() + chunk_length = len(chunk) + if remaining < chunk_length: + left_chunk, right_chunk = chunk[:remaining], chunk[remaining:] + ret.write(left_chunk) + self.buffer.appendleft(right_chunk) + self._size -= remaining + break + else: + ret.write(chunk) + self._size -= chunk_length + fetched += chunk_length + + if not self.buffer: + break + + return ret.getvalue() + + +class BaseHTTPResponse(io.IOBase): + CONTENT_DECODERS = ["gzip", "deflate"] + if brotli is not None: + CONTENT_DECODERS += ["br"] + if zstd is not None: + CONTENT_DECODERS += ["zstd"] + REDIRECT_STATUSES = [301, 302, 303, 307, 308] + + DECODER_ERROR_CLASSES: tuple[type[Exception], ...] = (IOError, zlib.error) + if brotli is not None: + DECODER_ERROR_CLASSES += (brotli.error,) + + if zstd is not None: + DECODER_ERROR_CLASSES += (zstd.ZstdError,) + + def __init__( + self, + *, + headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None, + status: int, + version: int, + reason: str | None, + decode_content: bool, + request_url: str | None, + retries: Retry | None = None, + ) -> None: + if isinstance(headers, HTTPHeaderDict): + self.headers = headers + else: + self.headers = HTTPHeaderDict(headers) # type: ignore[arg-type] + self.status = status + self.version = version + self.reason = reason + self.decode_content = decode_content + self._has_decoded_content = False + self._request_url: str | None = request_url + self.retries = retries + + self.chunked = False + tr_enc = self.headers.get("transfer-encoding", "").lower() + # Don't incur the penalty of creating a list and then discarding it + encodings = (enc.strip() for enc in tr_enc.split(",")) + if "chunked" in encodings: + self.chunked = True + + self._decoder: ContentDecoder | None = None + + def get_redirect_location(self) -> str | None | Literal[False]: + """ + Should we redirect and where to? + + :returns: Truthy redirect location string if we got a redirect status + code and valid location. ``None`` if redirect status and no + location. ``False`` if not a redirect status code. + """ + if self.status in self.REDIRECT_STATUSES: + return self.headers.get("location") + return False + + @property + def data(self) -> bytes: + raise NotImplementedError() + + def json(self) -> typing.Any: + """ + Parses the body of the HTTP response as JSON. + + To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to the decoder. + + This method can raise either `UnicodeDecodeError` or `json.JSONDecodeError`. + + Read more :ref:`here `. + """ + data = self.data.decode("utf-8") + return _json.loads(data) + + @property + def url(self) -> str | None: + raise NotImplementedError() + + @url.setter + def url(self, url: str | None) -> None: + raise NotImplementedError() + + @property + def connection(self) -> HTTPConnection | None: + raise NotImplementedError() + + @property + def retries(self) -> Retry | None: + return self._retries + + @retries.setter + def retries(self, retries: Retry | None) -> None: + # Override the request_url if retries has a redirect location. + if retries is not None and retries.history: + self.url = retries.history[-1].redirect_location + self._retries = retries + + def stream( + self, amt: int | None = 2**16, decode_content: bool | None = None + ) -> typing.Iterator[bytes]: + raise NotImplementedError() + + def read( + self, + amt: int | None = None, + decode_content: bool | None = None, + cache_content: bool = False, + ) -> bytes: + raise NotImplementedError() + + def read_chunked( + self, + amt: int | None = None, + decode_content: bool | None = None, + ) -> typing.Iterator[bytes]: + raise NotImplementedError() + + def release_conn(self) -> None: + raise NotImplementedError() + + def drain_conn(self) -> None: + raise NotImplementedError() + + def close(self) -> None: + raise NotImplementedError() + + def _init_decoder(self) -> None: + """ + Set-up the _decoder attribute if necessary. + """ + # Note: content-encoding value should be case-insensitive, per RFC 7230 + # Section 3.2 + content_encoding = self.headers.get("content-encoding", "").lower() + if self._decoder is None: + if content_encoding in self.CONTENT_DECODERS: + self._decoder = _get_decoder(content_encoding) + elif "," in content_encoding: + encodings = [ + e.strip() + for e in content_encoding.split(",") + if e.strip() in self.CONTENT_DECODERS + ] + if encodings: + self._decoder = _get_decoder(content_encoding) + + def _decode( + self, data: bytes, decode_content: bool | None, flush_decoder: bool + ) -> bytes: + """ + Decode the data passed in and potentially flush the decoder. + """ + if not decode_content: + if self._has_decoded_content: + raise RuntimeError( + "Calling read(decode_content=False) is not supported after " + "read(decode_content=True) was called." + ) + return data + + try: + if self._decoder: + data = self._decoder.decompress(data) + self._has_decoded_content = True + except self.DECODER_ERROR_CLASSES as e: + content_encoding = self.headers.get("content-encoding", "").lower() + raise DecodeError( + "Received response with content-encoding: %s, but " + "failed to decode it." % content_encoding, + e, + ) from e + if flush_decoder: + data += self._flush_decoder() + + return data + + def _flush_decoder(self) -> bytes: + """ + Flushes the decoder. Should only be called if the decoder is actually + being used. + """ + if self._decoder: + return self._decoder.decompress(b"") + self._decoder.flush() + return b"" + + # Compatibility methods for `io` module + def readinto(self, b: bytearray) -> int: + temp = self.read(len(b)) + if len(temp) == 0: + return 0 + else: + b[: len(temp)] = temp + return len(temp) + + # Compatibility methods for http.client.HTTPResponse + def getheaders(self) -> HTTPHeaderDict: + warnings.warn( + "HTTPResponse.getheaders() is deprecated and will be removed " + "in urllib3 v2.1.0. Instead access HTTPResponse.headers directly.", + category=DeprecationWarning, + stacklevel=2, + ) + return self.headers + + def getheader(self, name: str, default: str | None = None) -> str | None: + warnings.warn( + "HTTPResponse.getheader() is deprecated and will be removed " + "in urllib3 v2.1.0. Instead use HTTPResponse.headers.get(name, default).", + category=DeprecationWarning, + stacklevel=2, + ) + return self.headers.get(name, default) + + # Compatibility method for http.cookiejar + def info(self) -> HTTPHeaderDict: + return self.headers + + def geturl(self) -> str | None: + return self.url + + +class HTTPResponse(BaseHTTPResponse): """ HTTP Response container. @@ -187,99 +527,74 @@ class is also compatible with the Python standard library's :mod:`io` value of Content-Length header, if present. Otherwise, raise error. """ - CONTENT_DECODERS = ["gzip", "deflate"] - if brotli is not None: - CONTENT_DECODERS += ["br"] - REDIRECT_STATUSES = [301, 302, 303, 307, 308] - def __init__( self, - body="", - headers=None, - status=0, - version=0, - reason=None, - strict=0, - preload_content=True, - decode_content=True, - original_response=None, - pool=None, - connection=None, - msg=None, - retries=None, - enforce_content_length=False, - request_method=None, - request_url=None, - auto_close=True, - ): + body: _TYPE_BODY = "", + headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None, + status: int = 0, + version: int = 0, + reason: str | None = None, + preload_content: bool = True, + decode_content: bool = True, + original_response: _HttplibHTTPResponse | None = None, + pool: HTTPConnectionPool | None = None, + connection: HTTPConnection | None = None, + msg: _HttplibHTTPMessage | None = None, + retries: Retry | None = None, + enforce_content_length: bool = True, + request_method: str | None = None, + request_url: str | None = None, + auto_close: bool = True, + ) -> None: + super().__init__( + headers=headers, + status=status, + version=version, + reason=reason, + decode_content=decode_content, + request_url=request_url, + retries=retries, + ) - if isinstance(headers, HTTPHeaderDict): - self.headers = headers - else: - self.headers = HTTPHeaderDict(headers) - self.status = status - self.version = version - self.reason = reason - self.strict = strict - self.decode_content = decode_content - self.retries = retries self.enforce_content_length = enforce_content_length self.auto_close = auto_close - self._decoder = None self._body = None - self._fp = None + self._fp: _HttplibHTTPResponse | None = None self._original_response = original_response self._fp_bytes_read = 0 self.msg = msg - self._request_url = request_url - if body and isinstance(body, (six.string_types, bytes)): + if body and isinstance(body, (str, bytes)): self._body = body self._pool = pool self._connection = connection if hasattr(body, "read"): - self._fp = body + self._fp = body # type: ignore[assignment] # Are we using the chunked-style of transfer encoding? - self.chunked = False - self.chunk_left = None - tr_enc = self.headers.get("transfer-encoding", "").lower() - # Don't incur the penalty of creating a list and then discarding it - encodings = (enc.strip() for enc in tr_enc.split(",")) - if "chunked" in encodings: - self.chunked = True + self.chunk_left: int | None = None # Determine length of response self.length_remaining = self._init_length(request_method) + # Used to return the correct amount of bytes for partial read()s + self._decoded_buffer = BytesQueueBuffer() + # If requested, preload the body. if preload_content and not self._body: self._body = self.read(decode_content=decode_content) - def get_redirect_location(self): - """ - Should we redirect and where to? - - :returns: Truthy redirect location string if we got a redirect status - code and valid location. ``None`` if redirect status and no - location. ``False`` if not a redirect status code. - """ - if self.status in self.REDIRECT_STATUSES: - return self.headers.get("location") - - return False - - def release_conn(self): + def release_conn(self) -> None: if not self._pool or not self._connection: - return + return None self._pool._put_conn(self._connection) self._connection = None - def drain_conn(self): + def drain_conn(self) -> None: """ Read and discard any remaining HTTP response data in the response connection. @@ -287,26 +602,28 @@ def drain_conn(self): """ try: self.read() - except (HTTPError, SocketError, BaseSSLError, HTTPException): + except (HTTPError, OSError, BaseSSLError, HTTPException): pass @property - def data(self): + def data(self) -> bytes: # For backwards-compat with earlier urllib3 0.4 and earlier. if self._body: - return self._body + return self._body # type: ignore[return-value] if self._fp: return self.read(cache_content=True) + return None # type: ignore[return-value] + @property - def connection(self): + def connection(self) -> HTTPConnection | None: return self._connection - def isclosed(self): + def isclosed(self) -> bool: return is_fp_closed(self._fp) - def tell(self): + def tell(self) -> int: """ Obtain the number of bytes pulled over the wire so far. May differ from the amount of content returned by :meth:``urllib3.response.HTTPResponse.read`` @@ -314,13 +631,14 @@ def tell(self): """ return self._fp_bytes_read - def _init_length(self, request_method): + def _init_length(self, request_method: str | None) -> int | None: """ Set initial length value for Response content if available. """ - length = self.headers.get("content-length") + length: int | None + content_length: str | None = self.headers.get("content-length") - if length is not None: + if content_length is not None: if self.chunked: # This Response will fail with an IncompleteRead if it can't be # received as chunked. This method falls back to attempt reading @@ -340,11 +658,11 @@ def _init_length(self, request_method): # (e.g. Content-Length: 42, 42). This line ensures the values # are all valid ints and that as long as the `set` length is 1, # all values are the same. Otherwise, the header is invalid. - lengths = set([int(val) for val in length.split(",")]) + lengths = {int(val) for val in content_length.split(",")} if len(lengths) > 1: raise InvalidHeader( "Content-Length contained multiple " - "unmatching values (%s)" % length + "unmatching values (%s)" % content_length ) length = lengths.pop() except ValueError: @@ -353,6 +671,9 @@ def _init_length(self, request_method): if length < 0: length = None + else: # if content_length is None + length = None + # Convert status to int for comparison # In some cases, httplib returns a status of "_UNKNOWN" try: @@ -366,64 +687,8 @@ def _init_length(self, request_method): return length - def _init_decoder(self): - """ - Set-up the _decoder attribute if necessary. - """ - # Note: content-encoding value should be case-insensitive, per RFC 7230 - # Section 3.2 - content_encoding = self.headers.get("content-encoding", "").lower() - if self._decoder is None: - if content_encoding in self.CONTENT_DECODERS: - self._decoder = _get_decoder(content_encoding) - elif "," in content_encoding: - encodings = [ - e.strip() - for e in content_encoding.split(",") - if e.strip() in self.CONTENT_DECODERS - ] - if len(encodings): - self._decoder = _get_decoder(content_encoding) - - DECODER_ERROR_CLASSES = (IOError, zlib.error) - if brotli is not None: - DECODER_ERROR_CLASSES += (brotli.error,) - - def _decode(self, data, decode_content, flush_decoder): - """ - Decode the data passed in and potentially flush the decoder. - """ - if not decode_content: - return data - - try: - if self._decoder: - data = self._decoder.decompress(data) - except self.DECODER_ERROR_CLASSES as e: - content_encoding = self.headers.get("content-encoding", "").lower() - raise DecodeError( - "Received response with content-encoding: %s, but " - "failed to decode it." % content_encoding, - e, - ) - if flush_decoder: - data += self._flush_decoder() - - return data - - def _flush_decoder(self): - """ - Flushes the decoder. Should only be called if the decoder is actually - being used. - """ - if self._decoder: - buf = self._decoder.decompress(b"") - return buf + self._decoder.flush() - - return b"" - @contextmanager - def _error_catcher(self): + def _error_catcher(self) -> typing.Generator[None, None, None]: """ Catch low-level python exceptions, instead re-raising urllib3 variants, so that low-level exceptions are not leaked in the @@ -437,22 +702,22 @@ def _error_catcher(self): try: yield - except SocketTimeout: + except SocketTimeout as e: # FIXME: Ideally we'd like to include the url in the ReadTimeoutError but # there is yet no clean way to get at it from this context. - raise ReadTimeoutError(self._pool, None, "Read timed out.") + raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type] except BaseSSLError as e: # FIXME: Is there a better way to differentiate between SSLErrors? if "read operation timed out" not in str(e): # SSL errors related to framing/MAC get wrapped and reraised here - raise SSLError(e) + raise SSLError(e) from e - raise ReadTimeoutError(self._pool, None, "Read timed out.") + raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type] - except (HTTPException, SocketError) as e: + except (HTTPException, OSError) as e: # This includes IncompleteRead. - raise ProtocolError("Connection broken: %r" % e, e) + raise ProtocolError(f"Connection broken: {e!r}", e) from e # If no exception is thrown, we should avoid cleaning up # unnecessarily. @@ -478,7 +743,102 @@ def _error_catcher(self): if self._original_response and self._original_response.isclosed(): self.release_conn() - def read(self, amt=None, decode_content=None, cache_content=False): + def _fp_read(self, amt: int | None = None) -> bytes: + """ + Read a response with the thought that reading the number of bytes + larger than can fit in a 32-bit int at a time via SSL in some + known cases leads to an overflow error that has to be prevented + if `amt` or `self.length_remaining` indicate that a problem may + happen. + + The known cases: + * 3.8 <= CPython < 3.9.7 because of a bug + https://github.com/urllib3/urllib3/issues/2513#issuecomment-1152559900. + * urllib3 injected with pyOpenSSL-backed SSL-support. + * CPython < 3.10 only when `amt` does not fit 32-bit int. + """ + assert self._fp + c_int_max = 2**31 - 1 + if ( + ( + (amt and amt > c_int_max) + or (self.length_remaining and self.length_remaining > c_int_max) + ) + and not util.IS_SECURETRANSPORT + and (util.IS_PYOPENSSL or sys.version_info < (3, 10)) + ): + buffer = io.BytesIO() + # Besides `max_chunk_amt` being a maximum chunk size, it + # affects memory overhead of reading a response by this + # method in CPython. + # `c_int_max` equal to 2 GiB - 1 byte is the actual maximum + # chunk size that does not lead to an overflow error, but + # 256 MiB is a compromise. + max_chunk_amt = 2**28 + while amt is None or amt != 0: + if amt is not None: + chunk_amt = min(amt, max_chunk_amt) + amt -= chunk_amt + else: + chunk_amt = max_chunk_amt + data = self._fp.read(chunk_amt) + if not data: + break + buffer.write(data) + del data # to reduce peak memory usage by `max_chunk_amt`. + return buffer.getvalue() + else: + # StringIO doesn't like amt=None + return self._fp.read(amt) if amt is not None else self._fp.read() + + def _raw_read( + self, + amt: int | None = None, + ) -> bytes: + """ + Reads `amt` of bytes from the socket. + """ + if self._fp is None: + return None # type: ignore[return-value] + + fp_closed = getattr(self._fp, "closed", False) + + with self._error_catcher(): + data = self._fp_read(amt) if not fp_closed else b"" + if amt is not None and amt != 0 and not data: + # Platform-specific: Buggy versions of Python. + # Close the connection when no data is returned + # + # This is redundant to what httplib/http.client _should_ + # already do. However, versions of python released before + # December 15, 2012 (http://bugs.python.org/issue16298) do + # not properly close the connection in all cases. There is + # no harm in redundantly calling close. + self._fp.close() + if ( + self.enforce_content_length + and self.length_remaining is not None + and self.length_remaining != 0 + ): + # This is an edge case that httplib failed to cover due + # to concerns of backward compatibility. We're + # addressing it here to make sure IncompleteRead is + # raised during streaming, so all calls with incorrect + # Content-Length are caught. + raise IncompleteRead(self._fp_bytes_read, self.length_remaining) + + if data: + self._fp_bytes_read += len(data) + if self.length_remaining is not None: + self.length_remaining -= len(data) + return data + + def read( + self, + amt: int | None = None, + decode_content: bool | None = None, + cache_content: bool = False, + ) -> bytes: """ Similar to :meth:`http.client.HTTPResponse.read`, but with two additional parameters: ``decode_content`` and ``cache_content``. @@ -503,56 +863,54 @@ def read(self, amt=None, decode_content=None, cache_content=False): if decode_content is None: decode_content = self.decode_content - if self._fp is None: - return + if amt is not None: + cache_content = False - flush_decoder = False - fp_closed = getattr(self._fp, "closed", False) + if len(self._decoded_buffer) >= amt: + return self._decoded_buffer.get(amt) - with self._error_catcher(): - if amt is None: - # cStringIO doesn't like amt=None - data = self._fp.read() if not fp_closed else b"" - flush_decoder = True - else: - cache_content = False - data = self._fp.read(amt) if not fp_closed else b"" - if ( - amt != 0 and not data - ): # Platform-specific: Buggy versions of Python. - # Close the connection when no data is returned - # - # This is redundant to what httplib/http.client _should_ - # already do. However, versions of python released before - # December 15, 2012 (http://bugs.python.org/issue16298) do - # not properly close the connection in all cases. There is - # no harm in redundantly calling close. - self._fp.close() - flush_decoder = True - if self.enforce_content_length and self.length_remaining not in ( - 0, - None, - ): - # This is an edge case that httplib failed to cover due - # to concerns of backward compatibility. We're - # addressing it here to make sure IncompleteRead is - # raised during streaming, so all calls with incorrect - # Content-Length are caught. - raise IncompleteRead(self._fp_bytes_read, self.length_remaining) + data = self._raw_read(amt) - if data: - self._fp_bytes_read += len(data) - if self.length_remaining is not None: - self.length_remaining -= len(data) + flush_decoder = False + if amt is None: + flush_decoder = True + elif amt != 0 and not data: + flush_decoder = True - data = self._decode(data, decode_content, flush_decoder) + if not data and len(self._decoded_buffer) == 0: + return data + if amt is None: + data = self._decode(data, decode_content, flush_decoder) if cache_content: self._body = data + else: + # do not waste memory on buffer when not decoding + if not decode_content: + if self._has_decoded_content: + raise RuntimeError( + "Calling read(decode_content=False) is not supported after " + "read(decode_content=True) was called." + ) + return data + + decoded_data = self._decode(data, decode_content, flush_decoder) + self._decoded_buffer.put(decoded_data) + + while len(self._decoded_buffer) < amt and data: + # TODO make sure to initially read enough data to get past the headers + # For example, the GZ file header takes 10 bytes, we don't want to read + # it one byte at a time + data = self._raw_read(amt) + decoded_data = self._decode(data, decode_content, flush_decoder) + self._decoded_buffer.put(decoded_data) + data = self._decoded_buffer.get(amt) return data - def stream(self, amt=2 ** 16, decode_content=None): + def stream( + self, amt: int | None = 2**16, decode_content: bool | None = None + ) -> typing.Generator[bytes, None, None]: """ A generator wrapper for the read() method. A call will block until ``amt`` bytes have been read from the connection or until the @@ -569,8 +927,7 @@ def stream(self, amt=2 ** 16, decode_content=None): 'content-encoding' header. """ if self.chunked and self.supports_chunked_reads(): - for line in self.read_chunked(amt, decode_content=decode_content): - yield line + yield from self.read_chunked(amt, decode_content=decode_content) else: while not is_fp_closed(self._fp): data = self.read(amt=amt, decode_content=decode_content) @@ -578,52 +935,12 @@ def stream(self, amt=2 ** 16, decode_content=None): if data: yield data - @classmethod - def from_httplib(ResponseCls, r, **response_kw): - """ - Given an :class:`http.client.HTTPResponse` instance ``r``, return a - corresponding :class:`urllib3.response.HTTPResponse` object. - - Remaining parameters are passed to the HTTPResponse constructor, along - with ``original_response=r``. - """ - headers = r.msg - - if not isinstance(headers, HTTPHeaderDict): - if six.PY2: - # Python 2.7 - headers = HTTPHeaderDict.from_httplib(headers) - else: - headers = HTTPHeaderDict(headers.items()) - - # HTTPResponse objects in Python 3 don't have a .strict attribute - strict = getattr(r, "strict", 0) - resp = ResponseCls( - body=r, - headers=headers, - status=r.status, - version=r.version, - reason=r.reason, - strict=strict, - original_response=r, - **response_kw - ) - return resp - - # Backwards-compatibility methods for http.client.HTTPResponse - def getheaders(self): - return self.headers - - def getheader(self, name, default=None): - return self.headers.get(name, default) - - # Backwards compatibility for http.cookiejar - def info(self): - return self.headers - # Overrides from io.IOBase - def close(self): - if not self.closed: + def readable(self) -> bool: + return True + + def close(self) -> None: + if not self.closed and self._fp: self._fp.close() if self._connection: @@ -633,9 +950,9 @@ def close(self): io.IOBase.close(self) @property - def closed(self): + def closed(self) -> bool: if not self.auto_close: - return io.IOBase.closed.__get__(self) + return io.IOBase.closed.__get__(self) # type: ignore[no-any-return] elif self._fp is None: return True elif hasattr(self._fp, "isclosed"): @@ -645,18 +962,18 @@ def closed(self): else: return True - def fileno(self): + def fileno(self) -> int: if self._fp is None: - raise IOError("HTTPResponse has no file to get a fileno from") + raise OSError("HTTPResponse has no file to get a fileno from") elif hasattr(self._fp, "fileno"): return self._fp.fileno() else: - raise IOError( + raise OSError( "The file-like object this HTTPResponse is wrapped " "around has no file descriptor" ) - def flush(self): + def flush(self) -> None: if ( self._fp is not None and hasattr(self._fp, "flush") @@ -664,20 +981,7 @@ def flush(self): ): return self._fp.flush() - def readable(self): - # This method is required for `io` module compatibility. - return True - - def readinto(self, b): - # This method is required for `io` module compatibility. - temp = self.read(len(b)) - if len(temp) == 0: - return 0 - else: - b[: len(temp)] = temp - return len(temp) - - def supports_chunked_reads(self): + def supports_chunked_reads(self) -> bool: """ Checks if the underlying file-like object looks like a :class:`http.client.HTTPResponse` object. We do this by testing for @@ -686,43 +990,45 @@ def supports_chunked_reads(self): """ return hasattr(self._fp, "fp") - def _update_chunk_length(self): + def _update_chunk_length(self) -> None: # First, we'll figure out length of a chunk and then # we'll try to read it from socket. if self.chunk_left is not None: - return - line = self._fp.fp.readline() + return None + line = self._fp.fp.readline() # type: ignore[union-attr] line = line.split(b";", 1)[0] try: self.chunk_left = int(line, 16) except ValueError: # Invalid chunked protocol response, abort. self.close() - raise InvalidChunkLength(self, line) + raise InvalidChunkLength(self, line) from None - def _handle_chunk(self, amt): + def _handle_chunk(self, amt: int | None) -> bytes: returned_chunk = None if amt is None: - chunk = self._fp._safe_read(self.chunk_left) + chunk = self._fp._safe_read(self.chunk_left) # type: ignore[union-attr] returned_chunk = chunk - self._fp._safe_read(2) # Toss the CRLF at the end of the chunk. + self._fp._safe_read(2) # type: ignore[union-attr] # Toss the CRLF at the end of the chunk. self.chunk_left = None - elif amt < self.chunk_left: - value = self._fp._safe_read(amt) + elif self.chunk_left is not None and amt < self.chunk_left: + value = self._fp._safe_read(amt) # type: ignore[union-attr] self.chunk_left = self.chunk_left - amt returned_chunk = value elif amt == self.chunk_left: - value = self._fp._safe_read(amt) - self._fp._safe_read(2) # Toss the CRLF at the end of the chunk. + value = self._fp._safe_read(amt) # type: ignore[union-attr] + self._fp._safe_read(2) # type: ignore[union-attr] # Toss the CRLF at the end of the chunk. self.chunk_left = None returned_chunk = value else: # amt > self.chunk_left - returned_chunk = self._fp._safe_read(self.chunk_left) - self._fp._safe_read(2) # Toss the CRLF at the end of the chunk. + returned_chunk = self._fp._safe_read(self.chunk_left) # type: ignore[union-attr] + self._fp._safe_read(2) # type: ignore[union-attr] # Toss the CRLF at the end of the chunk. self.chunk_left = None - return returned_chunk + return returned_chunk # type: ignore[no-any-return] - def read_chunked(self, amt=None, decode_content=None): + def read_chunked( + self, amt: int | None = None, decode_content: bool | None = None + ) -> typing.Generator[bytes, None, None]: """ Similar to :meth:`HTTPResponse.read`, but with an additional parameter: ``decode_content``. @@ -753,12 +1059,12 @@ def read_chunked(self, amt=None, decode_content=None): # Don't bother reading the body of a HEAD request. if self._original_response and is_response_to_head(self._original_response): self._original_response.close() - return + return None # If a response is already read and closed # then return immediately. - if self._fp.fp is None: - return + if self._fp.fp is None: # type: ignore[union-attr] + return None while True: self._update_chunk_length() @@ -780,7 +1086,7 @@ def read_chunked(self, amt=None, decode_content=None): yield decoded # Chunk content ends with \r\n: discard it. - while True: + while self._fp is not None: line = self._fp.fp.readline() if not line: # Some sites may not end with '\r\n'. @@ -792,27 +1098,29 @@ def read_chunked(self, amt=None, decode_content=None): if self._original_response: self._original_response.close() - def geturl(self): + @property + def url(self) -> str | None: """ Returns the URL that was the source of this response. If the request that generated this response redirected, this method will return the final redirect location. """ - if self.retries is not None and len(self.retries.history): - return self.retries.history[-1].redirect_location - else: - return self._request_url + return self._request_url + + @url.setter + def url(self, url: str) -> None: + self._request_url = url - def __iter__(self): - buffer = [] + def __iter__(self) -> typing.Iterator[bytes]: + buffer: list[bytes] = [] for chunk in self.stream(decode_content=True): if b"\n" in chunk: - chunk = chunk.split(b"\n") - yield b"".join(buffer) + chunk[0] + b"\n" - for x in chunk[1:-1]: + chunks = chunk.split(b"\n") + yield b"".join(buffer) + chunks[0] + b"\n" + for x in chunks[1:-1]: yield x + b"\n" - if chunk[-1]: - buffer = [chunk[-1]] + if chunks[-1]: + buffer = [chunks[-1]] else: buffer = [] else: diff --git a/src/urllib3/util/__init__.py b/src/urllib3/util/__init__.py index 4547fc522b..ff56c55bae 100644 --- a/src/urllib3/util/__init__.py +++ b/src/urllib3/util/__init__.py @@ -1,46 +1,41 @@ -from __future__ import absolute_import - # For backwards compatibility, provide imports that used to be here. +from __future__ import annotations + from .connection import is_connection_dropped from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers from .response import is_fp_closed from .retry import Retry from .ssl_ import ( ALPN_PROTOCOLS, - HAS_SNI, IS_PYOPENSSL, IS_SECURETRANSPORT, - PROTOCOL_TLS, SSLContext, assert_fingerprint, + create_urllib3_context, resolve_cert_reqs, resolve_ssl_version, ssl_wrap_socket, ) -from .timeout import Timeout, current_time -from .url import Url, get_host, parse_url, split_first +from .timeout import Timeout +from .url import Url, parse_url from .wait import wait_for_read, wait_for_write __all__ = ( - "HAS_SNI", "IS_PYOPENSSL", "IS_SECURETRANSPORT", "SSLContext", - "PROTOCOL_TLS", "ALPN_PROTOCOLS", "Retry", "Timeout", "Url", "assert_fingerprint", - "current_time", + "create_urllib3_context", "is_connection_dropped", "is_fp_closed", - "get_host", "parse_url", "make_headers", "resolve_cert_reqs", "resolve_ssl_version", - "split_first", "ssl_wrap_socket", "wait_for_read", "wait_for_write", diff --git a/src/urllib3/util/connection.py b/src/urllib3/util/connection.py index cd57455748..5c7da73f4e 100644 --- a/src/urllib3/util/connection.py +++ b/src/urllib3/util/connection.py @@ -1,34 +1,23 @@ -from __future__ import absolute_import +from __future__ import annotations import socket +import typing -from urllib3.exceptions import LocationParseError +from ..exceptions import LocationParseError +from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT -from ..contrib import _appengine_environ -from ..packages import six -from .wait import NoWayToWaitForSocketError, wait_for_read +_TYPE_SOCKET_OPTIONS = typing.Sequence[typing.Tuple[int, int, typing.Union[int, bytes]]] +if typing.TYPE_CHECKING: + from .._base_connection import BaseHTTPConnection -def is_connection_dropped(conn): # Platform-specific + +def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific """ Returns True if the connection is dropped and should be closed. - - :param conn: - :class:`http.client.HTTPConnection` object. - - Note: For platforms like AppEngine, this will always return ``False`` to - let the platform handle connection recycling transparently for us. + :param conn: :class:`urllib3.connection.HTTPConnection` object. """ - sock = getattr(conn, "sock", False) - if sock is False: # Platform-specific: AppEngine - return False - if sock is None: # Connection already closed (such as by httplib). - return True - try: - # Returns True if readable, which here means it's been dropped - return wait_for_read(sock, timeout=0.0) - except NoWayToWaitForSocketError: # Platform-specific: AppEngine - return False + return not conn.is_connected # This function is copied from socket.py in the Python 2.7 standard @@ -36,11 +25,11 @@ def is_connection_dropped(conn): # Platform-specific # One additional modification is that we avoid binding to IPv6 servers # discovered in DNS if the system doesn't have IPv6 functionality. def create_connection( - address, - timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None, - socket_options=None, -): + address: tuple[str, int], + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + socket_options: _TYPE_SOCKET_OPTIONS | None = None, +) -> socket.socket: """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, @@ -66,9 +55,7 @@ def create_connection( try: host.encode("idna") except UnicodeError: - return six.raise_from( - LocationParseError(u"'%s', label empty or too long" % host), None - ) + raise LocationParseError(f"'{host}', label empty or too long") from None for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res @@ -79,26 +66,33 @@ def create_connection( # If provided, set socket level options before connecting. _set_socket_options(sock, socket_options) - if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + if timeout is not _DEFAULT_TIMEOUT: sock.settimeout(timeout) if source_address: sock.bind(source_address) sock.connect(sa) + # Break explicitly a reference cycle + err = None return sock - except socket.error as e: - err = e + except OSError as _: + err = _ if sock is not None: sock.close() - sock = None if err is not None: - raise err - - raise socket.error("getaddrinfo returns an empty list") + try: + raise err + finally: + # Break explicitly a reference cycle + err = None + else: + raise OSError("getaddrinfo returns an empty list") -def _set_socket_options(sock, options): +def _set_socket_options( + sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None +) -> None: if options is None: return @@ -106,7 +100,7 @@ def _set_socket_options(sock, options): sock.setsockopt(*opt) -def allowed_gai_family(): +def allowed_gai_family() -> socket.AddressFamily: """This function is designed to work in the context of getaddrinfo, where family=socket.AF_UNSPEC is the default and will perform a DNS search for both IPv6 and IPv4 records.""" @@ -117,18 +111,11 @@ def allowed_gai_family(): return family -def _has_ipv6(host): - """ Returns True if the system can bind an IPv6 address. """ +def _has_ipv6(host: str) -> bool: + """Returns True if the system can bind an IPv6 address.""" sock = None has_ipv6 = False - # App Engine doesn't support IPV6 sockets and actually has a quota on the - # number of sockets that can be used, so just early out here instead of - # creating a socket needlessly. - # See https://github.com/urllib3/urllib3/issues/1446 - if _appengine_environ.is_appengine_sandbox(): - return False - if socket.has_ipv6: # has_ipv6 returns true if cPython was compiled with IPv6 support. # It does not tell us if the system has IPv6 support enabled. To diff --git a/src/urllib3/util/proxy.py b/src/urllib3/util/proxy.py index 34f884d5b3..908fc6621d 100644 --- a/src/urllib3/util/proxy.py +++ b/src/urllib3/util/proxy.py @@ -1,9 +1,18 @@ -from .ssl_ import create_urllib3_context, resolve_cert_reqs, resolve_ssl_version +from __future__ import annotations + +import typing + +from .url import Url + +if typing.TYPE_CHECKING: + from ..connection import ProxyConfig def connection_requires_http_tunnel( - proxy_url=None, proxy_config=None, destination_scheme=None -): + proxy_url: Url | None = None, + proxy_config: ProxyConfig | None = None, + destination_scheme: str | None = None, +) -> bool: """ Returns True if the connection requires an HTTP CONNECT through the proxy. @@ -32,25 +41,3 @@ def connection_requires_http_tunnel( # Otherwise always use a tunnel. return True - - -def create_proxy_ssl_context( - ssl_version, cert_reqs, ca_certs=None, ca_cert_dir=None, ca_cert_data=None -): - """ - Generates a default proxy ssl context if one hasn't been provided by the - user. - """ - ssl_context = create_urllib3_context( - ssl_version=resolve_ssl_version(ssl_version), - cert_reqs=resolve_cert_reqs(cert_reqs), - ) - if ( - not ca_certs - and not ca_cert_dir - and not ca_cert_data - and hasattr(ssl_context, "load_default_certs") - ): - ssl_context.load_default_certs() - - return ssl_context diff --git a/src/urllib3/util/queue.py b/src/urllib3/util/queue.py deleted file mode 100644 index 41784104ee..0000000000 --- a/src/urllib3/util/queue.py +++ /dev/null @@ -1,22 +0,0 @@ -import collections - -from ..packages import six -from ..packages.six.moves import queue - -if six.PY2: - # Queue is imported for side effects on MS Windows. See issue #229. - import Queue as _unused_module_Queue # noqa: F401 - - -class LifoQueue(queue.Queue): - def _init(self, _): - self.queue = collections.deque() - - def _qsize(self, len=len): - return len(self.queue) - - def _put(self, item): - self.queue.append(item) - - def _get(self): - return self.queue.pop() diff --git a/src/urllib3/util/queue.pyi b/src/urllib3/util/queue.pyi deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/urllib3/util/request.py b/src/urllib3/util/request.py index 25103383ec..def099d745 100644 --- a/src/urllib3/util/request.py +++ b/src/urllib3/util/request.py @@ -1,9 +1,15 @@ -from __future__ import absolute_import +from __future__ import annotations +import io +import typing from base64 import b64encode +from enum import Enum from ..exceptions import UnrewindableBodyError -from ..packages.six import b, integer_types +from .util import to_bytes + +if typing.TYPE_CHECKING: + from typing_extensions import Final # Pass as a value within ``headers`` to skip # emitting some HTTP headers that are added automatically. @@ -14,23 +20,46 @@ ACCEPT_ENCODING = "gzip,deflate" try: - import brotli as _unused_module_brotli # noqa: F401 + try: + import brotlicffi as _unused_module_brotli # type: ignore[import] # noqa: F401 + except ImportError: + import brotli as _unused_module_brotli # type: ignore[import] # noqa: F401 except ImportError: pass else: ACCEPT_ENCODING += ",br" +try: + import zstandard as _unused_module_zstd # type: ignore[import] # noqa: F401 +except ImportError: + pass +else: + ACCEPT_ENCODING += ",zstd" + + +class _TYPE_FAILEDTELL(Enum): + token = 0 + -_FAILEDTELL = object() +_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token + +_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL] + +# When sending a request with these methods we aren't expecting +# a body so don't need to set an explicit 'Content-Length: 0' +# The reason we do this in the negative instead of tracking methods +# which 'should' have a body is because unknown methods should be +# treated as if they were 'POST' which *does* expect a body. +_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"} def make_headers( - keep_alive=None, - accept_encoding=None, - user_agent=None, - basic_auth=None, - proxy_basic_auth=None, - disable_cache=None, -): + keep_alive: bool | None = None, + accept_encoding: bool | list[str] | str | None = None, + user_agent: str | None = None, + basic_auth: str | None = None, + proxy_basic_auth: str | None = None, + disable_cache: bool | None = None, +) -> dict[str, str]: """ Shortcuts for generating request headers. @@ -39,7 +68,8 @@ def make_headers( :param accept_encoding: Can be a boolean, list, or string. - ``True`` translates to 'gzip,deflate'. + ``True`` translates to 'gzip,deflate'. If either the ``brotli`` or + ``brotlicffi`` package is installed 'gzip,deflate,br' is used instead. List will get joined by comma. String will be used as provided. @@ -58,14 +88,18 @@ def make_headers( :param disable_cache: If ``True``, adds 'cache-control: no-cache' header. - Example:: + Example: + + .. code-block:: python + + import urllib3 - >>> make_headers(keep_alive=True, user_agent="Batman/1.0") - {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'} - >>> make_headers(accept_encoding=True) - {'accept-encoding': 'gzip,deflate'} + print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0")) + # {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'} + print(urllib3.util.make_headers(accept_encoding=True)) + # {'accept-encoding': 'gzip,deflate'} """ - headers = {} + headers: dict[str, str] = {} if accept_encoding: if isinstance(accept_encoding, str): pass @@ -82,12 +116,14 @@ def make_headers( headers["connection"] = "keep-alive" if basic_auth: - headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8") + headers[ + "authorization" + ] = f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}" if proxy_basic_auth: - headers["proxy-authorization"] = "Basic " + b64encode( - b(proxy_basic_auth) - ).decode("utf-8") + headers[ + "proxy-authorization" + ] = f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}" if disable_cache: headers["cache-control"] = "no-cache" @@ -95,7 +131,9 @@ def make_headers( return headers -def set_file_position(body, pos): +def set_file_position( + body: typing.Any, pos: _TYPE_BODY_POSITION | None +) -> _TYPE_BODY_POSITION | None: """ If a position is provided, move file to that point. Otherwise, we'll attempt to record a position for future use. @@ -105,7 +143,7 @@ def set_file_position(body, pos): elif getattr(body, "tell", None) is not None: try: pos = body.tell() - except (IOError, OSError): + except OSError: # This differentiates from None, allowing us to catch # a failed `tell()` later when trying to rewind the body. pos = _FAILEDTELL @@ -113,7 +151,7 @@ def set_file_position(body, pos): return pos -def rewind_body(body, body_pos): +def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None: """ Attempt to rewind body to a certain position. Primarily used for request redirects and retries. @@ -125,13 +163,13 @@ def rewind_body(body, body_pos): Position to seek to in file. """ body_seek = getattr(body, "seek", None) - if body_seek is not None and isinstance(body_pos, integer_types): + if body_seek is not None and isinstance(body_pos, int): try: body_seek(body_pos) - except (IOError, OSError): + except OSError as e: raise UnrewindableBodyError( "An error occurred when rewinding request body for redirect/retry." - ) + ) from e elif body_pos is _FAILEDTELL: raise UnrewindableBodyError( "Unable to record file position for rewinding " @@ -139,5 +177,80 @@ def rewind_body(body, body_pos): ) else: raise ValueError( - "body_pos must be of type integer, instead it was %s." % type(body_pos) + f"body_pos must be of type integer, instead it was {type(body_pos)}." ) + + +class ChunksAndContentLength(typing.NamedTuple): + chunks: typing.Iterable[bytes] | None + content_length: int | None + + +def body_to_chunks( + body: typing.Any | None, method: str, blocksize: int +) -> ChunksAndContentLength: + """Takes the HTTP request method, body, and blocksize and + transforms them into an iterable of chunks to pass to + socket.sendall() and an optional 'Content-Length' header. + + A 'Content-Length' of 'None' indicates the length of the body + can't be determined so should use 'Transfer-Encoding: chunked' + for framing instead. + """ + + chunks: typing.Iterable[bytes] | None + content_length: int | None + + # No body, we need to make a recommendation on 'Content-Length' + # based on whether that request method is expected to have + # a body or not. + if body is None: + chunks = None + if method.upper() not in _METHODS_NOT_EXPECTING_BODY: + content_length = 0 + else: + content_length = None + + # Bytes or strings become bytes + elif isinstance(body, (str, bytes)): + chunks = (to_bytes(body),) + content_length = len(chunks[0]) + + # File-like object, TODO: use seek() and tell() for length? + elif hasattr(body, "read"): + + def chunk_readable() -> typing.Iterable[bytes]: + nonlocal body, blocksize + encode = isinstance(body, io.TextIOBase) + while True: + datablock = body.read(blocksize) # type: ignore[union-attr] + if not datablock: + break + if encode: + datablock = datablock.encode("iso-8859-1") + yield datablock + + chunks = chunk_readable() + content_length = None + + # Otherwise we need to start checking via duck-typing. + else: + try: + # Check if the body implements the buffer API. + mv = memoryview(body) + except TypeError: + try: + # Check if the body is an iterable + chunks = iter(body) + content_length = None + except TypeError: + raise TypeError( + f"'body' must be a bytes-like object, file-like " + f"object, or iterable. Instead was {body!r}" + ) from None + else: + # Since it implements the buffer API can be passed directly to socket.sendall() + chunks = (body,) + content_length = mv.nbytes + + return ChunksAndContentLength(chunks=chunks, content_length=content_length) diff --git a/src/urllib3/util/response.py b/src/urllib3/util/response.py index 5ea609cced..0f4578696f 100644 --- a/src/urllib3/util/response.py +++ b/src/urllib3/util/response.py @@ -1,12 +1,12 @@ -from __future__ import absolute_import +from __future__ import annotations +import http.client as httplib from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect from ..exceptions import HeaderParsingError -from ..packages.six.moves import http_client as httplib -def is_fp_closed(obj): +def is_fp_closed(obj: object) -> bool: """ Checks whether a given file-like object is closed. @@ -17,27 +17,27 @@ def is_fp_closed(obj): try: # Check `isclosed()` first, in case Python3 doesn't set `closed`. # GH Issue #928 - return obj.isclosed() + return obj.isclosed() # type: ignore[no-any-return, attr-defined] except AttributeError: pass try: # Check via the official file-like-object way. - return obj.closed + return obj.closed # type: ignore[no-any-return, attr-defined] except AttributeError: pass try: # Check if the object is a container for another file-like object that # gets released on exhaustion (e.g. HTTPResponse). - return obj.fp is None + return obj.fp is None # type: ignore[attr-defined] except AttributeError: pass raise ValueError("Unable to determine whether fp is closed.") -def assert_header_parsing(headers): +def assert_header_parsing(headers: httplib.HTTPMessage) -> None: """ Asserts whether all headers have been successfully parsed. Extracts encountered errors from the result of parsing headers. @@ -53,55 +53,49 @@ def assert_header_parsing(headers): # This will fail silently if we pass in the wrong kind of parameter. # To make debugging easier add an explicit check. if not isinstance(headers, httplib.HTTPMessage): - raise TypeError("expected httplib.Message, got {0}.".format(type(headers))) - - defects = getattr(headers, "defects", None) - get_payload = getattr(headers, "get_payload", None) + raise TypeError(f"expected httplib.Message, got {type(headers)}.") unparsed_data = None - if get_payload: - # get_payload is actually email.message.Message.get_payload; - # we're only interested in the result if it's not a multipart message - if not headers.is_multipart(): - payload = get_payload() - - if isinstance(payload, (bytes, str)): - unparsed_data = payload - if defects: - # httplib is assuming a response body is available - # when parsing headers even when httplib only sends - # header data to parse_headers() This results in - # defects on multipart responses in particular. - # See: https://github.com/urllib3/urllib3/issues/800 - - # So we ignore the following defects: - # - StartBoundaryNotFoundDefect: - # The claimed start boundary was never found. - # - MultipartInvariantViolationDefect: - # A message claimed to be a multipart but no subparts were found. - defects = [ - defect - for defect in defects - if not isinstance( - defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect) - ) - ] + + # get_payload is actually email.message.Message.get_payload; + # we're only interested in the result if it's not a multipart message + if not headers.is_multipart(): + payload = headers.get_payload() + + if isinstance(payload, (bytes, str)): + unparsed_data = payload + + # httplib is assuming a response body is available + # when parsing headers even when httplib only sends + # header data to parse_headers() This results in + # defects on multipart responses in particular. + # See: https://github.com/urllib3/urllib3/issues/800 + + # So we ignore the following defects: + # - StartBoundaryNotFoundDefect: + # The claimed start boundary was never found. + # - MultipartInvariantViolationDefect: + # A message claimed to be a multipart but no subparts were found. + defects = [ + defect + for defect in headers.defects + if not isinstance( + defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect) + ) + ] if defects or unparsed_data: raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) -def is_response_to_head(response): +def is_response_to_head(response: httplib.HTTPResponse) -> bool: """ Checks whether the request of a response has been a HEAD-request. - Handles the quirks of AppEngine. :param http.client.HTTPResponse response: Response to check if the originating request used 'HEAD' as a method. """ # FIXME: Can we do this somehow without accessing private httplib _method? - method = response._method - if isinstance(method, int): # Platform-specific: Appengine - return method == 3 - return method.upper() == "HEAD" + method_str = response._method # type: str # type: ignore[attr-defined] + return method_str.upper() == "HEAD" diff --git a/src/urllib3/util/retry.py b/src/urllib3/util/retry.py index ee51f922f8..691f62c6a7 100644 --- a/src/urllib3/util/retry.py +++ b/src/urllib3/util/retry.py @@ -1,12 +1,13 @@ -from __future__ import absolute_import +from __future__ import annotations import email import logging +import random import re import time -import warnings -from collections import namedtuple +import typing from itertools import takewhile +from types import TracebackType from ..exceptions import ( ConnectTimeoutError, @@ -17,79 +18,49 @@ ReadTimeoutError, ResponseError, ) -from ..packages import six +from .util import reraise + +if typing.TYPE_CHECKING: + from ..connectionpool import ConnectionPool + from ..response import BaseHTTPResponse log = logging.getLogger(__name__) # Data structure for representing the metadata of requests that result in a retry. -RequestHistory = namedtuple( - "RequestHistory", ["method", "url", "error", "status", "redirect_location"] -) - - -# TODO: In v2 we can remove this sentinel and metaclass with deprecated options. -_Default = object() +class RequestHistory(typing.NamedTuple): + method: str | None + url: str | None + error: Exception | None + status: int | None + redirect_location: str | None -class _RetryMeta(type): - @property - def DEFAULT_METHOD_WHITELIST(cls): - warnings.warn( - "Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and " - "will be removed in v2.0. Use 'Retry.DEFAULT_METHODS_ALLOWED' instead", - DeprecationWarning, - ) - return cls.DEFAULT_ALLOWED_METHODS - - @DEFAULT_METHOD_WHITELIST.setter - def DEFAULT_METHOD_WHITELIST(cls, value): - warnings.warn( - "Using 'Retry.DEFAULT_METHOD_WHITELIST' is deprecated and " - "will be removed in v2.0. Use 'Retry.DEFAULT_ALLOWED_METHODS' instead", - DeprecationWarning, - ) - cls.DEFAULT_ALLOWED_METHODS = value - - @property - def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls): - warnings.warn( - "Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and " - "will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead", - DeprecationWarning, - ) - return cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT - - @DEFAULT_REDIRECT_HEADERS_BLACKLIST.setter - def DEFAULT_REDIRECT_HEADERS_BLACKLIST(cls, value): - warnings.warn( - "Using 'Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST' is deprecated and " - "will be removed in v2.0. Use 'Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT' instead", - DeprecationWarning, - ) - cls.DEFAULT_REMOVE_HEADERS_ON_REDIRECT = value - - -@six.add_metaclass(_RetryMeta) -class Retry(object): +class Retry: """Retry configuration. Each retry attempt will create a new Retry object with updated values, so they can be safely reused. - Retries can be defined as a default for a pool:: + Retries can be defined as a default for a pool: + + .. code-block:: python retries = Retry(connect=5, read=2, redirect=5) http = PoolManager(retries=retries) - response = http.request('GET', 'http://example.com/') + response = http.request("GET", "https://example.com/") + + Or per-request (which overrides the default for the pool): - Or per-request (which overrides the default for the pool):: + .. code-block:: python - response = http.request('GET', 'http://example.com/', retries=Retry(10)) + response = http.request("GET", "https://example.com/", retries=Retry(10)) - Retries can be disabled by passing ``False``:: + Retries can be disabled by passing ``False``: - response = http.request('GET', 'http://example.com/', retries=False) + .. code-block:: python + + response = http.request("GET", "https://example.com/", retries=False) Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless retries are disabled, in which case the causing exception will be raised. @@ -151,21 +122,16 @@ class Retry(object): If ``total`` is not set, it's a good idea to set this to 0 to account for unexpected edge cases and avoid infinite retry loops. - :param iterable allowed_methods: + :param Collection allowed_methods: Set of uppercased HTTP method verbs that we should retry on. By default, we only retry on methods which are considered to be idempotent (multiple requests with the same parameters end with the same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`. - Set to a ``False`` value to retry on any verb. - - .. warning:: - - Previously this parameter was named ``method_whitelist``, that - usage is deprecated in v1.26.0 and will be removed in v2.0. + Set to a ``None`` value to retry on any verb. - :param iterable status_forcelist: + :param Collection status_forcelist: A set of integer HTTP status codes that we should force a retry on. A retry is initiated if the request method is in ``allowed_methods`` and the response status code is in ``status_forcelist``. @@ -179,11 +145,15 @@ class Retry(object): {backoff factor} * (2 ** ({number of total retries} - 1)) - seconds. If the backoff_factor is 0.1, then :func:`.sleep` will sleep - for [0.0s, 0.2s, 0.4s, ...] between retries. It will never be longer - than :attr:`Retry.BACKOFF_MAX`. + seconds. If `backoff_jitter` is non-zero, this sleep is extended by:: - By default, backoff is disabled (set to 0). + random.uniform(0, {backoff jitter}) + + seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will + sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever + be longer than `backoff_max`. + + By default, backoff is disabled (factor set to 0). :param bool raise_on_redirect: Whether, if the number of redirects is exhausted, to raise a MaxRetryError, or to return a response with a @@ -202,7 +172,7 @@ class Retry(object): Whether to respect Retry-After header on status codes defined as :attr:`Retry.RETRY_AFTER_STATUS_CODES` or not. - :param iterable remove_headers_on_redirect: + :param Collection remove_headers_on_redirect: Sequence of headers to remove from the request when a response indicating a redirect is returned before firing off the redirected request. @@ -219,47 +189,33 @@ class Retry(object): #: Default headers to be used for ``remove_headers_on_redirect`` DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset(["Authorization"]) - #: Maximum backoff time. - BACKOFF_MAX = 120 + #: Default maximum backoff time. + DEFAULT_BACKOFF_MAX = 120 + + # Backward compatibility; assigned outside of the class. + DEFAULT: typing.ClassVar[Retry] def __init__( self, - total=10, - connect=None, - read=None, - redirect=None, - status=None, - other=None, - allowed_methods=_Default, - status_forcelist=None, - backoff_factor=0, - raise_on_redirect=True, - raise_on_status=True, - history=None, - respect_retry_after_header=True, - remove_headers_on_redirect=_Default, - # TODO: Deprecated, remove in v2.0 - method_whitelist=_Default, - ): - - if method_whitelist is not _Default: - if allowed_methods is not _Default: - raise ValueError( - "Using both 'allowed_methods' and " - "'method_whitelist' together is not allowed. " - "Instead only use 'allowed_methods'" - ) - warnings.warn( - "Using 'method_whitelist' with Retry is deprecated and " - "will be removed in v2.0. Use 'allowed_methods' instead", - DeprecationWarning, - ) - allowed_methods = method_whitelist - if allowed_methods is _Default: - allowed_methods = self.DEFAULT_ALLOWED_METHODS - if remove_headers_on_redirect is _Default: - remove_headers_on_redirect = self.DEFAULT_REMOVE_HEADERS_ON_REDIRECT - + total: bool | int | None = 10, + connect: int | None = None, + read: int | None = None, + redirect: bool | int | None = None, + status: int | None = None, + other: int | None = None, + allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS, + status_forcelist: typing.Collection[int] | None = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + raise_on_redirect: bool = True, + raise_on_status: bool = True, + history: tuple[RequestHistory, ...] | None = None, + respect_retry_after_header: bool = True, + remove_headers_on_redirect: typing.Collection[ + str + ] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT, + backoff_jitter: float = 0.0, + ) -> None: self.total = total self.connect = connect self.read = read @@ -274,15 +230,17 @@ def __init__( self.status_forcelist = status_forcelist or set() self.allowed_methods = allowed_methods self.backoff_factor = backoff_factor + self.backoff_max = backoff_max self.raise_on_redirect = raise_on_redirect self.raise_on_status = raise_on_status - self.history = history or tuple() + self.history = history or () self.respect_retry_after_header = respect_retry_after_header self.remove_headers_on_redirect = frozenset( - [h.lower() for h in remove_headers_on_redirect] + h.lower() for h in remove_headers_on_redirect ) + self.backoff_jitter = backoff_jitter - def new(self, **kw): + def new(self, **kw: typing.Any) -> Retry: params = dict( total=self.total, connect=self.connect, @@ -290,37 +248,29 @@ def new(self, **kw): redirect=self.redirect, status=self.status, other=self.other, + allowed_methods=self.allowed_methods, status_forcelist=self.status_forcelist, backoff_factor=self.backoff_factor, + backoff_max=self.backoff_max, raise_on_redirect=self.raise_on_redirect, raise_on_status=self.raise_on_status, history=self.history, remove_headers_on_redirect=self.remove_headers_on_redirect, respect_retry_after_header=self.respect_retry_after_header, + backoff_jitter=self.backoff_jitter, ) - # TODO: If already given in **kw we use what's given to us - # If not given we need to figure out what to pass. We decide - # based on whether our class has the 'method_whitelist' property - # and if so we pass the deprecated 'method_whitelist' otherwise - # we use 'allowed_methods'. Remove in v2.0 - if "method_whitelist" not in kw and "allowed_methods" not in kw: - if "method_whitelist" in self.__dict__: - warnings.warn( - "Using 'method_whitelist' with Retry is deprecated and " - "will be removed in v2.0. Use 'allowed_methods' instead", - DeprecationWarning, - ) - params["method_whitelist"] = self.allowed_methods - else: - params["allowed_methods"] = self.allowed_methods - params.update(kw) - return type(self)(**params) + return type(self)(**params) # type: ignore[arg-type] @classmethod - def from_int(cls, retries, redirect=True, default=None): - """ Backwards-compatibility for the old retries format.""" + def from_int( + cls, + retries: Retry | bool | int | None, + redirect: bool | int | None = True, + default: Retry | bool | int | None = None, + ) -> Retry: + """Backwards-compatibility for the old retries format.""" if retries is None: retries = default if default is not None else cls.DEFAULT @@ -332,7 +282,7 @@ def from_int(cls, retries, redirect=True, default=None): log.debug("Converted retries value: %r -> %r", retries, new_retries) return new_retries - def get_backoff_time(self): + def get_backoff_time(self) -> float: """Formula for computing the current backoff :rtype: float @@ -347,42 +297,38 @@ def get_backoff_time(self): return 0 backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1)) - return min(self.BACKOFF_MAX, backoff_value) + if self.backoff_jitter != 0.0: + backoff_value += random.random() * self.backoff_jitter + return float(max(0, min(self.backoff_max, backoff_value))) - def parse_retry_after(self, retry_after): + def parse_retry_after(self, retry_after: str) -> float: + seconds: float # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 if re.match(r"^\s*[0-9]+\s*$", retry_after): seconds = int(retry_after) else: retry_date_tuple = email.utils.parsedate_tz(retry_after) if retry_date_tuple is None: - raise InvalidHeader("Invalid Retry-After header: %s" % retry_after) - if retry_date_tuple[9] is None: # Python 2 - # Assume UTC if no timezone was specified - # On Python2.7, parsedate_tz returns None for a timezone offset - # instead of 0 if no timezone is given, where mktime_tz treats - # a None timezone offset as local time. - retry_date_tuple = retry_date_tuple[:9] + (0,) + retry_date_tuple[10:] + raise InvalidHeader(f"Invalid Retry-After header: {retry_after}") retry_date = email.utils.mktime_tz(retry_date_tuple) seconds = retry_date - time.time() - if seconds < 0: - seconds = 0 + seconds = max(seconds, 0) return seconds - def get_retry_after(self, response): - """ Get the value of Retry-After in seconds. """ + def get_retry_after(self, response: BaseHTTPResponse) -> float | None: + """Get the value of Retry-After in seconds.""" - retry_after = response.getheader("Retry-After") + retry_after = response.headers.get("Retry-After") if retry_after is None: return None return self.parse_retry_after(retry_after) - def sleep_for_retry(self, response=None): + def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: retry_after = self.get_retry_after(response) if retry_after: time.sleep(retry_after) @@ -390,13 +336,13 @@ def sleep_for_retry(self, response=None): return False - def _sleep_backoff(self): + def _sleep_backoff(self) -> None: backoff = self.get_backoff_time() if backoff <= 0: return time.sleep(backoff) - def sleep(self, response=None): + def sleep(self, response: BaseHTTPResponse | None = None) -> None: """Sleep between retry attempts. This method will respect a server's ``Retry-After`` response header @@ -412,7 +358,7 @@ def sleep(self, response=None): self._sleep_backoff() - def _is_connection_error(self, err): + def _is_connection_error(self, err: Exception) -> bool: """Errors when we're fairly sure that the server did not receive the request, so it should be safe to retry. """ @@ -420,33 +366,23 @@ def _is_connection_error(self, err): err = err.original_error return isinstance(err, ConnectTimeoutError) - def _is_read_error(self, err): + def _is_read_error(self, err: Exception) -> bool: """Errors that occur after the request has been started, so we should assume that the server began processing it. """ return isinstance(err, (ReadTimeoutError, ProtocolError)) - def _is_method_retryable(self, method): + def _is_method_retryable(self, method: str) -> bool: """Checks if a given HTTP method should be retried upon, depending if it is included in the allowed_methods """ - # TODO: For now favor if the Retry implementation sets its own method_whitelist - # property outside of our constructor to avoid breaking custom implementations. - if "method_whitelist" in self.__dict__: - warnings.warn( - "Using 'method_whitelist' with Retry is deprecated and " - "will be removed in v2.0. Use 'allowed_methods' instead", - DeprecationWarning, - ) - allowed_methods = self.method_whitelist - else: - allowed_methods = self.allowed_methods - - if allowed_methods and method.upper() not in allowed_methods: + if self.allowed_methods and method.upper() not in self.allowed_methods: return False return True - def is_retry(self, method, status_code, has_retry_after=False): + def is_retry( + self, method: str, status_code: int, has_retry_after: bool = False + ) -> bool: """Is this method/status code retryable? (Based on allowlists and control variables such as the number of total retries to allow, whether to respect the Retry-After header, whether this header is present, and @@ -459,24 +395,27 @@ def is_retry(self, method, status_code, has_retry_after=False): if self.status_forcelist and status_code in self.status_forcelist: return True - return ( + return bool( self.total and self.respect_retry_after_header and has_retry_after and (status_code in self.RETRY_AFTER_STATUS_CODES) ) - def is_exhausted(self): - """ Are we out of retries? """ - retry_counts = ( - self.total, - self.connect, - self.read, - self.redirect, - self.status, - self.other, - ) - retry_counts = list(filter(None, retry_counts)) + def is_exhausted(self) -> bool: + """Are we out of retries?""" + retry_counts = [ + x + for x in ( + self.total, + self.connect, + self.read, + self.redirect, + self.status, + self.other, + ) + if x + ] if not retry_counts: return False @@ -484,18 +423,18 @@ def is_exhausted(self): def increment( self, - method=None, - url=None, - response=None, - error=None, - _pool=None, - _stacktrace=None, - ): + method: str | None = None, + url: str | None = None, + response: BaseHTTPResponse | None = None, + error: Exception | None = None, + _pool: ConnectionPool | None = None, + _stacktrace: TracebackType | None = None, + ) -> Retry: """Return a new Retry object with incremented retry counters. :param response: A response object, or None, if the server did not return a response. - :type response: :class:`~urllib3.response.HTTPResponse` + :type response: :class:`~urllib3.response.BaseHTTPResponse` :param Exception error: An error encountered during the request, or None if the response was received successfully. @@ -503,7 +442,7 @@ def increment( """ if self.total is False and error: # Disabled, indicate to re-raise the error. - raise six.reraise(type(error), error, _stacktrace) + raise reraise(type(error), error, _stacktrace) total = self.total if total is not None: @@ -521,14 +460,14 @@ def increment( if error and self._is_connection_error(error): # Connect retry? if connect is False: - raise six.reraise(type(error), error, _stacktrace) + raise reraise(type(error), error, _stacktrace) elif connect is not None: connect -= 1 elif error and self._is_read_error(error): # Read retry? - if read is False or not self._is_method_retryable(method): - raise six.reraise(type(error), error, _stacktrace) + if read is False or method is None or not self._is_method_retryable(method): + raise reraise(type(error), error, _stacktrace) elif read is not None: read -= 1 @@ -542,7 +481,9 @@ def increment( if redirect is not None: redirect -= 1 cause = "too many redirects" - redirect_location = response.get_redirect_location() + response_redirect_location = response.get_redirect_location() + if response_redirect_location: + redirect_location = response_redirect_location status = response.status else: @@ -570,31 +511,18 @@ def increment( ) if new_retry.is_exhausted(): - raise MaxRetryError(_pool, url, error or ResponseError(cause)) + reason = error or ResponseError(cause) + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] log.debug("Incremented Retry for (url='%s'): %r", url, new_retry) return new_retry - def __repr__(self): + def __repr__(self) -> str: return ( - "{cls.__name__}(total={self.total}, connect={self.connect}, " - "read={self.read}, redirect={self.redirect}, status={self.status})" - ).format(cls=type(self), self=self) - - def __getattr__(self, item): - if item == "method_whitelist": - # TODO: Remove this deprecated alias in v2.0 - warnings.warn( - "Using 'method_whitelist' with Retry is deprecated and " - "will be removed in v2.0. Use 'allowed_methods' instead", - DeprecationWarning, - ) - return self.allowed_methods - try: - return getattr(super(Retry, self), item) - except AttributeError: - return getattr(Retry, item) + f"{type(self).__name__}(total={self.total}, connect={self.connect}, " + f"read={self.read}, redirect={self.redirect}, status={self.status})" + ) # For backwards compatibility (equivalent to pre-v1.9): diff --git a/src/urllib3/util/ssl_.py b/src/urllib3/util/ssl_.py index 1cb5e7cdc1..5e24810120 100644 --- a/src/urllib3/util/ssl_.py +++ b/src/urllib3/util/ssl_.py @@ -1,172 +1,142 @@ -from __future__ import absolute_import +from __future__ import annotations import hmac import os +import socket import sys +import typing import warnings -from binascii import hexlify, unhexlify +from binascii import unhexlify from hashlib import md5, sha1, sha256 -from ..exceptions import ( - InsecurePlatformWarning, - ProxySchemeUnsupported, - SNIMissingWarning, - SSLError, -) -from ..packages import six -from .url import BRACELESS_IPV6_ADDRZ_RE, IPV4_RE +from ..exceptions import ProxySchemeUnsupported, SSLError +from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE SSLContext = None SSLTransport = None -HAS_SNI = False +HAS_NEVER_CHECK_COMMON_NAME = False IS_PYOPENSSL = False IS_SECURETRANSPORT = False ALPN_PROTOCOLS = ["http/1.1"] +_TYPE_VERSION_INFO = typing.Tuple[int, int, int, str, int] + # Maps the length of a digest to a possible hash function producing this digest HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256} -def _const_compare_digest_backport(a, b): - """ - Compare two digests of equal length in constant time. +def _is_bpo_43522_fixed( + implementation_name: str, version_info: _TYPE_VERSION_INFO +) -> bool: + """Return True for CPython 3.8.9+, 3.9.3+ or 3.10+ where setting + SSLContext.hostname_checks_common_name to False works. + + PyPy 7.3.7 doesn't work as it doesn't ship with OpenSSL 1.1.1l+ + so we're waiting for a version of PyPy that works before + allowing this function to return 'True'. + + Outside of CPython and PyPy we don't know which implementations work + or not so we conservatively use our hostname matching as we know that works + on all implementations. - The digests must be of type str/bytes. - Returns True if the digests match, and False otherwise. + https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963 + https://foss.heptapod.net/pypy/pypy/-/issues/3539# """ - result = abs(len(a) - len(b)) - for left, right in zip(bytearray(a), bytearray(b)): - result |= left ^ right - return result == 0 + if implementation_name != "cpython": + return False + + major_minor = version_info[:2] + micro = version_info[2] + return ( + (major_minor == (3, 8) and micro >= 9) + or (major_minor == (3, 9) and micro >= 3) + or major_minor >= (3, 10) + ) -_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport) +def _is_has_never_check_common_name_reliable( + openssl_version_number: int, + implementation_name: str, + version_info: _TYPE_VERSION_INFO, +) -> bool: + # Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags + # like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython. + # https://github.com/openssl/openssl/issues/14579 + # This was released in OpenSSL 1.1.1l+ (>=0x101010cf) + is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF -try: # Test for SSL features - import ssl - from ssl import HAS_SNI # Has SNI? - from ssl import CERT_REQUIRED, wrap_socket + return is_openssl_issue_14579_fixed or _is_bpo_43522_fixed( + implementation_name, version_info + ) - from .ssltransport import SSLTransport -except ImportError: - pass -try: # Platform-specific: Python 3.6 - from ssl import PROTOCOL_TLS +if typing.TYPE_CHECKING: + from ssl import VerifyMode - PROTOCOL_SSLv23 = PROTOCOL_TLS -except ImportError: - try: - from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS + from typing_extensions import Literal, TypedDict - PROTOCOL_SSLv23 = PROTOCOL_TLS - except ImportError: - PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 + from .ssltransport import SSLTransport as SSLTransportType + class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False): + subjectAltName: tuple[tuple[str, str], ...] + subject: tuple[tuple[tuple[str, str], ...], ...] + serialNumber: str -try: - from ssl import OP_NO_COMPRESSION, OP_NO_SSLv2, OP_NO_SSLv3 -except ImportError: - OP_NO_SSLv2, OP_NO_SSLv3 = 0x1000000, 0x2000000 - OP_NO_COMPRESSION = 0x20000 +# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X' +_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {} -try: # OP_NO_TICKET was added in Python 3.6 - from ssl import OP_NO_TICKET -except ImportError: - OP_NO_TICKET = 0x4000 - - -# A secure default. -# Sources for more information on TLS ciphers: -# -# - https://wiki.mozilla.org/Security/Server_Side_TLS -# - https://www.ssllabs.com/projects/best-practices/index.html -# - https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ -# -# The general intent is: -# - prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE), -# - prefer ECDHE over DHE for better performance, -# - prefer any AES-GCM and ChaCha20 over any AES-CBC for better performance and -# security, -# - prefer AES-GCM over ChaCha20 because hardware-accelerated AES is common, -# - disable NULL authentication, MD5 MACs, DSS, and other -# insecure ciphers for security reasons. -# - NOTE: TLS 1.3 cipher suites are managed through a different interface -# not exposed by CPython (yet!) and are enabled by default if they're available. -DEFAULT_CIPHERS = ":".join( - [ - "ECDHE+AESGCM", - "ECDHE+CHACHA20", - "DHE+AESGCM", - "DHE+CHACHA20", - "ECDH+AESGCM", - "DH+AESGCM", - "ECDH+AES", - "DH+AES", - "RSA+AESGCM", - "RSA+AES", - "!aNULL", - "!eNULL", - "!MD5", - "!DSS", - ] -) - -try: - from ssl import SSLContext # Modern SSL? -except ImportError: +try: # Do we have ssl at all? + import ssl + from ssl import ( # type: ignore[assignment] + CERT_REQUIRED, + HAS_NEVER_CHECK_COMMON_NAME, + OP_NO_COMPRESSION, + OP_NO_TICKET, + OPENSSL_VERSION_NUMBER, + PROTOCOL_TLS, + PROTOCOL_TLS_CLIENT, + OP_NO_SSLv2, + OP_NO_SSLv3, + SSLContext, + TLSVersion, + ) - class SSLContext(object): # Platform-specific: Python 2 - def __init__(self, protocol_version): - self.protocol = protocol_version - # Use default values from a real SSLContext - self.check_hostname = False - self.verify_mode = ssl.CERT_NONE - self.ca_certs = None - self.options = 0 - self.certfile = None - self.keyfile = None - self.ciphers = None + PROTOCOL_SSLv23 = PROTOCOL_TLS - def load_cert_chain(self, certfile, keyfile): - self.certfile = certfile - self.keyfile = keyfile + # Setting SSLContext.hostname_checks_common_name = False didn't work before CPython + # 3.8.9, 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+ + if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable( + OPENSSL_VERSION_NUMBER, + sys.implementation.name, + sys.version_info, + ): + HAS_NEVER_CHECK_COMMON_NAME = False + + # Need to be careful here in case old TLS versions get + # removed in future 'ssl' module implementations. + for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"): + try: + _SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr( + TLSVersion, attr + ) + except AttributeError: # Defensive: + continue - def load_verify_locations(self, cafile=None, capath=None, cadata=None): - self.ca_certs = cafile + from .ssltransport import SSLTransport # type: ignore[assignment] +except ImportError: + OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment] + OP_NO_TICKET = 0x4000 # type: ignore[assignment] + OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment] + OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment] + PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment] + PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment] - if capath is not None: - raise SSLError("CA directories not supported in older Pythons") - if cadata is not None: - raise SSLError("CA data not supported in older Pythons") +_TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None] - def set_ciphers(self, cipher_suite): - self.ciphers = cipher_suite - def wrap_socket(self, socket, server_hostname=None, server_side=False): - warnings.warn( - "A true SSLContext object is not available. This prevents " - "urllib3 from configuring SSL appropriately and may cause " - "certain SSL connections to fail. You can upgrade to a newer " - "version of Python to solve this. For more information, see " - "https://urllib3.readthedocs.io/en/latest/advanced-usage.html" - "#ssl-warnings", - InsecurePlatformWarning, - ) - kwargs = { - "keyfile": self.keyfile, - "certfile": self.certfile, - "ca_certs": self.ca_certs, - "cert_reqs": self.verify_mode, - "ssl_version": self.protocol, - "server_side": server_side, - } - return wrap_socket(socket, ciphers=self.ciphers, **kwargs) - - -def assert_fingerprint(cert, fingerprint): +def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None: """ Checks if given fingerprint matches the supplied certificate. @@ -176,26 +146,27 @@ def assert_fingerprint(cert, fingerprint): Fingerprint as string of hexdigits, can be interspersed by colons. """ + if cert is None: + raise SSLError("No certificate for the peer.") + fingerprint = fingerprint.replace(":", "").lower() digest_length = len(fingerprint) hashfunc = HASHFUNC_MAP.get(digest_length) if not hashfunc: - raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint)) + raise SSLError(f"Fingerprint of invalid length: {fingerprint}") # We need encode() here for py32; works on py2 and p33. fingerprint_bytes = unhexlify(fingerprint.encode()) cert_digest = hashfunc(cert).digest() - if not _const_compare_digest(cert_digest, fingerprint_bytes): + if not hmac.compare_digest(cert_digest, fingerprint_bytes): raise SSLError( - 'Fingerprints did not match. Expected "{0}", got "{1}".'.format( - fingerprint, hexlify(cert_digest) - ) + f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"' ) -def resolve_cert_reqs(candidate): +def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode: """ Resolves the argument to a numeric constant, which can be passed to the wrap_socket function/method from the ssl module. @@ -213,12 +184,12 @@ def resolve_cert_reqs(candidate): res = getattr(ssl, candidate, None) if res is None: res = getattr(ssl, "CERT_" + candidate) - return res + return res # type: ignore[no-any-return] - return candidate + return candidate # type: ignore[return-value] -def resolve_ssl_version(candidate): +def resolve_ssl_version(candidate: None | int | str) -> int: """ like resolve_cert_reqs """ @@ -229,35 +200,33 @@ def resolve_ssl_version(candidate): res = getattr(ssl, candidate, None) if res is None: res = getattr(ssl, "PROTOCOL_" + candidate) - return res + return typing.cast(int, res) return candidate def create_urllib3_context( - ssl_version=None, cert_reqs=None, options=None, ciphers=None -): - """All arguments have the same meaning as ``ssl_wrap_socket``. - - By default, this function does a lot of the same work that - ``ssl.create_default_context`` does on Python 3.4+. It: - - - Disables SSLv2, SSLv3, and compression - - Sets a restricted set of server ciphers - - If you wish to enable SSLv3, you can do:: - - from urllib3.util import ssl_ - context = ssl_.create_urllib3_context() - context.options &= ~ssl_.OP_NO_SSLv3 - - You can do the same to enable compression (substituting ``COMPRESSION`` - for ``SSLv3`` in the last line above). + ssl_version: int | None = None, + cert_reqs: int | None = None, + options: int | None = None, + ciphers: str | None = None, + ssl_minimum_version: int | None = None, + ssl_maximum_version: int | None = None, +) -> ssl.SSLContext: + """Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3. :param ssl_version: The desired protocol version to use. This will default to PROTOCOL_SSLv23 which will negotiate the highest protocol that both the server and your installation of OpenSSL support. + + This parameter is deprecated instead use 'ssl_minimum_version'. + :param ssl_minimum_version: + The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value. + :param ssl_maximum_version: + The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value. + Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the + default value. :param cert_reqs: Whether to require the certificate verification. This defaults to ``ssl.CERT_REQUIRED``. @@ -265,14 +234,60 @@ def create_urllib3_context( Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``, ``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``. :param ciphers: - Which cipher suites to allow the server to select. + Which cipher suites to allow the server to select. Defaults to either system configured + ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers. :returns: Constructed SSLContext object with specified options :rtype: SSLContext """ - context = SSLContext(ssl_version or PROTOCOL_TLS) + if SSLContext is None: + raise TypeError("Can't create an SSLContext object without an ssl module") + + # This means 'ssl_version' was specified as an exact value. + if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT): + # Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version' + # to avoid conflicts. + if ssl_minimum_version is not None or ssl_maximum_version is not None: + raise ValueError( + "Can't specify both 'ssl_version' and either " + "'ssl_minimum_version' or 'ssl_maximum_version'" + ) + + # 'ssl_version' is deprecated and will be removed in the future. + else: + # Use 'ssl_minimum_version' and 'ssl_maximum_version' instead. + ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MINIMUM_SUPPORTED + ) + ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MAXIMUM_SUPPORTED + ) + + # This warning message is pushing users to use 'ssl_minimum_version' + # instead of both min/max. Best practice is to only set the minimum version and + # keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED' + warnings.warn( + "'ssl_version' option is deprecated and will be " + "removed in urllib3 v2.1.0. Instead use 'ssl_minimum_version'", + category=DeprecationWarning, + stacklevel=2, + ) + + # PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT + context = SSLContext(PROTOCOL_TLS_CLIENT) + + if ssl_minimum_version is not None: + context.minimum_version = ssl_minimum_version + else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here + context.minimum_version = TLSVersion.TLSv1_2 + + if ssl_maximum_version is not None: + context.maximum_version = ssl_maximum_version - context.set_ciphers(ciphers or DEFAULT_CIPHERS) + # Unless we're given ciphers defer to either system ciphers in + # the case of OpenSSL 1.1.1+ or use our own secure default ciphers. + if ciphers: + context.set_ciphers(ciphers) # Setting the default here, as we may have no ssl module on import cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs @@ -305,13 +320,22 @@ def create_urllib3_context( ) is not None: context.post_handshake_auth = True - context.verify_mode = cert_reqs - if ( - getattr(context, "check_hostname", None) is not None - ): # Platform-specific: Python 3.2 - # We do our own verification, including fingerprints and alternative - # hostnames. So disable it here + # The order of the below lines setting verify_mode and check_hostname + # matter due to safe-guards SSLContext has to prevent an SSLContext with + # check_hostname=True, verify_mode=NONE/OPTIONAL. + # We always set 'check_hostname=False' for pyOpenSSL so we rely on our own + # 'ssl.match_hostname()' implementation. + if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL: + context.verify_mode = cert_reqs + context.check_hostname = True + else: context.check_hostname = False + context.verify_mode = cert_reqs + + try: + context.hostname_checks_common_name = False + except AttributeError: + pass # Enable logging of TLS session keys via defacto standard environment variable # 'SSLKEYLOGFILE', if the feature is available (Python 3.8+). Skip empty values. @@ -323,21 +347,59 @@ def create_urllib3_context( return context +@typing.overload +def ssl_wrap_socket( + sock: socket.socket, + keyfile: str | None = ..., + certfile: str | None = ..., + cert_reqs: int | None = ..., + ca_certs: str | None = ..., + server_hostname: str | None = ..., + ssl_version: int | None = ..., + ciphers: str | None = ..., + ssl_context: ssl.SSLContext | None = ..., + ca_cert_dir: str | None = ..., + key_password: str | None = ..., + ca_cert_data: None | str | bytes = ..., + tls_in_tls: Literal[False] = ..., +) -> ssl.SSLSocket: + ... + + +@typing.overload +def ssl_wrap_socket( + sock: socket.socket, + keyfile: str | None = ..., + certfile: str | None = ..., + cert_reqs: int | None = ..., + ca_certs: str | None = ..., + server_hostname: str | None = ..., + ssl_version: int | None = ..., + ciphers: str | None = ..., + ssl_context: ssl.SSLContext | None = ..., + ca_cert_dir: str | None = ..., + key_password: str | None = ..., + ca_cert_data: None | str | bytes = ..., + tls_in_tls: bool = ..., +) -> ssl.SSLSocket | SSLTransportType: + ... + + def ssl_wrap_socket( - sock, - keyfile=None, - certfile=None, - cert_reqs=None, - ca_certs=None, - server_hostname=None, - ssl_version=None, - ciphers=None, - ssl_context=None, - ca_cert_dir=None, - key_password=None, - ca_cert_data=None, - tls_in_tls=False, -): + sock: socket.socket, + keyfile: str | None = None, + certfile: str | None = None, + cert_reqs: int | None = None, + ca_certs: str | None = None, + server_hostname: str | None = None, + ssl_version: int | None = None, + ciphers: str | None = None, + ssl_context: ssl.SSLContext | None = None, + ca_cert_dir: str | None = None, + key_password: str | None = None, + ca_cert_data: None | str | bytes = None, + tls_in_tls: bool = False, +) -> ssl.SSLSocket | SSLTransportType: """ All arguments except for server_hostname, ssl_context, and ca_cert_dir have the same meaning as they do when using :func:`ssl.wrap_socket`. @@ -363,19 +425,18 @@ def ssl_wrap_socket( """ context = ssl_context if context is None: - # Note: This branch of code and all the variables in it are no longer - # used by urllib3 itself. We should consider deprecating and removing - # this code. + # Note: This branch of code and all the variables in it are only used in tests. + # We should consider deprecating and removing this code. context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers) if ca_certs or ca_cert_dir or ca_cert_data: try: context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data) - except (IOError, OSError) as e: - raise SSLError(e) + except OSError as e: + raise SSLError(e) from e elif ssl_context is None and hasattr(context, "load_default_certs"): - # try to load OS default certs; works well on Windows (require Python3.4+) + # try to load OS default certs; works well on Windows. context.load_default_certs() # Attempt to detect if we get the goofy behavior of the @@ -391,56 +452,30 @@ def ssl_wrap_socket( context.load_cert_chain(certfile, keyfile, key_password) try: - if hasattr(context, "set_alpn_protocols"): - context.set_alpn_protocols(ALPN_PROTOCOLS) - except NotImplementedError: + context.set_alpn_protocols(ALPN_PROTOCOLS) + except NotImplementedError: # Defensive: in CI, we always have set_alpn_protocols pass - # If we detect server_hostname is an IP address then the SNI - # extension should not be used according to RFC3546 Section 3.1 - use_sni_hostname = server_hostname and not is_ipaddress(server_hostname) - # SecureTransport uses server_hostname in certificate verification. - send_sni = (use_sni_hostname and HAS_SNI) or ( - IS_SECURETRANSPORT and server_hostname - ) - # Do not warn the user if server_hostname is an invalid SNI hostname. - if not HAS_SNI and use_sni_hostname: - warnings.warn( - "An HTTPS request has been made, but the SNI (Server Name " - "Indication) extension to TLS is not available on this platform. " - "This may cause the server to present an incorrect TLS " - "certificate, which can cause validation failures. You can upgrade to " - "a newer version of Python to solve this. For more information, see " - "https://urllib3.readthedocs.io/en/latest/advanced-usage.html" - "#ssl-warnings", - SNIMissingWarning, - ) - - if send_sni: - ssl_sock = _ssl_wrap_socket_impl( - sock, context, tls_in_tls, server_hostname=server_hostname - ) - else: - ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls) + ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname) return ssl_sock -def is_ipaddress(hostname): +def is_ipaddress(hostname: str | bytes) -> bool: """Detects whether the hostname given is an IPv4 or IPv6 address. Also detects IPv6 addresses with Zone IDs. :param str hostname: Hostname to examine. :return: True if the hostname is an IP address, False otherwise. """ - if not six.PY2 and isinstance(hostname, bytes): + if isinstance(hostname, bytes): # IDN A-label bytes are ASCII compatible. hostname = hostname.decode("ascii") - return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname)) + return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname)) -def _is_key_file_encrypted(key_file): +def _is_key_file_encrypted(key_file: str) -> bool: """Detects if a key file is encrypted or not.""" - with open(key_file, "r") as f: + with open(key_file) as f: for line in f: # Look for Proc-Type: 4,ENCRYPTED if "ENCRYPTED" in line: @@ -449,7 +484,12 @@ def _is_key_file_encrypted(key_file): return False -def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None): +def _ssl_wrap_socket_impl( + sock: socket.socket, + ssl_context: ssl.SSLContext, + tls_in_tls: bool, + server_hostname: str | None = None, +) -> ssl.SSLSocket | SSLTransportType: if tls_in_tls: if not SSLTransport: # Import error, ssl is not available. @@ -460,7 +500,4 @@ def _ssl_wrap_socket_impl(sock, ssl_context, tls_in_tls, server_hostname=None): SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context) return SSLTransport(sock, ssl_context, server_hostname) - if server_hostname: - return ssl_context.wrap_socket(sock, server_hostname=server_hostname) - else: - return ssl_context.wrap_socket(sock) + return ssl_context.wrap_socket(sock, server_hostname=server_hostname) diff --git a/src/urllib3/packages/ssl_match_hostname/_implementation.py b/src/urllib3/util/ssl_match_hostname.py similarity index 65% rename from src/urllib3/packages/ssl_match_hostname/_implementation.py rename to src/urllib3/util/ssl_match_hostname.py index 689208d3c6..453cfd420d 100644 --- a/src/urllib3/packages/ssl_match_hostname/_implementation.py +++ b/src/urllib3/util/ssl_match_hostname.py @@ -1,19 +1,18 @@ -"""The match_hostname() function from Python 3.3.3, essential when using SSL.""" +"""The match_hostname() function from Python 3.5, essential when using SSL.""" # Note: This file is under the PSF license as the code comes from the python # stdlib. http://docs.python.org/3/license.html +# It is modified to remove commonName support. +from __future__ import annotations + +import ipaddress import re -import sys +import typing +from ipaddress import IPv4Address, IPv6Address -# ipaddress has been backported to 2.6+ in pypi. If it is installed on the -# system, use it to handle IPAddress ServerAltnames (this was added in -# python-3.5) otherwise only do DNS matching. This allows -# backports.ssl_match_hostname to continue to be used in Python 2.7. -try: - import ipaddress -except ImportError: - ipaddress = None +if typing.TYPE_CHECKING: + from .ssl_ import _TYPE_PEER_CERT_RET_DICT __version__ = "3.5.0.1" @@ -22,7 +21,9 @@ class CertificateError(ValueError): pass -def _dnsname_match(dn, hostname, max_wildcards=1): +def _dnsname_match( + dn: typing.Any, hostname: str, max_wildcards: int = 1 +) -> typing.Match[str] | None | bool: """Matching according to RFC 6125, section 6.4.3 http://tools.ietf.org/html/rfc6125#section-6.4.3 @@ -49,7 +50,7 @@ def _dnsname_match(dn, hostname, max_wildcards=1): # speed up common case w/o wildcards if not wildcards: - return dn.lower() == hostname.lower() + return bool(dn.lower() == hostname.lower()) # RFC 6125, section 6.4.3, subitem 1. # The client SHOULD NOT attempt to match a presented identifier in which @@ -76,25 +77,26 @@ def _dnsname_match(dn, hostname, max_wildcards=1): return pat.match(hostname) -def _to_unicode(obj): - if isinstance(obj, str) and sys.version_info < (3,): - obj = unicode(obj, encoding="ascii", errors="strict") - return obj - - -def _ipaddress_match(ipname, host_ip): +def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool: """Exact matching of IP addresses. - RFC 6125 explicitly doesn't define an algorithm for this - (section 1.7.2 - "Out of Scope"). + RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded + bytes of the IP address. An IP version 4 address is 4 octets, and an IP + version 6 address is 16 octets. [...] A reference identity of type IP-ID + matches if the address is identical to an iPAddress value of the + subjectAltName extension of the certificate." """ # OpenSSL may add a trailing newline to a subjectAltName's IP address # Divergence from upstream: ipaddress can't handle byte str - ip = ipaddress.ip_address(_to_unicode(ipname).rstrip()) - return ip == host_ip + ip = ipaddress.ip_address(ipname.rstrip()) + return bool(ip.packed == host_ip.packed) -def match_hostname(cert, hostname): +def match_hostname( + cert: _TYPE_PEER_CERT_RET_DICT | None, + hostname: str, + hostname_checks_common_name: bool = False, +) -> None: """Verify that *cert* (in decoded format as returned by SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 rules are followed, but IP addresses are not accepted for *hostname*. @@ -110,23 +112,22 @@ def match_hostname(cert, hostname): ) try: # Divergence from upstream: ipaddress can't handle byte str - host_ip = ipaddress.ip_address(_to_unicode(hostname)) + # + # The ipaddress module shipped with Python < 3.9 does not support + # scoped IPv6 addresses so we unconditionally strip the Zone IDs for + # now. Once we drop support for Python 3.9 we can remove this branch. + if "%" in hostname: + host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")]) + else: + host_ip = ipaddress.ip_address(hostname) + except ValueError: # Not an IP address (common case) host_ip = None - except UnicodeError: - # Divergence from upstream: Have to deal with ipaddress not taking - # byte strings. addresses should be all ascii, so we consider it not - # an ipaddress in this case - host_ip = None - except AttributeError: - # Divergence from upstream: Make ipaddress library optional - if ipaddress is None: - host_ip = None - else: - raise dnsnames = [] - san = cert.get("subjectAltName", ()) + san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ()) + key: str + value: str for key, value in san: if key == "DNS": if host_ip is None and _dnsname_match(value, hostname): @@ -136,25 +137,23 @@ def match_hostname(cert, hostname): if host_ip is not None and _ipaddress_match(value, host_ip): return dnsnames.append(value) - if not dnsnames: - # The subject is only checked when there is no dNSName entry - # in subjectAltName + + # We only check 'commonName' if it's enabled and we're not verifying + # an IP address. IP addresses aren't valid within 'commonName'. + if hostname_checks_common_name and host_ip is None and not dnsnames: for sub in cert.get("subject", ()): for key, value in sub: - # XXX according to RFC 2818, the most specific Common Name - # must be used. if key == "commonName": if _dnsname_match(value, hostname): return dnsnames.append(value) + if len(dnsnames) > 1: raise CertificateError( "hostname %r " "doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames))) ) elif len(dnsnames) == 1: - raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0])) + raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}") else: - raise CertificateError( - "no appropriate commonName or subjectAltName fields were found" - ) + raise CertificateError("no appropriate subjectAltName fields were found") diff --git a/src/urllib3/util/ssltransport.py b/src/urllib3/util/ssltransport.py index 1e41354f5d..5ec86473b4 100644 --- a/src/urllib3/util/ssltransport.py +++ b/src/urllib3/util/ssltransport.py @@ -1,9 +1,21 @@ +from __future__ import annotations + import io import socket import ssl +import typing + +from ..exceptions import ProxySchemeUnsupported + +if typing.TYPE_CHECKING: + from typing_extensions import Literal + + from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT + -from urllib3.exceptions import ProxySchemeUnsupported -from urllib3.packages import six +_SelfT = typing.TypeVar("_SelfT", bound="SSLTransport") +_WriteBuffer = typing.Union[bytearray, memoryview] +_ReturnValue = typing.TypeVar("_ReturnValue") SSL_BLOCKSIZE = 16384 @@ -20,7 +32,7 @@ class SSLTransport: """ @staticmethod - def _validate_ssl_context_for_tls_in_tls(ssl_context): + def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: """ Raises a ProxySchemeUnsupported if the provided ssl_context can't be used for TLS in TLS. @@ -30,20 +42,18 @@ def _validate_ssl_context_for_tls_in_tls(ssl_context): """ if not hasattr(ssl_context, "wrap_bio"): - if six.PY2: - raise ProxySchemeUnsupported( - "TLS in TLS requires SSLContext.wrap_bio() which isn't " - "supported on Python 2" - ) - else: - raise ProxySchemeUnsupported( - "TLS in TLS requires SSLContext.wrap_bio() which isn't " - "available on non-native SSLContext" - ) + raise ProxySchemeUnsupported( + "TLS in TLS requires SSLContext.wrap_bio() which isn't " + "available on non-native SSLContext" + ) def __init__( - self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True - ): + self, + socket: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + suppress_ragged_eofs: bool = True, + ) -> None: """ Create an SSLTransport around socket using the provided ssl_context. """ @@ -60,33 +70,36 @@ def __init__( # Perform initial handshake. self._ssl_io_loop(self.sslobj.do_handshake) - def __enter__(self): + def __enter__(self: _SelfT) -> _SelfT: return self - def __exit__(self, *_): + def __exit__(self, *_: typing.Any) -> None: self.close() - def fileno(self): + def fileno(self) -> int: return self.socket.fileno() - def read(self, len=1024, buffer=None): + def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: return self._wrap_ssl_read(len, buffer) - def recv(self, len=1024, flags=0): + def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: if flags != 0: raise ValueError("non-zero flags not allowed in calls to recv") - return self._wrap_ssl_read(len) - - def recv_into(self, buffer, nbytes=None, flags=0): + return self._wrap_ssl_read(buflen) + + def recv_into( + self, + buffer: _WriteBuffer, + nbytes: int | None = None, + flags: int = 0, + ) -> None | int | bytes: if flags != 0: raise ValueError("non-zero flags not allowed in calls to recv_into") - if buffer and (nbytes is None): + if nbytes is None: nbytes = len(buffer) - elif nbytes is None: - nbytes = 1024 return self.read(nbytes, buffer) - def sendall(self, data, flags=0): + def sendall(self, data: bytes, flags: int = 0) -> None: if flags != 0: raise ValueError("non-zero flags not allowed in calls to sendall") count = 0 @@ -96,15 +109,20 @@ def sendall(self, data, flags=0): v = self.send(byte_view[count:]) count += v - def send(self, data, flags=0): + def send(self, data: bytes, flags: int = 0) -> int: if flags != 0: raise ValueError("non-zero flags not allowed in calls to send") - response = self._ssl_io_loop(self.sslobj.write, data) - return response + return self._ssl_io_loop(self.sslobj.write, data) def makefile( - self, mode="r", buffering=None, encoding=None, errors=None, newline=None - ): + self, + mode: str, + buffering: int | None = None, + *, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: """ Python's httpclient uses makefile and buffered io when reading HTTP messages and we need to support it. @@ -113,7 +131,7 @@ def makefile( changes to point to the socket directly. """ if not set(mode) <= {"r", "w", "b"}: - raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) + raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") writing = "w" in mode reading = "r" in mode or not writing @@ -124,8 +142,8 @@ def makefile( rawmode += "r" if writing: rawmode += "w" - raw = socket.SocketIO(self, rawmode) - self.socket._io_refs += 1 + raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] + self.socket._io_refs += 1 # type: ignore[attr-defined] if buffering is None: buffering = -1 if buffering < 0: @@ -134,8 +152,9 @@ def makefile( if not binary: raise ValueError("unbuffered streams must be binary") return raw + buffer: typing.BinaryIO if reading and writing: - buffer = io.BufferedRWPair(raw, raw, buffering) + buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] elif reading: buffer = io.BufferedReader(raw, buffering) else: @@ -144,46 +163,56 @@ def makefile( if binary: return buffer text = io.TextIOWrapper(buffer, encoding, errors, newline) - text.mode = mode + text.mode = mode # type: ignore[misc] return text - def unwrap(self): + def unwrap(self) -> None: self._ssl_io_loop(self.sslobj.unwrap) - def close(self): + def close(self) -> None: self.socket.close() - def getpeercert(self, binary_form=False): - return self.sslobj.getpeercert(binary_form) + @typing.overload + def getpeercert( + self, binary_form: Literal[False] = ... + ) -> _TYPE_PEER_CERT_RET_DICT | None: + ... + + @typing.overload + def getpeercert(self, binary_form: Literal[True]) -> bytes | None: + ... - def version(self): + def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: + return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] + + def version(self) -> str | None: return self.sslobj.version() - def cipher(self): + def cipher(self) -> tuple[str, str, int] | None: return self.sslobj.cipher() - def selected_alpn_protocol(self): + def selected_alpn_protocol(self) -> str | None: return self.sslobj.selected_alpn_protocol() - def selected_npn_protocol(self): + def selected_npn_protocol(self) -> str | None: return self.sslobj.selected_npn_protocol() - def shared_ciphers(self): + def shared_ciphers(self) -> list[tuple[str, str, int]] | None: return self.sslobj.shared_ciphers() - def compression(self): + def compression(self) -> str | None: return self.sslobj.compression() - def settimeout(self, value): + def settimeout(self, value: float | None) -> None: self.socket.settimeout(value) - def gettimeout(self): + def gettimeout(self) -> float | None: return self.socket.gettimeout() - def _decref_socketios(self): - self.socket._decref_socketios() + def _decref_socketios(self) -> None: + self.socket._decref_socketios() # type: ignore[attr-defined] - def _wrap_ssl_read(self, len, buffer=None): + def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: try: return self._ssl_io_loop(self.sslobj.read, len, buffer) except ssl.SSLError as e: @@ -192,15 +221,45 @@ def _wrap_ssl_read(self, len, buffer=None): else: raise - def _ssl_io_loop(self, func, *args): - """ Performs an I/O loop between incoming/outgoing and the socket.""" + # func is sslobj.do_handshake or sslobj.unwrap + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: + ... + + # func is sslobj.write, arg1 is data + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: + ... + + # func is sslobj.read, arg1 is len, arg2 is buffer + @typing.overload + def _ssl_io_loop( + self, + func: typing.Callable[[int, bytearray | None], bytes], + arg1: int, + arg2: bytearray | None, + ) -> bytes: + ... + + def _ssl_io_loop( + self, + func: typing.Callable[..., _ReturnValue], + arg1: None | bytes | int = None, + arg2: bytearray | None = None, + ) -> _ReturnValue: + """Performs an I/O loop between incoming/outgoing and the socket.""" should_loop = True ret = None while should_loop: errno = None try: - ret = func(*args) + if arg1 is None and arg2 is None: + ret = func() + elif arg2 is None: + ret = func(arg1) + else: + ret = func(arg1, arg2) except ssl.SSLError as e: if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): # WANT_READ, and WANT_WRITE are expected, others are not. @@ -218,4 +277,4 @@ def _ssl_io_loop(self, func, *args): self.incoming.write(buf) else: self.incoming.write_eof() - return ret + return typing.cast(_ReturnValue, ret) diff --git a/src/urllib3/util/timeout.py b/src/urllib3/util/timeout.py index ff69593b05..ec090f69cc 100644 --- a/src/urllib3/util/timeout.py +++ b/src/urllib3/util/timeout.py @@ -1,45 +1,56 @@ -from __future__ import absolute_import +from __future__ import annotations import time - -# The default socket timeout, used by httplib to indicate that no timeout was -# specified by the user -from socket import _GLOBAL_DEFAULT_TIMEOUT +import typing +from enum import Enum +from socket import getdefaulttimeout from ..exceptions import TimeoutStateError -# A sentinel value to indicate that no timeout was specified by the user in -# urllib3 -_Default = object() +if typing.TYPE_CHECKING: + from typing_extensions import Final + + +class _TYPE_DEFAULT(Enum): + # This value should never be passed to socket.settimeout() so for safety we use a -1. + # socket.settimout() raises a ValueError for negative values. + token = -1 + +_DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token -# Use time.monotonic if available. -current_time = getattr(time, "monotonic", time.time) +_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]] -class Timeout(object): +class Timeout: """Timeout configuration. Timeouts can be defined as a default for a pool: .. code-block:: python - timeout = Timeout(connect=2.0, read=7.0) - http = PoolManager(timeout=timeout) - response = http.request('GET', 'http://example.com/') + import urllib3 + + timeout = urllib3.util.Timeout(connect=2.0, read=7.0) + + http = urllib3.PoolManager(timeout=timeout) + + resp = http.request("GET", "https://example.com/") + + print(resp.status) Or per-request (which overrides the default for the pool): .. code-block:: python - response = http.request('GET', 'http://example.com/', timeout=Timeout(10)) + response = http.request("GET", "https://example.com/", timeout=Timeout(10)) Timeouts can be disabled by setting all the parameters to ``None``: .. code-block:: python no_timeout = Timeout(connect=None, read=None) - response = http.request('GET', 'http://example.com/, timeout=no_timeout) + response = http.request("GET", "https://example.com/", timeout=no_timeout) :param total: @@ -97,27 +108,31 @@ class Timeout(object): """ #: A sentinel object representing the default timeout value - DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT - - def __init__(self, total=None, connect=_Default, read=_Default): + DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT + + def __init__( + self, + total: _TYPE_TIMEOUT = None, + connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + ) -> None: self._connect = self._validate_timeout(connect, "connect") self._read = self._validate_timeout(read, "read") self.total = self._validate_timeout(total, "total") - self._start_connect = None + self._start_connect: float | None = None - def __repr__(self): - return "%s(connect=%r, read=%r, total=%r)" % ( - type(self).__name__, - self._connect, - self._read, - self.total, - ) + def __repr__(self) -> str: + return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})" # __str__ provided for backwards compatibility __str__ = __repr__ + @staticmethod + def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None: + return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout + @classmethod - def _validate_timeout(cls, value, name): + def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT: """Check that a timeout attribute is valid. :param value: The timeout value to validate @@ -127,10 +142,7 @@ def _validate_timeout(cls, value, name): :raises ValueError: If it is a numeric value less than or equal to zero, or the type is not an integer, float, or None. """ - if value is _Default: - return cls.DEFAULT_TIMEOUT - - if value is None or value is cls.DEFAULT_TIMEOUT: + if value is None or value is _DEFAULT_TIMEOUT: return value if isinstance(value, bool): @@ -144,7 +156,7 @@ def _validate_timeout(cls, value, name): raise ValueError( "Timeout value %s was %s, but it must be an " "int, float or None." % (name, value) - ) + ) from None try: if value <= 0: @@ -154,16 +166,15 @@ def _validate_timeout(cls, value, name): "than or equal to 0." % (name, value) ) except TypeError: - # Python 3 raise ValueError( "Timeout value %s was %s, but it must be an " "int, float or None." % (name, value) - ) + ) from None return value @classmethod - def from_float(cls, timeout): + def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout: """Create a new Timeout from a legacy timeout value. The timeout value used by httplib.py sets the same timeout on the @@ -172,13 +183,13 @@ def from_float(cls, timeout): passed to this function. :param timeout: The legacy timeout value. - :type timeout: integer, float, sentinel default object, or None + :type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None :return: Timeout object :rtype: :class:`Timeout` """ return Timeout(read=timeout, connect=timeout) - def clone(self): + def clone(self) -> Timeout: """Create a copy of the timeout object Timeout properties are stored per-pool but each request needs a fresh @@ -192,7 +203,7 @@ def clone(self): # detect the user default. return Timeout(connect=self._connect, read=self._read, total=self.total) - def start_connect(self): + def start_connect(self) -> float: """Start the timeout clock, used during a connect() attempt :raises urllib3.exceptions.TimeoutStateError: if you attempt @@ -200,10 +211,10 @@ def start_connect(self): """ if self._start_connect is not None: raise TimeoutStateError("Timeout timer has already been started.") - self._start_connect = current_time() + self._start_connect = time.monotonic() return self._start_connect - def get_connect_duration(self): + def get_connect_duration(self) -> float: """Gets the time elapsed since the call to :meth:`start_connect`. :return: Elapsed time in seconds. @@ -215,10 +226,10 @@ def get_connect_duration(self): raise TimeoutStateError( "Can't get connect duration for timer that has not started." ) - return current_time() - self._start_connect + return time.monotonic() - self._start_connect @property - def connect_timeout(self): + def connect_timeout(self) -> _TYPE_TIMEOUT: """Get the value to use when setting a connection timeout. This will be a positive float or integer, the value None @@ -230,13 +241,13 @@ def connect_timeout(self): if self.total is None: return self._connect - if self._connect is None or self._connect is self.DEFAULT_TIMEOUT: + if self._connect is None or self._connect is _DEFAULT_TIMEOUT: return self.total - return min(self._connect, self.total) + return min(self._connect, self.total) # type: ignore[type-var] @property - def read_timeout(self): + def read_timeout(self) -> float | None: """Get the value for the read timeout. This assumes some time has elapsed in the connection timeout and @@ -248,21 +259,21 @@ def read_timeout(self): raised. :return: Value to use for the read timeout. - :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None + :rtype: int, float or None :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect` has not yet been called on this object. """ if ( self.total is not None - and self.total is not self.DEFAULT_TIMEOUT + and self.total is not _DEFAULT_TIMEOUT and self._read is not None - and self._read is not self.DEFAULT_TIMEOUT + and self._read is not _DEFAULT_TIMEOUT ): # In case the connect timeout has not yet been established. if self._start_connect is None: return self._read return max(0, min(self.total - self.get_connect_duration(), self._read)) - elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT: + elif self.total is not None and self.total is not _DEFAULT_TIMEOUT: return max(0, self.total - self.get_connect_duration()) else: - return self._read + return self.resolve_default_timeout(self._read) diff --git a/src/urllib3/util/url.py b/src/urllib3/util/url.py index 6ff238fe3c..d53ea932a0 100644 --- a/src/urllib3/util/url.py +++ b/src/urllib3/util/url.py @@ -1,22 +1,20 @@ -from __future__ import absolute_import +from __future__ import annotations import re -from collections import namedtuple +import typing from ..exceptions import LocationParseError -from ..packages import six - -url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"] +from .util import to_str # We only want to normalize urls with an HTTP(S) scheme. # urllib3 infers URLs without a scheme (None) to be http. -NORMALIZABLE_SCHEMES = ("http", "https", None) +_NORMALIZABLE_SCHEMES = ("http", "https", None) # Almost all of these patterns were derived from the # 'rfc3986' module: https://github.com/python-hyper/rfc3986 -PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}") -SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)") -URI_RE = re.compile( +_PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}") +_SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)") +_URI_RE = re.compile( r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?" r"(?://([^\\/?#]*))?" r"([^?#]*)" @@ -25,10 +23,10 @@ re.UNICODE | re.DOTALL, ) -IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" -HEX_PAT = "[0-9A-Fa-f]{1,4}" -LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT) -_subs = {"hex": HEX_PAT, "ls32": LS32_PAT} +_IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" +_HEX_PAT = "[0-9A-Fa-f]{1,4}" +_LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT) +_subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT} _variations = [ # 6( h16 ":" ) ls32 "(?:%(hex)s:){6}%(ls32)s", @@ -50,69 +48,78 @@ "(?:(?:%(hex)s:){0,6}%(hex)s)?::", ] -UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~" -IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" -ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" -IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]" -REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*" -TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$") - -IPV4_RE = re.compile("^" + IPV4_PAT + "$") -IPV6_RE = re.compile("^" + IPV6_PAT + "$") -IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$") -BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$") -ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$") - -SUBAUTHORITY_PAT = (u"^(?:(.*)@)?(%s|%s|%s)(?::([0-9]{0,5}))?$") % ( - REG_NAME_PAT, - IPV4_PAT, - IPV6_ADDRZ_PAT, +_UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~" +_IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" +_ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" +_IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]" +_REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*" +_TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$") + +_IPV4_RE = re.compile("^" + _IPV4_PAT + "$") +_IPV6_RE = re.compile("^" + _IPV6_PAT + "$") +_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$") +_BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$") +_ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$") + +_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % ( + _REG_NAME_PAT, + _IPV4_PAT, + _IPV6_ADDRZ_PAT, ) -SUBAUTHORITY_RE = re.compile(SUBAUTHORITY_PAT, re.UNICODE | re.DOTALL) +_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL) -UNRESERVED_CHARS = set( +_UNRESERVED_CHARS = set( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~" ) -SUB_DELIM_CHARS = set("!$&'()*+,;=") -USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"} -PATH_CHARS = USERINFO_CHARS | {"@", "/"} -QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"} - - -class Url(namedtuple("Url", url_attrs)): +_SUB_DELIM_CHARS = set("!$&'()*+,;=") +_USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"} +_PATH_CHARS = _USERINFO_CHARS | {"@", "/"} +_QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"} + + +class Url( + typing.NamedTuple( + "Url", + [ + ("scheme", typing.Optional[str]), + ("auth", typing.Optional[str]), + ("host", typing.Optional[str]), + ("port", typing.Optional[int]), + ("path", typing.Optional[str]), + ("query", typing.Optional[str]), + ("fragment", typing.Optional[str]), + ], + ) +): """ Data structure for representing an HTTP URL. Used as a return value for :func:`parse_url`. Both the scheme and host are normalized as they are both case-insensitive according to RFC 3986. """ - __slots__ = () - - def __new__( + def __new__( # type: ignore[no-untyped-def] cls, - scheme=None, - auth=None, - host=None, - port=None, - path=None, - query=None, - fragment=None, + scheme: str | None = None, + auth: str | None = None, + host: str | None = None, + port: int | None = None, + path: str | None = None, + query: str | None = None, + fragment: str | None = None, ): if path and not path.startswith("/"): path = "/" + path if scheme is not None: scheme = scheme.lower() - return super(Url, cls).__new__( - cls, scheme, auth, host, port, path, query, fragment - ) + return super().__new__(cls, scheme, auth, host, port, path, query, fragment) @property - def hostname(self): + def hostname(self) -> str | None: """For backwards-compatibility with urlparse. We're nice like that.""" return self.host @property - def request_uri(self): + def request_uri(self) -> str: """Absolute path including the query string.""" uri = self.path or "/" @@ -122,14 +129,37 @@ def request_uri(self): return uri @property - def netloc(self): - """Network location including host and port""" + def authority(self) -> str | None: + """ + Authority component as defined in RFC 3986 3.2. + This includes userinfo (auth), host and port. + + i.e. + userinfo@host:port + """ + userinfo = self.auth + netloc = self.netloc + if netloc is None or userinfo is None: + return netloc + else: + return f"{userinfo}@{netloc}" + + @property + def netloc(self) -> str | None: + """ + Network location including host and port. + + If you need the equivalent of urllib.parse's ``netloc``, + use the ``authority`` property instead. + """ + if self.host is None: + return None if self.port: - return "%s:%d" % (self.host, self.port) + return f"{self.host}:{self.port}" return self.host @property - def url(self): + def url(self) -> str: """ Convert self into a url @@ -138,88 +168,77 @@ def url(self): :func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls with a blank port will have : removed). - Example: :: + Example: + + .. code-block:: python + + import urllib3 - >>> U = parse_url('http://google.com/mail/') - >>> U.url - 'http://google.com/mail/' - >>> Url('http', 'username:password', 'host.com', 80, - ... '/path', 'query', 'fragment').url - 'http://username:password@host.com:80/path?query#fragment' + U = urllib3.util.parse_url("https://google.com/mail/") + + print(U.url) + # "https://google.com/mail/" + + print( urllib3.util.Url("https", "username:password", + "host.com", 80, "/path", "query", "fragment" + ).url + ) + # "https://username:password@host.com:80/path?query#fragment" """ scheme, auth, host, port, path, query, fragment = self - url = u"" + url = "" # We use "is not None" we want things to happen with empty strings (or 0 port) if scheme is not None: - url += scheme + u"://" + url += scheme + "://" if auth is not None: - url += auth + u"@" + url += auth + "@" if host is not None: url += host if port is not None: - url += u":" + str(port) + url += ":" + str(port) if path is not None: url += path if query is not None: - url += u"?" + query + url += "?" + query if fragment is not None: - url += u"#" + fragment + url += "#" + fragment return url - def __str__(self): + def __str__(self) -> str: return self.url -def split_first(s, delims): - """ - .. deprecated:: 1.25 - - Given a string and an iterable of delimiters, split on the first found - delimiter. Return two split parts and the matched delimiter. - - If not found, then the first part is the full input string. - - Example:: - - >>> split_first('foo/bar?baz', '?/=') - ('foo', 'bar?baz', '/') - >>> split_first('foo/bar?baz', '123') - ('foo/bar?baz', '', None) - - Scales linearly with number of delims. Not ideal for large number of delims. - """ - min_idx = None - min_delim = None - for d in delims: - idx = s.find(d) - if idx < 0: - continue +@typing.overload +def _encode_invalid_chars( + component: str, allowed_chars: typing.Container[str] +) -> str: # Abstract + ... - if min_idx is None or idx < min_idx: - min_idx = idx - min_delim = d - if min_idx is None or min_idx < 0: - return s, "", None +@typing.overload +def _encode_invalid_chars( + component: None, allowed_chars: typing.Container[str] +) -> None: # Abstract + ... - return s[:min_idx], s[min_idx + 1 :], min_delim - -def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"): +def _encode_invalid_chars( + component: str | None, allowed_chars: typing.Container[str] +) -> str | None: """Percent-encodes a URI component without reapplying onto an already percent-encoded component. """ if component is None: return component - component = six.ensure_text(component) + component = to_str(component) # Normalize existing percent-encoded bytes. # Try to see if the component we're encoding is already percent-encoded # so we can skip all '%' characters but still encode all others. - component, percent_encodings = PERCENT_RE.subn( + component, percent_encodings = _PERCENT_RE.subn( lambda match: match.group(0).upper(), component ) @@ -228,7 +247,7 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"): encoded_component = bytearray() for i in range(0, len(uri_bytes)): - # Will return a single character bytestring on both Python 2 & 3 + # Will return a single character bytestring byte = uri_bytes[i : i + 1] byte_ord = ord(byte) if (is_percent_encoded and byte == b"%") or ( @@ -238,10 +257,10 @@ def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"): continue encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper())) - return encoded_component.decode(encoding) + return encoded_component.decode() -def _remove_path_dot_segments(path): +def _remove_path_dot_segments(path: str) -> str: # See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code segments = path.split("/") # Turn the path into a list of segments output = [] # Initialize the variable to use to store output @@ -251,7 +270,7 @@ def _remove_path_dot_segments(path): if segment == ".": continue # Anything other than '..', should be appended to the output - elif segment != "..": + if segment != "..": output.append(segment) # In this case segment == '..', if we can, we should pop the last # element @@ -271,15 +290,25 @@ def _remove_path_dot_segments(path): return "/".join(output) -def _normalize_host(host, scheme): - if host: - if isinstance(host, six.binary_type): - host = six.ensure_str(host) +@typing.overload +def _normalize_host(host: None, scheme: str | None) -> None: + ... + + +@typing.overload +def _normalize_host(host: str, scheme: str | None) -> str: + ... + - if scheme in NORMALIZABLE_SCHEMES: - is_ipv6 = IPV6_ADDRZ_RE.match(host) +def _normalize_host(host: str | None, scheme: str | None) -> str | None: + if host: + if scheme in _NORMALIZABLE_SCHEMES: + is_ipv6 = _IPV6_ADDRZ_RE.match(host) if is_ipv6: - match = ZONE_ID_RE.search(host) + # IPv6 hosts of the form 'a::b%zone' are encoded in a URL as + # such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID + # separator as necessary to return a valid RFC 4007 scoped IP. + match = _ZONE_ID_RE.search(host) if match: start, end = match.span(1) zone_id = host[start:end] @@ -288,108 +317,138 @@ def _normalize_host(host, scheme): zone_id = zone_id[3:] else: zone_id = zone_id[1:] - zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS) - return host[:start].lower() + zone_id + host[end:] + zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS) + return f"{host[:start].lower()}%{zone_id}{host[end:]}" else: return host.lower() - elif not IPV4_RE.match(host): - return six.ensure_str( - b".".join([_idna_encode(label) for label in host.split(".")]) + elif not _IPV4_RE.match(host): + return to_str( + b".".join([_idna_encode(label) for label in host.split(".")]), + "ascii", ) return host -def _idna_encode(name): - if name and any([ord(x) > 128 for x in name]): +def _idna_encode(name: str) -> bytes: + if not name.isascii(): try: import idna except ImportError: - six.raise_from( - LocationParseError("Unable to parse URL without the 'idna' module"), - None, - ) + raise LocationParseError( + "Unable to parse URL without the 'idna' module" + ) from None + try: return idna.encode(name.lower(), strict=True, std3_rules=True) except idna.IDNAError: - six.raise_from( - LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None - ) + raise LocationParseError( + f"Name '{name}' is not a valid IDNA label" + ) from None + return name.lower().encode("ascii") -def _encode_target(target): - """Percent-encodes a request target so that there are no invalid characters""" - path, query = TARGET_RE.match(target).groups() - target = _encode_invalid_chars(path, PATH_CHARS) - query = _encode_invalid_chars(query, QUERY_CHARS) +def _encode_target(target: str) -> str: + """Percent-encodes a request target so that there are no invalid characters + + Pre-condition for this function is that 'target' must start with '/'. + If that is the case then _TARGET_RE will always produce a match. + """ + match = _TARGET_RE.match(target) + if not match: # Defensive: + raise LocationParseError(f"{target!r} is not a valid request URI") + + path, query = match.groups() + encoded_target = _encode_invalid_chars(path, _PATH_CHARS) if query is not None: - target += "?" + query - return target + query = _encode_invalid_chars(query, _QUERY_CHARS) + encoded_target += "?" + query + return encoded_target -def parse_url(url): +def parse_url(url: str) -> Url: """ Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is performed to parse incomplete urls. Fields not provided will be None. - This parser is RFC 3986 compliant. + This parser is RFC 3986 and RFC 6874 compliant. The parser logic and helper functions are based heavily on work done in the ``rfc3986`` module. :param str url: URL to parse into a :class:`.Url` namedtuple. - Partly backwards-compatible with :mod:`urlparse`. + Partly backwards-compatible with :mod:`urllib.parse`. - Example:: + Example: - >>> parse_url('http://google.com/mail/') - Url(scheme='http', host='google.com', port=None, path='/mail/', ...) - >>> parse_url('google.com:80') - Url(scheme=None, host='google.com', port=80, path=None, ...) - >>> parse_url('/foo?bar') - Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) + .. code-block:: python + + import urllib3 + + print( urllib3.util.parse_url('http://google.com/mail/')) + # Url(scheme='http', host='google.com', port=None, path='/mail/', ...) + + print( urllib3.util.parse_url('google.com:80')) + # Url(scheme=None, host='google.com', port=80, path=None, ...) + + print( urllib3.util.parse_url('/foo?bar')) + # Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) """ if not url: # Empty return Url() source_url = url - if not SCHEME_RE.search(url): + if not _SCHEME_RE.search(url): url = "//" + url + scheme: str | None + authority: str | None + auth: str | None + host: str | None + port: str | None + port_int: int | None + path: str | None + query: str | None + fragment: str | None + try: - scheme, authority, path, query, fragment = URI_RE.match(url).groups() - normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES + scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr] + normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES if scheme: scheme = scheme.lower() if authority: - auth, host, port = SUBAUTHORITY_RE.match(authority).groups() + auth, _, host_port = authority.rpartition("@") + auth = auth or None + host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr] if auth and normalize_uri: - auth = _encode_invalid_chars(auth, USERINFO_CHARS) + auth = _encode_invalid_chars(auth, _USERINFO_CHARS) if port == "": port = None else: auth, host, port = None, None, None if port is not None: - port = int(port) - if not (0 <= port <= 65535): + port_int = int(port) + if not (0 <= port_int <= 65535): raise LocationParseError(url) + else: + port_int = None host = _normalize_host(host, scheme) if normalize_uri and path: path = _remove_path_dot_segments(path) - path = _encode_invalid_chars(path, PATH_CHARS) + path = _encode_invalid_chars(path, _PATH_CHARS) if normalize_uri and query: - query = _encode_invalid_chars(query, QUERY_CHARS) + query = _encode_invalid_chars(query, _QUERY_CHARS) if normalize_uri and fragment: - fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS) + fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS) - except (ValueError, AttributeError): - return six.raise_from(LocationParseError(source_url), None) + except (ValueError, AttributeError) as e: + raise LocationParseError(source_url) from e # For the sake of backwards compatibility we put empty # string values for path if there are any defined values @@ -401,30 +460,12 @@ def parse_url(url): else: path = None - # Ensure that each part of the URL is a `str` for - # backwards compatibility. - if isinstance(url, six.text_type): - ensure_func = six.ensure_text - else: - ensure_func = six.ensure_str - - def ensure_type(x): - return x if x is None else ensure_func(x) - return Url( - scheme=ensure_type(scheme), - auth=ensure_type(auth), - host=ensure_type(host), - port=port, - path=ensure_type(path), - query=ensure_type(query), - fragment=ensure_type(fragment), + scheme=scheme, + auth=auth, + host=host, + port=port_int, + path=path, + query=query, + fragment=fragment, ) - - -def get_host(url): - """ - Deprecated. Use :func:`parse_url` instead. - """ - p = parse_url(url) - return p.scheme or "http", p.hostname, p.port diff --git a/src/urllib3/util/url.pyi b/src/urllib3/util/url.pyi deleted file mode 100644 index 67e308fe29..0000000000 --- a/src/urllib3/util/url.pyi +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Any, List, Optional, Tuple, Union - -from .. import exceptions - -LocationParseError = exceptions.LocationParseError - -url_attrs: List[str] - -class Url: - slots: Any - def __new__( - cls, - scheme: Optional[str], - auth: Optional[str], - host: Optional[str], - port: Optional[str], - path: Optional[str], - query: Optional[str], - fragment: Optional[str], - ) -> Url: ... - @property - def hostname(self) -> str: ... - @property - def request_uri(self) -> str: ... - @property - def netloc(self) -> str: ... - @property - def url(self) -> str: ... - -def split_first(s: str, delims: str) -> Tuple[str, str, Optional[str]]: ... -def parse_url(url: str) -> Url: ... -def get_host(url: str) -> Union[str, Tuple[str]]: ... diff --git a/src/urllib3/util/util.py b/src/urllib3/util/util.py new file mode 100644 index 0000000000..35c77e4025 --- /dev/null +++ b/src/urllib3/util/util.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import typing +from types import TracebackType + + +def to_bytes( + x: str | bytes, encoding: str | None = None, errors: str | None = None +) -> bytes: + if isinstance(x, bytes): + return x + elif not isinstance(x, str): + raise TypeError(f"not expecting type {type(x).__name__}") + if encoding or errors: + return x.encode(encoding or "utf-8", errors=errors or "strict") + return x.encode() + + +def to_str( + x: str | bytes, encoding: str | None = None, errors: str | None = None +) -> str: + if isinstance(x, str): + return x + elif not isinstance(x, bytes): + raise TypeError(f"not expecting type {type(x).__name__}") + if encoding or errors: + return x.decode(encoding or "utf-8", errors=errors or "strict") + return x.decode() + + +def reraise( + tp: type[BaseException] | None, + value: BaseException, + tb: TracebackType | None = None, +) -> typing.NoReturn: + try: + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None # type: ignore[assignment] + tb = None diff --git a/src/urllib3/util/wait.py b/src/urllib3/util/wait.py index c280646c7b..aeca0c7ad5 100644 --- a/src/urllib3/util/wait.py +++ b/src/urllib3/util/wait.py @@ -1,18 +1,10 @@ -import errno +from __future__ import annotations + import select -import sys +import socket from functools import partial -try: - from time import monotonic -except ImportError: - from time import time as monotonic - -__all__ = ["NoWayToWaitForSocketError", "wait_for_read", "wait_for_write"] - - -class NoWayToWaitForSocketError(Exception): - pass +__all__ = ["wait_for_read", "wait_for_write"] # How should we wait on sockets? @@ -37,38 +29,13 @@ class NoWayToWaitForSocketError(Exception): # So: on Windows we use select(), and everywhere else we use poll(). We also # fall back to select() in case poll() is somehow broken or missing. -if sys.version_info >= (3, 5): - # Modern Python, that retries syscalls by default - def _retry_on_intr(fn, timeout): - return fn(timeout) - - -else: - # Old and broken Pythons. - def _retry_on_intr(fn, timeout): - if timeout is None: - deadline = float("inf") - else: - deadline = monotonic() + timeout - - while True: - try: - return fn(timeout) - # OSError for 3 <= pyver < 3.5, select.error for pyver <= 2.7 - except (OSError, select.error) as e: - # 'e.args[0]' incantation works for both OSError and select.error - if e.args[0] != errno.EINTR: - raise - else: - timeout = deadline - monotonic() - if timeout < 0: - timeout = 0 - if timeout == float("inf"): - timeout = None - continue - - -def select_wait_for_socket(sock, read=False, write=False, timeout=None): + +def select_wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: if not read and not write: raise RuntimeError("must specify at least one of read=True, write=True") rcheck = [] @@ -83,11 +50,16 @@ def select_wait_for_socket(sock, read=False, write=False, timeout=None): # sockets for both conditions. (The stdlib selectors module does the same # thing.) fn = partial(select.select, rcheck, wcheck, wcheck) - rready, wready, xready = _retry_on_intr(fn, timeout) + rready, wready, xready = fn(timeout) return bool(rready or wready or xready) -def poll_wait_for_socket(sock, read=False, write=False, timeout=None): +def poll_wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: if not read and not write: raise RuntimeError("must specify at least one of read=True, write=True") mask = 0 @@ -99,32 +71,33 @@ def poll_wait_for_socket(sock, read=False, write=False, timeout=None): poll_obj.register(sock, mask) # For some reason, poll() takes timeout in milliseconds - def do_poll(t): + def do_poll(t: float | None) -> list[tuple[int, int]]: if t is not None: t *= 1000 return poll_obj.poll(t) - return bool(_retry_on_intr(do_poll, timeout)) - - -def null_wait_for_socket(*args, **kwargs): - raise NoWayToWaitForSocketError("no select-equivalent available") + return bool(do_poll(timeout)) -def _have_working_poll(): +def _have_working_poll() -> bool: # Apparently some systems have a select.poll that fails as soon as you try # to use it, either due to strange configuration or broken monkeypatching # from libraries like eventlet/greenlet. try: poll_obj = select.poll() - _retry_on_intr(poll_obj.poll, 0) + poll_obj.poll(0) except (AttributeError, OSError): return False else: return True -def wait_for_socket(*args, **kwargs): +def wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: # We delay choosing which implementation to use until the first time we're # called. We could do it at import time, but then we might make the wrong # decision if someone goes wild with monkeypatching select.poll after @@ -134,19 +107,17 @@ def wait_for_socket(*args, **kwargs): wait_for_socket = poll_wait_for_socket elif hasattr(select, "select"): wait_for_socket = select_wait_for_socket - else: # Platform-specific: Appengine. - wait_for_socket = null_wait_for_socket - return wait_for_socket(*args, **kwargs) + return wait_for_socket(sock, read, write, timeout) -def wait_for_read(sock, timeout=None): +def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool: """Waits for reading to be available on a given socket. Returns True if the socket is readable, or False if the timeout expired. """ return wait_for_socket(sock, read=True, timeout=timeout) -def wait_for_write(sock, timeout=None): +def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool: """Waits for writing to be available on a given socket. Returns True if the socket is readable, or False if the timeout expired. """ diff --git a/test/__init__.py b/test/__init__.py index c03cfac257..a1629e7021 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,28 +1,55 @@ +from __future__ import annotations + import errno +import importlib.util import logging import os import platform import socket -import ssl import sys +import typing import warnings +from collections.abc import Sequence +from importlib.abc import Loader, MetaPathFinder +from importlib.machinery import ModuleSpec +from types import ModuleType, TracebackType import pytest try: - import brotli + try: + import brotlicffi as brotli # type: ignore[import] + except ImportError: + import brotli # type: ignore[import] except ImportError: brotli = None +try: + import zstandard as zstd # type: ignore[import] +except ImportError: + zstd = None + +import functools + from urllib3 import util +from urllib3.connectionpool import ConnectionPool from urllib3.exceptions import HTTPWarning -from urllib3.packages import six from urllib3.util import ssl_ try: import urllib3.contrib.pyopenssl as pyopenssl except ImportError: - pyopenssl = None + pyopenssl = None # type: ignore[assignment] + +if typing.TYPE_CHECKING: + import ssl + + from typing_extensions import Literal + + +_RT = typing.TypeVar("_RT") # return type +_TestFuncT = typing.TypeVar("_TestFuncT", bound=typing.Callable[..., typing.Any]) + # We need a host that will not immediately close the connection with a TCP # Reset. @@ -38,7 +65,7 @@ VALID_SOURCE_ADDRESSES = [(("::1", 0), True), (("127.0.0.1", 0), False)] # RFC 5737: 192.0.2.0/24 is for testing only. # RFC 3849: 2001:db8::/32 is for documentation only. -INVALID_SOURCE_ADDRESSES = [("192.0.2.255", 0), ("2001:db8::1", 0)] +INVALID_SOURCE_ADDRESSES = [(("192.0.2.255", 0), False), (("2001:db8::1", 0), True)] # We use timeouts in three different ways in our tests # @@ -52,9 +79,11 @@ if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS") == "true": LONG_TIMEOUT = 0.5 +DUMMY_POOL = ConnectionPool("dummy") + -def _can_resolve(host): - """ Returns True if the system can resolve host to an address. """ +def _can_resolve(host: str) -> bool: + """Returns True if the system can resolve host to an address.""" try: socket.getaddrinfo(host, None, socket.AF_UNSPEC) return True @@ -62,10 +91,10 @@ def _can_resolve(host): return False -def has_alpn(ctx_cls=None): - """ Detect if ALPN support is enabled. """ +def has_alpn(ctx_cls: type[ssl.SSLContext] | None = None) -> bool: + """Detect if ALPN support is enabled.""" ctx_cls = ctx_cls or util.SSLContext - ctx = ctx_cls(protocol=ssl_.PROTOCOL_TLS) + ctx = ctx_cls(protocol=ssl_.PROTOCOL_TLS) # type: ignore[misc, attr-defined] try: if hasattr(ctx, "set_alpn_protocols"): ctx.set_alpn_protocols(ssl_.ALPN_PROTOCOLS) @@ -81,208 +110,134 @@ def has_alpn(ctx_cls=None): RESOLVES_LOCALHOST_FQDN = _can_resolve("localhost.") -def clear_warnings(cls=HTTPWarning): +def clear_warnings(cls: type[Warning] = HTTPWarning) -> None: new_filters = [] for f in warnings.filters: if issubclass(f[2], cls): continue new_filters.append(f) - warnings.filters[:] = new_filters + warnings.filters[:] = new_filters # type: ignore[index] -def setUp(): +def setUp() -> None: clear_warnings() warnings.simplefilter("ignore", HTTPWarning) -def onlyPy279OrNewer(test): - """Skips this test unless you are on Python 2.7.9 or later.""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} requires Python 2.7.9+ to run".format(name=test.__name__) - if sys.version_info < (2, 7, 9): - pytest.skip(msg) - return test(*args, **kwargs) - - return wrapper - - -def onlyPy2(test): - """Skips this test unless you are on Python 2.x""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} requires Python 2.x to run".format(name=test.__name__) - if not six.PY2: - pytest.skip(msg) - return test(*args, **kwargs) - - return wrapper - - -def onlyPy3(test): - """Skips this test unless you are on Python3.x""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} requires Python3.x to run".format(name=test.__name__) - if six.PY2: - pytest.skip(msg) - return test(*args, **kwargs) - - return wrapper - - -def notPyPy2(test): - """Skips this test on PyPy2""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - # https://github.com/testing-cabal/mock/issues/438 - msg = "{} fails with PyPy 2 dues to funcsigs bugs".format(test.__name__) - if platform.python_implementation() == "PyPy" and sys.version_info[0] == 2: - pytest.xfail(msg) - return test(*args, **kwargs) - - return wrapper - - -def notWindows(test): +def notWindows() -> typing.Callable[[_TestFuncT], _TestFuncT]: """Skips this test on Windows""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} does not run on Windows".format(name=test.__name__) - if platform.system() == "Windows": - pytest.skip(msg) - return test(*args, **kwargs) - - return wrapper + return pytest.mark.skipif( + platform.system() == "Windows", + reason="Test does not run on Windows", + ) -def onlyBrotlipy(): - return pytest.mark.skipif(brotli is None, reason="only run if brotlipy is present") +def onlyBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]: + return pytest.mark.skipif( + brotli is None, reason="only run if brotli library is present" + ) -def notBrotlipy(): +def notBrotli() -> typing.Callable[[_TestFuncT], _TestFuncT]: return pytest.mark.skipif( - brotli is not None, reason="only run if brotlipy is absent" + brotli is not None, reason="only run if a brotli library is absent" ) -def onlySecureTransport(test): - """Runs this test when SecureTransport is in use.""" +def onlyZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]: + return pytest.mark.skipif( + zstd is None, reason="only run if a python-zstandard library is installed" + ) - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} only runs with SecureTransport".format(name=test.__name__) - if not ssl_.IS_SECURETRANSPORT: - pytest.skip(msg) - return test(*args, **kwargs) - - return wrapper +def notZstd() -> typing.Callable[[_TestFuncT], _TestFuncT]: + return pytest.mark.skipif( + zstd is not None, + reason="only run if a python-zstandard library is not installed", + ) -def notSecureTransport(test): - """Skips this test when SecureTransport is in use.""" - @six.wraps(test) - def wrapper(*args, **kwargs): - msg = "{name} does not run with SecureTransport".format(name=test.__name__) - if ssl_.IS_SECURETRANSPORT: - pytest.skip(msg) - return test(*args, **kwargs) +# Hack to make pytest evaluate a condition at test runtime instead of collection time. +def lazy_condition(condition: typing.Callable[[], bool]) -> bool: + class LazyCondition: + def __bool__(self) -> bool: + return condition() - return wrapper + return typing.cast(bool, LazyCondition()) -def notOpenSSL098(test): - """Skips this test for Python 3.5 macOS python.org distribution""" +def onlySecureTransport() -> typing.Callable[[_TestFuncT], _TestFuncT]: + """Runs this test when SecureTransport is in use.""" + return pytest.mark.skipif( + lazy_condition(lambda: not ssl_.IS_SECURETRANSPORT), + reason="Test only runs with SecureTransport", + ) - @six.wraps(test) - def wrapper(*args, **kwargs): - is_stdlib_ssl = not ssl_.IS_SECURETRANSPORT and not ssl_.IS_PYOPENSSL - if is_stdlib_ssl and ssl.OPENSSL_VERSION == "OpenSSL 0.9.8zh 14 Jan 2016": - pytest.xfail("{name} fails with OpenSSL 0.9.8zh".format(name=test.__name__)) - return test(*args, **kwargs) - return wrapper +def notSecureTransport() -> typing.Callable[[_TestFuncT], _TestFuncT]: + """Skips this test when SecureTransport is in use.""" + return pytest.mark.skipif( + lazy_condition(lambda: ssl_.IS_SECURETRANSPORT), + reason="Test does not run with SecureTransport", + ) _requires_network_has_route = None -def requires_network(test): +def requires_network() -> typing.Callable[[_TestFuncT], _TestFuncT]: """Helps you skip tests that require the network""" - def _is_unreachable_err(err): + def _is_unreachable_err(err: Exception) -> bool: return getattr(err, "errno", None) in ( errno.ENETUNREACH, errno.EHOSTUNREACH, # For OSX ) - def _has_route(): + def _has_route() -> bool: try: sock = socket.create_connection((TARPIT_HOST, 80), 0.0001) sock.close() return True except socket.timeout: return True - except socket.error as e: + except OSError as e: if _is_unreachable_err(e): return False else: raise - @six.wraps(test) - def wrapper(*args, **kwargs): - global _requires_network_has_route - - if _requires_network_has_route is None: - _requires_network_has_route = _has_route() + global _requires_network_has_route - if _requires_network_has_route: - return test(*args, **kwargs) - else: - msg = "Can't run {name} because the network is unreachable".format( - name=test.__name__ - ) - pytest.skip(msg) - - return wrapper + if _requires_network_has_route is None: + _requires_network_has_route = _has_route() + return pytest.mark.skipif( + not _requires_network_has_route, + reason="Can't run the test because the network is unreachable", + ) -def requires_ssl_context_keyfile_password(test): - @six.wraps(test) - def wrapper(*args, **kwargs): - if ( - not ssl_.IS_PYOPENSSL and sys.version_info < (2, 7, 9) - ) or ssl_.IS_SECURETRANSPORT: - pytest.skip( - "%s requires password parameter for " - "SSLContext.load_cert_chain()" % test.__name__ - ) - return test(*args, **kwargs) - return wrapper +def requires_ssl_context_keyfile_password() -> ( + typing.Callable[[_TestFuncT], _TestFuncT] +): + return pytest.mark.skipif( + lazy_condition(lambda: ssl_.IS_SECURETRANSPORT), + reason="Test requires password parameter for SSLContext.load_cert_chain()", + ) -def resolvesLocalhostFQDN(test): +def resolvesLocalhostFQDN() -> typing.Callable[[_TestFuncT], _TestFuncT]: """Test requires successful resolving of 'localhost.'""" - - @six.wraps(test) - def wrapper(*args, **kwargs): - if not RESOLVES_LOCALHOST_FQDN: - pytest.skip("Can't resolve localhost.") - return test(*args, **kwargs) - - return wrapper + return pytest.mark.skipif( + not RESOLVES_LOCALHOST_FQDN, + reason="Can't resolve localhost.", + ) -def withPyOpenSSL(test): - @six.wraps(test) - def wrapper(*args, **kwargs): +def withPyOpenSSL(test: typing.Callable[..., _RT]) -> typing.Callable[..., _RT]: + @functools.wraps(test) + def wrapper(*args: typing.Any, **kwargs: typing.Any) -> _RT: if not pyopenssl: pytest.skip("pyopenssl not available, skipping test.") return test(*args, **kwargs) @@ -296,34 +251,114 @@ def wrapper(*args, **kwargs): class _ListHandler(logging.Handler): - def __init__(self): - super(_ListHandler, self).__init__() - self.records = [] + def __init__(self) -> None: + super().__init__() + self.records: list[logging.LogRecord] = [] - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: self.records.append(record) -class LogRecorder(object): - def __init__(self, target=logging.root): - super(LogRecorder, self).__init__() +class LogRecorder: + def __init__(self, target: logging.Logger = logging.root) -> None: + super().__init__() self._target = target self._handler = _ListHandler() @property - def records(self): + def records(self) -> list[logging.LogRecord]: return self._handler.records - def install(self): + def install(self) -> None: self._target.addHandler(self._handler) - def uninstall(self): + def uninstall(self) -> None: self._target.removeHandler(self._handler) - def __enter__(self): + def __enter__(self) -> list[logging.LogRecord]: self.install() return self.records - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> Literal[False]: self.uninstall() return False + + +class ImportBlockerLoader(Loader): + def __init__(self, fullname: str) -> None: + self._fullname = fullname + + def load_module(self, fullname: str) -> ModuleType: + raise ImportError(f"import of {fullname} is blocked") + + def exec_module(self, module: ModuleType) -> None: + raise ImportError(f"import of {self._fullname} is blocked") + + +class ImportBlocker(MetaPathFinder): + """ + Block Imports + + To be placed on ``sys.meta_path``. This ensures that the modules + specified cannot be imported, even if they are a builtin. + """ + + def __init__(self, *namestoblock: str) -> None: + self.namestoblock = namestoblock + + def find_module( + self, fullname: str, path: typing.Sequence[bytes | str] | None = None + ) -> Loader | None: + if fullname in self.namestoblock: + return ImportBlockerLoader(fullname) + return None + + def find_spec( + self, + fullname: str, + path: Sequence[bytes | str] | None, + target: ModuleType | None = None, + ) -> ModuleSpec | None: + loader = self.find_module(fullname, path) + if loader is None: + return None + + return importlib.util.spec_from_loader(fullname, loader) + + +class ModuleStash(MetaPathFinder): + """ + Stashes away previously imported modules + + If we reimport a module the data from coverage is lost, so we reuse the old + modules + """ + + def __init__( + self, namespace: str, modules: dict[str, ModuleType] = sys.modules + ) -> None: + self.namespace = namespace + self.modules = modules + self._data: dict[str, ModuleType] = {} + + def stash(self) -> None: + if self.namespace in self.modules: + self._data[self.namespace] = self.modules.pop(self.namespace) + + for module in list(self.modules.keys()): + if module.startswith(self.namespace + "."): + self._data[module] = self.modules.pop(module) + + def pop(self) -> None: + self.modules.pop(self.namespace, None) + + for module in list(self.modules.keys()): + if module.startswith(self.namespace + "."): + self.modules.pop(module) + + self.modules.update(self._data) diff --git a/test/appengine/__init__.py b/test/appengine/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/appengine/conftest.py b/test/appengine/conftest.py deleted file mode 100644 index 0b9d1f1fb0..0000000000 --- a/test/appengine/conftest.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2015 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -# Import py.test hooks and fixtures for App Engine -try: - from gcp_devrel.testing.appengine import ( - pytest_configure, - pytest_runtest_call, - testbed, - ) -except ImportError: - pass - -import pytest -import six - -__all__ = [ - "pytest_configure", - "pytest_runtest_call", - "pytest_ignore_collect", - "testbed", - "sandbox", -] - - -@pytest.fixture -def sandbox(testbed): - """ - Enables parts of the GAE sandbox that are relevant. - Inserts the stub module import hook which causes the usage of - appengine-specific httplib, httplib2, socket, etc. - """ - try: - from google.appengine.tools.devappserver2.python import sandbox - except ImportError: - from google.appengine.tools.devappserver2.python.runtime import sandbox - - for name in list(sys.modules): - if name in sandbox.dist27.MODULE_OVERRIDES: - del sys.modules[name] - sys.meta_path.insert(0, sandbox.StubModuleImportHook()) - sys.path_importer_cache = {} - - yield testbed - - sys.meta_path = [ - x for x in sys.meta_path if not isinstance(x, sandbox.StubModuleImportHook) - ] - sys.path_importer_cache = {} - - # Delete any instances of sandboxed modules. - for name in list(sys.modules): - if name in sandbox.dist27.MODULE_OVERRIDES: - del sys.modules[name] - - -def pytest_ignore_collect(path, config): - """Skip App Engine tests in python 3 or if no SDK is available.""" - if "appengine" in str(path): - if not six.PY2: - return True - if not os.environ.get("GAE_SDK_PATH"): - return True - return False diff --git a/test/appengine/test_gae_manager.py b/test/appengine/test_gae_manager.py deleted file mode 100644 index 3047f249d4..0000000000 --- a/test/appengine/test_gae_manager.py +++ /dev/null @@ -1,178 +0,0 @@ -from test import SHORT_TIMEOUT -from test.with_dummyserver import test_connectionpool - -import pytest - -import dummyserver.testcase -import urllib3.exceptions -import urllib3.util.retry -import urllib3.util.url -from urllib3.contrib import appengine - - -# This class is used so we can re-use the tests from the connection pool. -# It proxies all requests to the manager. -class MockPool(object): - def __init__(self, host, port, manager, scheme="http"): - self.host = host - self.port = port - self.manager = manager - self.scheme = scheme - - def request(self, method, url, *args, **kwargs): - url = self._absolute_url(url) - return self.manager.request(method, url, *args, **kwargs) - - def urlopen(self, method, url, *args, **kwargs): - url = self._absolute_url(url) - return self.manager.urlopen(method, url, *args, **kwargs) - - def _absolute_url(self, path): - return urllib3.util.url.Url( - scheme=self.scheme, host=self.host, port=self.port, path=path - ).url - - -# Note that this doesn't run in the sandbox, it only runs with the URLFetch -# API stub enabled. There's no need to enable the sandbox as we know for a fact -# that URLFetch is used by the connection manager. -@pytest.mark.usefixtures("testbed") -class TestGAEConnectionManager(test_connectionpool.TestConnectionPool): - def setup_method(self, method): - self.manager = appengine.AppEngineManager() - self.pool = MockPool(self.host, self.port, self.manager) - - # Tests specific to AppEngineManager - - def test_exceptions(self): - # DeadlineExceededError -> TimeoutError - with pytest.raises(urllib3.exceptions.TimeoutError): - self.pool.request( - "GET", - "/sleep?seconds={}".format(5 * SHORT_TIMEOUT), - timeout=SHORT_TIMEOUT, - ) - - # InvalidURLError -> ProtocolError - with pytest.raises(urllib3.exceptions.ProtocolError): - self.manager.request("GET", "ftp://invalid/url") - - # DownloadError -> ProtocolError - with pytest.raises(urllib3.exceptions.ProtocolError): - self.manager.request("GET", "http://0.0.0.0") - - # ResponseTooLargeError -> AppEnginePlatformError - with pytest.raises(appengine.AppEnginePlatformError): - self.pool.request( - "GET", "/nbytes?length=33554433" - ) # One byte over 32 megabytes. - - # URLFetch reports the request too large error as a InvalidURLError, - # which maps to a AppEnginePlatformError. - body = b"1" * 10485761 # One byte over 10 megabytes. - with pytest.raises(appengine.AppEnginePlatformError): - self.manager.request("POST", "/", body=body) - - # Re-used tests below this line. - # Subsumed tests - test_timeout_float = None # Covered by test_exceptions. - - # Non-applicable tests - test_conn_closed = None - test_nagle = None - test_socket_options = None - test_disable_default_socket_options = None - test_defaults_are_applied = None - test_tunnel = None - test_keepalive = None - test_keepalive_close = None - test_connection_count = None - test_connection_count_bigpool = None - test_for_double_release = None - test_release_conn_parameter = None - test_stream_keepalive = None - test_cleanup_on_connection_error = None - test_read_chunked_short_circuit = None - test_read_chunked_on_closed_response = None - - # Tests that should likely be modified for appengine specific stuff - test_timeout = None - test_connect_timeout = None - test_connection_error_retries = None - test_total_timeout = None - test_none_total_applies_connect = None - test_timeout_success = None - test_source_address_error = None - test_bad_connect = None - test_partial_response = None - test_dns_error = None - - -@pytest.mark.usefixtures("testbed") -class TestGAEConnectionManagerWithSSL(dummyserver.testcase.HTTPSDummyServerTestCase): - def setup_method(self, method): - self.manager = appengine.AppEngineManager() - self.pool = MockPool(self.host, self.port, self.manager, "https") - - def test_exceptions(self): - # SSLCertificateError -> SSLError - # SSLError is raised with dummyserver because URLFetch doesn't allow - # self-signed certs. - with pytest.raises(urllib3.exceptions.SSLError): - self.pool.request("GET", "/") - - -@pytest.mark.usefixtures("testbed") -class TestGAERetry(test_connectionpool.TestRetry): - def setup_method(self, method): - self.manager = appengine.AppEngineManager() - self.pool = MockPool(self.host, self.port, self.manager) - - def test_default_method_whitelist_retried(self): - """ urllib3 should retry methods in the default method whitelist """ - retry = urllib3.util.retry.Retry(total=1, status_forcelist=[418]) - # Use HEAD instead of OPTIONS, as URLFetch doesn't support OPTIONS - resp = self.pool.request( - "HEAD", - "/successful_retry", - headers={"test-name": "test_default_whitelist"}, - retries=retry, - ) - assert resp.status == 200 - - def test_retry_return_in_response(self): - headers = {"test-name": "test_retry_return_in_response"} - retry = urllib3.util.retry.Retry(total=2, status_forcelist=[418]) - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 200 - assert resp.retries.total == 1 - # URLFetch use absolute urls. - assert resp.retries.history == ( - urllib3.util.retry.RequestHistory( - "GET", self.pool._absolute_url("/successful_retry"), None, 418, None - ), - ) - - # test_max_retry = None - # test_disabled_retry = None - # We don't need these tests because URLFetch resolves its own redirects. - test_retry_redirect_history = None - test_multi_redirect_history = None - - -@pytest.mark.usefixtures("testbed") -class TestGAERetryAfter(test_connectionpool.TestRetryAfter): - def setup_method(self, method): - # Disable urlfetch which doesn't respect Retry-After header. - self.manager = appengine.AppEngineManager(urlfetch_retries=False) - self.pool = MockPool(self.host, self.port, self.manager) - - -def test_gae_environ(): - assert not appengine.is_appengine() - assert not appengine.is_appengine_sandbox() - assert not appengine.is_local_appengine() - assert not appengine.is_prod_appengine() - assert not appengine.is_prod_appengine_mvms() diff --git a/test/appengine/test_urlfetch.py b/test/appengine/test_urlfetch.py deleted file mode 100644 index 74484ea405..0000000000 --- a/test/appengine/test_urlfetch.py +++ /dev/null @@ -1,66 +0,0 @@ -"""These tests ensure that when running in App Engine standard with the -App Engine sandbox enabled that urllib3 appropriately uses the App -Engine-patched version of httplib to make requests.""" - -import httplib -import pytest -import StringIO -from mock import patch - -from ..test_no_ssl import TestWithoutSSL - - -class MockResponse(object): - def __init__(self, content, status_code, content_was_truncated, final_url, headers): - - self.content = content - self.status_code = status_code - self.content_was_truncated = content_was_truncated - self.final_url = final_url - self.header_msg = httplib.HTTPMessage( - StringIO.StringIO( - "".join(["%s: %s\n" % (k, v) for k, v in headers.iteritems()] + ["\n"]) - ) - ) - self.headers = headers - - -@pytest.mark.usefixtures("sandbox") -class TestHTTP(TestWithoutSSL): - def test_urlfetch_called_with_http(self): - """Check that URLFetch is used to fetch non-https resources.""" - resp = MockResponse( - "OK", 200, False, "http://www.google.com", {"content-type": "text/plain"} - ) - fetch_patch = patch("google.appengine.api.urlfetch.fetch", return_value=resp) - with fetch_patch as fetch_mock: - import urllib3 - - pool = urllib3.HTTPConnectionPool("www.google.com", "80") - r = pool.request("GET", "/") - assert r.status == 200, r.data - assert fetch_mock.call_count == 1 - - -@pytest.mark.usefixtures("sandbox") -class TestHTTPS(object): - @pytest.mark.xfail( - reason="This is not yet supported by urlfetch, presence of the ssl " - "module will bypass urlfetch." - ) - def test_urlfetch_called_with_https(self): - """ - Check that URLFetch is used when fetching https resources - """ - resp = MockResponse( - "OK", 200, False, "https://www.google.com", {"content-type": "text/plain"} - ) - fetch_patch = patch("google.appengine.api.urlfetch.fetch", return_value=resp) - with fetch_patch as fetch_mock: - import urllib3 - - pool = urllib3.HTTPSConnectionPool("www.google.com", "443") - pool.ConnectionCls = urllib3.connection.UnverifiedHTTPSConnection - r = pool.request("GET", "/") - assert r.status == 200, r.data - assert fetch_mock.call_count == 1 diff --git a/test/benchmark.py b/test/benchmark.py deleted file mode 100644 index 67d141b252..0000000000 --- a/test/benchmark.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python - -""" -Really simple rudimentary benchmark to compare ConnectionPool versus standard -urllib to demonstrate the usefulness of connection re-using. -""" -from __future__ import print_function - -import sys -import time -import urllib - -sys.path.append("../") -import urllib3 # noqa: E402 - -# URLs to download. Doesn't matter as long as they're from the same host, so we -# can take advantage of connection re-using. -TO_DOWNLOAD = [ - "http://code.google.com/apis/apps/", - "http://code.google.com/apis/base/", - "http://code.google.com/apis/blogger/", - "http://code.google.com/apis/calendar/", - "http://code.google.com/apis/codesearch/", - "http://code.google.com/apis/contact/", - "http://code.google.com/apis/books/", - "http://code.google.com/apis/documents/", - "http://code.google.com/apis/finance/", - "http://code.google.com/apis/health/", - "http://code.google.com/apis/notebook/", - "http://code.google.com/apis/picasaweb/", - "http://code.google.com/apis/spreadsheets/", - "http://code.google.com/apis/webmastertools/", - "http://code.google.com/apis/youtube/", -] - - -def urllib_get(url_list): - assert url_list - for url in url_list: - now = time.time() - urllib.urlopen(url) - elapsed = time.time() - now - print("Got in %0.3f: %s" % (elapsed, url)) - - -def pool_get(url_list): - assert url_list - pool = urllib3.PoolManager() - for url in url_list: - now = time.time() - pool.request("GET", url, assert_same_host=False) - elapsed = time.time() - now - print("Got in %0.3fs: %s" % (elapsed, url)) - - -if __name__ == "__main__": - print("Running pool_get ...") - now = time.time() - pool_get(TO_DOWNLOAD) - pool_elapsed = time.time() - now - - print("Running urllib_get ...") - now = time.time() - urllib_get(TO_DOWNLOAD) - urllib_elapsed = time.time() - now - - print("Completed pool_get in %0.3fs" % pool_elapsed) - print("Completed urllib_get in %0.3fs" % urllib_elapsed) - - -""" -Example results: - -Completed pool_get in 1.163s -Completed urllib_get in 2.318s -""" diff --git a/test/conftest.py b/test/conftest.py index ff8e463186..9aafdacfac 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,110 +1,278 @@ -import collections +from __future__ import annotations + +import asyncio import contextlib -import platform import socket import ssl -import sys -import threading +import typing +from pathlib import Path import pytest import trustme -from tornado import ioloop, web +from tornado import web from dummyserver.handlers import TestingApp -from dummyserver.server import HAS_IPV6, run_tornado_app +from dummyserver.proxy import ProxyHandler +from dummyserver.server import HAS_IPV6, run_loop_in_thread, run_tornado_app from dummyserver.testcase import HTTPSDummyServerTestCase from urllib3.util import ssl_ from .tz_stub import stub_timezone_ctx -# The Python 3.8+ default loop on Windows breaks Tornado -@pytest.fixture(scope="session", autouse=True) -def configure_windows_event_loop(): - if sys.version_info >= (3, 8) and platform.system() == "Windows": - import asyncio +class ServerConfig(typing.NamedTuple): + scheme: str + host: str + port: int + ca_certs: str + + @property + def base_url(self) -> str: + host = self.host + if ":" in host: + host = f"[{host}]" + return f"{self.scheme}://{host}:{self.port}" + + +def _write_cert_to_dir( + cert: trustme.LeafCert, tmpdir: Path, file_prefix: str = "server" +) -> dict[str, str]: + cert_path = str(tmpdir / ("%s.pem" % file_prefix)) + key_path = str(tmpdir / ("%s.key" % file_prefix)) + cert.private_key_pem.write_to_path(key_path) + cert.cert_chain_pems[0].write_to_path(cert_path) + certs = {"keyfile": key_path, "certfile": cert_path} + return certs + + +@contextlib.contextmanager +def run_server_in_thread( + scheme: str, host: str, tmpdir: Path, ca: trustme.CA, server_cert: trustme.LeafCert +) -> typing.Generator[ServerConfig, None, None]: + ca_cert_path = str(tmpdir / "ca.pem") + ca.cert_pem.write_to_path(ca_cert_path) + server_certs = _write_cert_to_dir(server_cert, tmpdir) - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + with run_loop_in_thread() as io_loop: + async def run_app() -> int: + app = web.Application([(r".*", TestingApp)]) + server, port = run_tornado_app(app, server_certs, scheme, host) + return port -ServerConfig = collections.namedtuple("ServerConfig", ["host", "port", "ca_certs"]) + port = asyncio.run_coroutine_threadsafe( + run_app(), io_loop.asyncio_loop # type: ignore[attr-defined] + ).result() + yield ServerConfig("https", host, port, ca_cert_path) @contextlib.contextmanager -def run_server_in_thread(scheme, host, tmpdir, ca, server_cert): +def run_server_and_proxy_in_thread( + proxy_scheme: str, + proxy_host: str, + tmpdir: Path, + ca: trustme.CA, + proxy_cert: trustme.LeafCert, + server_cert: trustme.LeafCert, +) -> typing.Generator[tuple[ServerConfig, ServerConfig], None, None]: ca_cert_path = str(tmpdir / "ca.pem") - server_cert_path = str(tmpdir / "server.pem") - server_key_path = str(tmpdir / "server.key") ca.cert_pem.write_to_path(ca_cert_path) - server_cert.private_key_pem.write_to_path(server_key_path) - server_cert.cert_chain_pems[0].write_to_path(server_cert_path) - server_certs = {"keyfile": server_key_path, "certfile": server_cert_path} - io_loop = ioloop.IOLoop.current() - app = web.Application([(r".*", TestingApp)]) - server, port = run_tornado_app(app, io_loop, server_certs, scheme, host) - server_thread = threading.Thread(target=io_loop.start) - server_thread.start() + server_certs = _write_cert_to_dir(server_cert, tmpdir) + proxy_certs = _write_cert_to_dir(proxy_cert, tmpdir, "proxy") + + with run_loop_in_thread() as io_loop: + + async def run_app() -> tuple[ServerConfig, ServerConfig]: + app = web.Application([(r".*", TestingApp)]) + server_app, port = run_tornado_app(app, server_certs, "https", "localhost") + server_config = ServerConfig("https", "localhost", port, ca_cert_path) + + proxy = web.Application([(r".*", ProxyHandler)]) + proxy_app, proxy_port = run_tornado_app( + proxy, proxy_certs, proxy_scheme, proxy_host + ) + proxy_config = ServerConfig( + proxy_scheme, proxy_host, proxy_port, ca_cert_path + ) + return proxy_config, server_config + + proxy_config, server_config = asyncio.run_coroutine_threadsafe( + run_app(), io_loop.asyncio_loop # type: ignore[attr-defined] + ).result() + yield (proxy_config, server_config) + + +@pytest.fixture(params=["localhost", "127.0.0.1", "::1"]) +def loopback_host(request: typing.Any) -> typing.Generator[str, None, None]: + host = request.param + if host == "::1" and not HAS_IPV6: + pytest.skip("Test requires IPv6 on loopback") + yield host + - yield ServerConfig(host, port, ca_cert_path) +@pytest.fixture() +def san_server( + loopback_host: str, tmp_path_factory: pytest.TempPathFactory +) -> typing.Generator[ServerConfig, None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + + server_cert = ca.issue_cert(loopback_host) + + with run_server_in_thread("https", loopback_host, tmpdir, ca, server_cert) as cfg: + yield cfg + + +@pytest.fixture() +def no_san_server( + loopback_host: str, tmp_path_factory: pytest.TempPathFactory +) -> typing.Generator[ServerConfig, None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + server_cert = ca.issue_cert(common_name=loopback_host) - io_loop.add_callback(server.stop) - io_loop.add_callback(io_loop.stop) - server_thread.join() + with run_server_in_thread("https", loopback_host, tmpdir, ca, server_cert) as cfg: + yield cfg + + +@pytest.fixture() +def no_san_server_with_different_commmon_name( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[ServerConfig, None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + server_cert = ca.issue_cert(common_name="example.com") + + with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg: + yield cfg @pytest.fixture -def no_san_server(tmp_path_factory): +def san_proxy_with_server( + loopback_host: str, tmp_path_factory: pytest.TempPathFactory +) -> typing.Generator[tuple[ServerConfig, ServerConfig], None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + proxy_cert = ca.issue_cert(loopback_host) + server_cert = ca.issue_cert("localhost") + + with run_server_and_proxy_in_thread( + "https", loopback_host, tmpdir, ca, proxy_cert, server_cert + ) as cfg: + yield cfg + + +@pytest.fixture +def no_san_proxy_with_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[tuple[ServerConfig, ServerConfig], None, None]: tmpdir = tmp_path_factory.mktemp("certs") ca = trustme.CA() # only common name, no subject alternative names - server_cert = ca.issue_cert(common_name=u"localhost") + proxy_cert = ca.issue_cert(common_name="localhost") + server_cert = ca.issue_cert("localhost") + + with run_server_and_proxy_in_thread( + "https", "localhost", tmpdir, ca, proxy_cert, server_cert + ) as cfg: + yield cfg + + +@pytest.fixture +def no_localhost_san_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[ServerConfig, None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + # non localhost common name + server_cert = ca.issue_cert("example.com") with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg: yield cfg @pytest.fixture -def ip_san_server(tmp_path_factory): +def ipv4_san_proxy_with_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[tuple[ServerConfig, ServerConfig], None, None]: tmpdir = tmp_path_factory.mktemp("certs") ca = trustme.CA() # IP address in Subject Alternative Name - server_cert = ca.issue_cert(u"127.0.0.1") + proxy_cert = ca.issue_cert("127.0.0.1") + + server_cert = ca.issue_cert("localhost") + + with run_server_and_proxy_in_thread( + "https", "127.0.0.1", tmpdir, ca, proxy_cert, server_cert + ) as cfg: + yield cfg + + +@pytest.fixture +def ipv6_san_proxy_with_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[tuple[ServerConfig, ServerConfig], None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + # IP addresses in Subject Alternative Name + proxy_cert = ca.issue_cert("::1") + + server_cert = ca.issue_cert("localhost") + + with run_server_and_proxy_in_thread( + "https", "::1", tmpdir, ca, proxy_cert, server_cert + ) as cfg: + yield cfg + + +@pytest.fixture +def ipv4_san_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[ServerConfig, None, None]: + tmpdir = tmp_path_factory.mktemp("certs") + ca = trustme.CA() + # IP address in Subject Alternative Name + server_cert = ca.issue_cert("127.0.0.1") with run_server_in_thread("https", "127.0.0.1", tmpdir, ca, server_cert) as cfg: yield cfg @pytest.fixture -def ipv6_addr_server(tmp_path_factory): +def ipv6_san_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[ServerConfig, None, None]: if not HAS_IPV6: pytest.skip("Only runs on IPv6 systems") tmpdir = tmp_path_factory.mktemp("certs") ca = trustme.CA() - # IP address in Common Name - server_cert = ca.issue_cert(common_name=u"::1") + # IP address in Subject Alternative Name + server_cert = ca.issue_cert("::1") with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg: yield cfg @pytest.fixture -def ipv6_san_server(tmp_path_factory): +def ipv6_no_san_server( + tmp_path_factory: pytest.TempPathFactory, +) -> typing.Generator[ServerConfig, None, None]: if not HAS_IPV6: pytest.skip("Only runs on IPv6 systems") tmpdir = tmp_path_factory.mktemp("certs") ca = trustme.CA() - # IP address in Subject Alternative Name - server_cert = ca.issue_cert(u"::1") + # IP address in Common Name + server_cert = ca.issue_cert(common_name="::1") with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg: yield cfg -@pytest.yield_fixture -def stub_timezone(request): +@pytest.fixture +def stub_timezone(request: pytest.FixtureRequest) -> typing.Generator[None, None, None]: """ A pytest fixture that runs the test with a stub timezone. """ @@ -113,7 +281,7 @@ def stub_timezone(request): @pytest.fixture(scope="session") -def supported_tls_versions(): +def supported_tls_versions() -> typing.AbstractSet[str | None]: # We have to create an actual TLS connection # to test if the TLS version is not disabled by # OpenSSL config. Ubuntu 20.04 specifically @@ -122,11 +290,11 @@ def supported_tls_versions(): _server = HTTPSDummyServerTestCase() _server._start_server() - for _ssl_version_name in ( - "PROTOCOL_TLSv1", - "PROTOCOL_TLSv1_1", - "PROTOCOL_TLSv1_2", - "PROTOCOL_TLS", + for _ssl_version_name, min_max_version in ( + ("PROTOCOL_TLSv1", ssl.TLSVersion.TLSv1), + ("PROTOCOL_TLSv1_1", ssl.TLSVersion.TLSv1_1), + ("PROTOCOL_TLSv1_2", ssl.TLSVersion.TLSv1_2), + ("PROTOCOL_TLS", None), ): _ssl_version = getattr(ssl, _ssl_version_name, 0) if _ssl_version == 0: @@ -134,7 +302,12 @@ def supported_tls_versions(): _sock = socket.create_connection((_server.host, _server.port)) try: _sock = ssl_.ssl_wrap_socket( - _sock, cert_reqs=ssl.CERT_NONE, ssl_version=_ssl_version + _sock, + ssl_context=ssl_.create_urllib3_context( + cert_reqs=ssl.CERT_NONE, + ssl_minimum_version=min_max_version, + ssl_maximum_version=min_max_version, + ), ) except ssl.SSLError: pass @@ -146,28 +319,28 @@ def supported_tls_versions(): @pytest.fixture(scope="function") -def requires_tlsv1(supported_tls_versions): +def requires_tlsv1(supported_tls_versions: typing.AbstractSet[str]) -> None: """Test requires TLSv1 available""" if not hasattr(ssl, "PROTOCOL_TLSv1") or "TLSv1" not in supported_tls_versions: pytest.skip("Test requires TLSv1") @pytest.fixture(scope="function") -def requires_tlsv1_1(supported_tls_versions): +def requires_tlsv1_1(supported_tls_versions: typing.AbstractSet[str]) -> None: """Test requires TLSv1.1 available""" if not hasattr(ssl, "PROTOCOL_TLSv1_1") or "TLSv1.1" not in supported_tls_versions: pytest.skip("Test requires TLSv1.1") @pytest.fixture(scope="function") -def requires_tlsv1_2(supported_tls_versions): +def requires_tlsv1_2(supported_tls_versions: typing.AbstractSet[str]) -> None: """Test requires TLSv1.2 available""" if not hasattr(ssl, "PROTOCOL_TLSv1_2") or "TLSv1.2" not in supported_tls_versions: pytest.skip("Test requires TLSv1.2") @pytest.fixture(scope="function") -def requires_tlsv1_3(supported_tls_versions): +def requires_tlsv1_3(supported_tls_versions: typing.AbstractSet[str]) -> None: """Test requires TLSv1.3 available""" if ( not getattr(ssl, "HAS_TLSv1_3", False) diff --git a/test/contrib/test_pyopenssl.py b/test/contrib/test_pyopenssl.py index 1a7f6f9714..b0231295a5 100644 --- a/test/contrib/test_pyopenssl.py +++ b/test/contrib/test_pyopenssl.py @@ -1,28 +1,29 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + import os +from unittest import mock -import mock import pytest try: from cryptography import x509 - from OpenSSL.crypto import FILETYPE_PEM, load_certificate + from OpenSSL.crypto import FILETYPE_PEM, load_certificate # type: ignore[import] from urllib3.contrib.pyopenssl import _dnsname_to_stdlib, get_subj_alt_name except ImportError: pass -def setup_module(): +def setup_module() -> None: try: from urllib3.contrib.pyopenssl import inject_into_urllib3 inject_into_urllib3() except ImportError as e: - pytest.skip("Could not import PyOpenSSL: %r" % e) + pytest.skip(f"Could not import PyOpenSSL: {e!r}") -def teardown_module(): +def teardown_module() -> None: try: from urllib3.contrib.pyopenssl import extract_from_urllib3 @@ -31,13 +32,12 @@ def teardown_module(): pass +from ..test_ssl import TestSSL # noqa: E402, F401 from ..test_util import TestUtilSSL # noqa: E402, F401 from ..with_dummyserver.test_https import ( # noqa: E402, F401 TestHTTPS, - TestHTTPS_IPSAN, - TestHTTPS_IPv6Addr, + TestHTTPS_IPV4SAN, TestHTTPS_IPV6SAN, - TestHTTPS_NoSAN, TestHTTPS_TLSv1, TestHTTPS_TLSv1_1, TestHTTPS_TLSv1_2, @@ -47,51 +47,53 @@ def teardown_module(): TestClientCerts, TestSNI, TestSocketClosing, - TestSSL, +) +from ..with_dummyserver.test_socketlevel import ( # noqa: E402, F401 + TestSSL as TestSocketSSL, ) -class TestPyOpenSSLHelpers(object): +class TestPyOpenSSLHelpers: """ Tests for PyOpenSSL helper functions. """ - def test_dnsname_to_stdlib_simple(self): + def test_dnsname_to_stdlib_simple(self) -> None: """ We can convert a dnsname to a native string when the domain is simple. """ - name = u"उदाहरण.परीक" + name = "उदाहरण.परीक" expected_result = "xn--p1b6ci4b4b3a.xn--11b5bs8d" assert _dnsname_to_stdlib(name) == expected_result - def test_dnsname_to_stdlib_leading_period(self): + def test_dnsname_to_stdlib_leading_period(self) -> None: """ If there is a . in front of the domain name we correctly encode it. """ - name = u".उदाहरण.परीक" + name = ".उदाहरण.परीक" expected_result = ".xn--p1b6ci4b4b3a.xn--11b5bs8d" assert _dnsname_to_stdlib(name) == expected_result - def test_dnsname_to_stdlib_leading_splat(self): + def test_dnsname_to_stdlib_leading_splat(self) -> None: """ If there's a wildcard character in the front of the string we handle it appropriately. """ - name = u"*.उदाहरण.परीक" + name = "*.उदाहरण.परीक" expected_result = "*.xn--p1b6ci4b4b3a.xn--11b5bs8d" assert _dnsname_to_stdlib(name) == expected_result @mock.patch("urllib3.contrib.pyopenssl.log.warning") - def test_get_subj_alt_name(self, mock_warning): + def test_get_subj_alt_name(self, mock_warning: mock.MagicMock) -> None: """ If a certificate has two subject alternative names, cryptography raises an x509.DuplicateExtension exception. """ path = os.path.join(os.path.dirname(__file__), "duplicate_san.pem") - with open(path, "r") as fp: + with open(path) as fp: cert = load_certificate(FILETYPE_PEM, fp.read()) assert get_subj_alt_name(cert) == [] diff --git a/test/contrib/test_pyopenssl_dependencies.py b/test/contrib/test_pyopenssl_dependencies.py index d1498e9218..d182727866 100644 --- a/test/contrib/test_pyopenssl_dependencies.py +++ b/test/contrib/test_pyopenssl_dependencies.py @@ -1,6 +1,8 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + +from unittest.mock import Mock, patch + import pytest -from mock import Mock, patch try: from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 @@ -8,16 +10,16 @@ pass -def setup_module(): +def setup_module() -> None: try: from urllib3.contrib.pyopenssl import inject_into_urllib3 inject_into_urllib3() except ImportError as e: - pytest.skip("Could not import PyOpenSSL: %r" % e) + pytest.skip(f"Could not import PyOpenSSL: {e!r}") -def teardown_module(): +def teardown_module() -> None: try: from urllib3.contrib.pyopenssl import extract_from_urllib3 @@ -26,12 +28,12 @@ def teardown_module(): pass -class TestPyOpenSSLInjection(object): +class TestPyOpenSSLInjection: """ Tests for error handling in pyopenssl's 'inject_into urllib3' """ - def test_inject_validate_fail_cryptography(self): + def test_inject_validate_fail_cryptography(self) -> None: """ Injection should not be supported if cryptography is too old. """ @@ -46,7 +48,7 @@ def test_inject_validate_fail_cryptography(self): # clean up so that subsequent tests are unaffected. extract_from_urllib3() - def test_inject_validate_fail_pyopenssl(self): + def test_inject_validate_fail_pyopenssl(self) -> None: """ Injection should not be supported if pyOpenSSL is too old. """ diff --git a/test/contrib/test_securetransport.py b/test/contrib/test_securetransport.py index 9a49a35521..ac41fe5caa 100644 --- a/test/contrib/test_securetransport.py +++ b/test/contrib/test_securetransport.py @@ -1,4 +1,6 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + +import base64 import contextlib import socket import ssl @@ -11,16 +13,16 @@ pass -def setup_module(): +def setup_module() -> None: try: from urllib3.contrib.securetransport import inject_into_urllib3 inject_into_urllib3() except ImportError as e: - pytest.skip("Could not import SecureTransport: %r" % e) + pytest.skip(f"Could not import SecureTransport: {repr(e)}") -def teardown_module(): +def teardown_module() -> None: try: from urllib3.contrib.securetransport import extract_from_urllib3 @@ -47,8 +49,20 @@ def teardown_module(): ) -def test_no_crash_with_empty_trust_bundle(): +def test_no_crash_with_empty_trust_bundle() -> None: with contextlib.closing(socket.socket()) as s: ws = WrappedSocket(s) with pytest.raises(ssl.SSLError): ws._custom_validate(True, b"") + + +def test_no_crash_with_invalid_trust_bundle() -> None: + invalid_cert = base64.b64encode(b"invalid-cert") + cert_bundle = ( + b"-----BEGIN CERTIFICATE-----\n" + invalid_cert + b"\n-----END CERTIFICATE-----" + ) + + with contextlib.closing(socket.socket()) as s: + ws = WrappedSocket(s) + with pytest.raises(ssl.SSLError): + ws._custom_validate(True, cert_bundle) diff --git a/test/contrib/test_socks.py b/test/contrib/test_socks.py index 1966513c18..2878cc8d8b 100644 --- a/test/contrib/test_socks.py +++ b/test/contrib/test_socks.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import socket import threading +import typing +from socket import getaddrinfo as real_getaddrinfo +from socket import timeout as SocketTimeout from test import SHORT_TIMEOUT +from unittest.mock import Mock, patch import pytest +import socks as py_socks # type: ignore[import] from dummyserver.server import DEFAULT_CA, DEFAULT_CERTS from dummyserver.testcase import IPV4SocketDummyServerTestCase @@ -16,8 +23,8 @@ HAS_SSL = True except ImportError: - ssl = None - better_ssl = None + ssl = None # type: ignore[assignment] + better_ssl = None # type: ignore[assignment] HAS_SSL = False @@ -28,7 +35,7 @@ SOCKS_VERSION_SOCKS5 = b"\x05" -def _get_free_port(host): +def _get_free_port(host: str) -> int: """ Gets a free port by opening a socket, binding it, checking the assigned port, and then closing it. @@ -37,10 +44,10 @@ def _get_free_port(host): s.bind((host, 0)) port = s.getsockname()[1] s.close() - return port + return port # type: ignore[no-any-return] -def _read_exactly(sock, amt): +def _read_exactly(sock: socket.socket, amt: int) -> bytes: """ Read *exactly* ``amt`` bytes from the socket ``sock``. """ @@ -54,7 +61,7 @@ def _read_exactly(sock, amt): return data -def _read_until(sock, char): +def _read_until(sock: socket.socket, char: bytes) -> bytes: """ Read from the socket until the character is received. """ @@ -68,7 +75,7 @@ def _read_until(sock, char): return b"".join(chunks) -def _address_from_socket(sock): +def _address_from_socket(sock: socket.socket) -> bytes | str: """ Returns the address from the SOCKS socket """ @@ -84,10 +91,45 @@ def _address_from_socket(sock): addr_len = ord(sock.recv(1)) return _read_exactly(sock, addr_len) else: - raise RuntimeError("Unexpected addr type: %r" % addr_type) - - -def handle_socks5_negotiation(sock, negotiate, username=None, password=None): + raise RuntimeError(f"Unexpected addr type: {addr_type!r}") + + +def _set_up_fake_getaddrinfo(monkeypatch: pytest.MonkeyPatch) -> None: + # Work around https://github.com/urllib3/urllib3/pull/2034 + # Nothing prevents localhost to point to two different IPs. For example, in the + # Ubuntu set up by GitHub Actions, localhost points both to 127.0.0.1 and ::1. + # + # In case of failure, PySocks will try the same request on both IPs, but our + # handle_socks[45]_negotiation functions don't handle retries, which leads either to + # a deadlock or a timeout in case of a failure on the first address. + # + # However, some tests need to exercise failure. We don't want retries there, but + # can't affect PySocks retries via its API. Instead, we monkeypatch PySocks so that + # it only sees a single address, which effectively disables retries. + def fake_getaddrinfo( + addr: str, port: int, family: int, socket_type: int + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + gai_list = real_getaddrinfo(addr, port, family, socket_type) + gai_list = [gai for gai in gai_list if gai[0] == socket.AF_INET] + return gai_list[:1] + + monkeypatch.setattr(py_socks.socket, "getaddrinfo", fake_getaddrinfo) + + +def handle_socks5_negotiation( + sock: socket.socket, + negotiate: bool, + username: bytes | None = None, + password: bytes | None = None, +) -> typing.Generator[tuple[bytes | str, int], bool, None]: """ Handle the SOCKS5 handshake. @@ -117,7 +159,6 @@ def handle_socks5_negotiation(sock, negotiate, username=None, password=None): else: sock.sendall(b"\x01\x01") sock.close() - yield False return else: assert SOCKS_NEGOTIATION_NONE in methods @@ -129,8 +170,8 @@ def handle_socks5_negotiation(sock, negotiate, username=None, password=None): command = sock.recv(1) reserved = sock.recv(1) addr = _address_from_socket(sock) - port = _read_exactly(sock, 2) - port = (ord(port[0:1]) << 8) + (ord(port[1:2])) + port_raw = _read_exactly(sock, 2) + port = (ord(port_raw[0:1]) << 8) + (ord(port_raw[1:2])) # Check some basic stuff. assert received_version == SOCKS_VERSION_SOCKS5 @@ -148,10 +189,11 @@ def handle_socks5_negotiation(sock, negotiate, username=None, password=None): response = SOCKS_VERSION_SOCKS5 + b"\x01\00" sock.sendall(response) - yield True # Avoid StopIteration exceptions getting fired. -def handle_socks4_negotiation(sock, username=None): +def handle_socks4_negotiation( + sock: socket.socket, username: bytes | None = None +) -> typing.Generator[tuple[bytes | str, int], bool, None]: """ Handle the SOCKS4 handshake. @@ -160,16 +202,17 @@ def handle_socks4_negotiation(sock, username=None): """ received_version = sock.recv(1) command = sock.recv(1) - port = _read_exactly(sock, 2) - port = (ord(port[0:1]) << 8) + (ord(port[1:2])) - addr = _read_exactly(sock, 4) + port_raw = _read_exactly(sock, 2) + port = (ord(port_raw[0:1]) << 8) + (ord(port_raw[1:2])) + addr_raw = _read_exactly(sock, 4) provided_username = _read_until(sock, b"\x00")[:-1] # Strip trailing null. - if addr == b"\x00\x00\x00\x01": + addr: bytes | str + if addr_raw == b"\x00\x00\x00\x01": # Magic string: means DNS name. addr = _read_until(sock, b"\x00")[:-1] # Strip trailing null. else: - addr = socket.inet_ntoa(addr) + addr = socket.inet_ntoa(addr_raw) # Check some basic stuff. assert received_version == SOCKS_VERSION_SOCKS4 @@ -178,7 +221,6 @@ def handle_socks4_negotiation(sock, username=None): if username is not None and username != provided_username: sock.sendall(b"\x00\x5d\x00\x00\x00\x00\x00\x00") sock.close() - yield False return # Yield the address port tuple. @@ -190,14 +232,12 @@ def handle_socks4_negotiation(sock, username=None): response = b"\x00\x5b\x00\x00\x00\x00\x00\x00" sock.sendall(response) - yield True # Avoid StopIteration exceptions getting fired. -class TestSOCKSProxyManager(object): - def test_invalid_socks_version_is_valueerror(self): - with pytest.raises(ValueError) as e: +class TestSOCKSProxyManager: + def test_invalid_socks_version_is_valueerror(self) -> None: + with pytest.raises(ValueError, match="Unable to determine SOCKS version"): socks.SOCKSProxyManager(proxy_url="http://example.org") - assert "Unable to determine SOCKS version" in e.value.args[0] class TestSocks5Proxy(IPV4SocketDummyServerTestCase): @@ -205,8 +245,8 @@ class TestSocks5Proxy(IPV4SocketDummyServerTestCase): Test the SOCKS proxy in SOCKS5 mode. """ - def test_basic_request(self): - def request_handler(listener): + def test_basic_request(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation(sock, negotiate=False) @@ -214,7 +254,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -230,7 +271,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5://%s:%s" % (self.host, self.port) + proxy_url = f"socks5://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://16.17.18.19") @@ -238,8 +279,8 @@ def request_handler(listener): assert response.data == b"" assert response.headers["Server"] == "SocksTestServer" - def test_local_dns(self): - def request_handler(listener): + def test_local_dns(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation(sock, negotiate=False) @@ -247,7 +288,8 @@ def request_handler(listener): assert addr in ["127.0.0.1", "::1"] assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -263,7 +305,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5://%s:%s" % (self.host, self.port) + proxy_url = f"socks5://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://localhost") @@ -271,8 +313,8 @@ def request_handler(listener): assert response.data == b"" assert response.headers["Server"] == "SocksTestServer" - def test_correct_header_line(self): - def request_handler(listener): + def test_correct_header_line(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation(sock, negotiate=False) @@ -280,7 +322,8 @@ def request_handler(listener): assert addr == b"example.com" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) buf = b"" while True: @@ -300,19 +343,19 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://example.com") assert response.status == 200 - def test_connection_timeouts(self): + def test_connection_timeouts(self) -> None: event = threading.Event() - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: event.wait() self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: with pytest.raises(ConnectTimeoutError): pm.request( @@ -320,42 +363,52 @@ def request_handler(listener): ) event.set() - def test_connection_failure(self): + @patch("socks.create_connection") + def test_socket_timeout(self, create_connection: Mock) -> None: + create_connection.side_effect = SocketTimeout() + proxy_url = f"socks5h://{self.host}:{self.port}" + with socks.SOCKSProxyManager(proxy_url) as pm: + with pytest.raises(ConnectTimeoutError, match="timed out"): + pm.request("GET", "http://example.com", retries=False) + + def test_connection_failure(self) -> None: event = threading.Event() - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: listener.close() event.set() self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: event.wait() with pytest.raises(NewConnectionError): pm.request("GET", "http://example.com", retries=False) - def test_proxy_rejection(self): + def test_proxy_rejection(self, monkeypatch: pytest.MonkeyPatch) -> None: + _set_up_fake_getaddrinfo(monkeypatch) evt = threading.Event() - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation(sock, negotiate=False) addr, port = next(handler) - handler.send(False) + with pytest.raises(StopIteration): + handler.send(False) evt.wait() sock.close() self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: with pytest.raises(NewConnectionError): pm.request("GET", "http://example.com", retries=False) evt.set() - def test_socks_with_password(self): - def request_handler(listener): + def test_socks_with_password(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation( @@ -365,7 +418,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -381,7 +435,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5://%s:%s" % (self.host, self.port) + proxy_url = f"socks5://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url, username="user", password="pass") as pm: response = pm.request("GET", "http://16.17.18.19") @@ -389,13 +443,13 @@ def request_handler(listener): assert response.data == b"" assert response.headers["Server"] == "SocksTestServer" - def test_socks_with_auth_in_url(self): + def test_socks_with_auth_in_url(self) -> None: """ Test when we have auth info in url, i.e. socks5://user:pass@host:port and no username/password as params """ - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation( @@ -405,7 +459,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -421,7 +476,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5://user:pass@%s:%s" % (self.host, self.port) + proxy_url = f"socks5://user:pass@{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://16.17.18.19") @@ -429,28 +484,32 @@ def request_handler(listener): assert response.data == b"" assert response.headers["Server"] == "SocksTestServer" - def test_socks_with_invalid_password(self): - def request_handler(listener): + def test_socks_with_invalid_password(self, monkeypatch: pytest.MonkeyPatch) -> None: + _set_up_fake_getaddrinfo(monkeypatch) + + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation( sock, negotiate=True, username=b"user", password=b"pass" ) - next(handler) + with pytest.raises(StopIteration): + next(handler) self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager( proxy_url, username="user", password="badpass" ) as pm: - with pytest.raises(NewConnectionError) as e: + with pytest.raises( + NewConnectionError, match="SOCKS5 authentication failed" + ): pm.request("GET", "http://example.com", retries=False) - assert "SOCKS5 authentication failed" in str(e.value) - def test_source_address_works(self): + def test_source_address_works(self) -> None: expected_port = _get_free_port(self.host) - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] assert sock.getpeername()[0] == "127.0.0.1" assert sock.getpeername()[1] == expected_port @@ -460,7 +519,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -476,7 +536,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5://%s:%s" % (self.host, self.port) + proxy_url = f"socks5://{self.host}:{self.port}" with socks.SOCKSProxyManager( proxy_url, source_address=("127.0.0.1", expected_port) ) as pm: @@ -492,8 +552,8 @@ class TestSOCKS4Proxy(IPV4SocketDummyServerTestCase): negotiation is done the two cases behave identically. """ - def test_basic_request(self): - def request_handler(listener): + def test_basic_request(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock) @@ -501,7 +561,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -517,7 +578,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks4://%s:%s" % (self.host, self.port) + proxy_url = f"socks4://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://16.17.18.19") @@ -525,8 +586,8 @@ def request_handler(listener): assert response.headers["Server"] == "SocksTestServer" assert response.data == b"" - def test_local_dns(self): - def request_handler(listener): + def test_local_dns(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock) @@ -534,7 +595,8 @@ def request_handler(listener): assert addr == "127.0.0.1" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -550,7 +612,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks4://%s:%s" % (self.host, self.port) + proxy_url = f"socks4://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://localhost") @@ -558,8 +620,8 @@ def request_handler(listener): assert response.headers["Server"] == "SocksTestServer" assert response.data == b"" - def test_correct_header_line(self): - def request_handler(listener): + def test_correct_header_line(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock) @@ -567,7 +629,8 @@ def request_handler(listener): assert addr == b"example.com" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) buf = b"" while True: @@ -587,33 +650,35 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks4a://%s:%s" % (self.host, self.port) + proxy_url = f"socks4a://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: response = pm.request("GET", "http://example.com") assert response.status == 200 - def test_proxy_rejection(self): + def test_proxy_rejection(self, monkeypatch: pytest.MonkeyPatch) -> None: + _set_up_fake_getaddrinfo(monkeypatch) evt = threading.Event() - def request_handler(listener): + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock) addr, port = next(handler) - handler.send(False) + with pytest.raises(StopIteration): + handler.send(False) evt.wait() sock.close() self._start_server(request_handler) - proxy_url = "socks4a://%s:%s" % (self.host, self.port) + proxy_url = f"socks4a://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url) as pm: with pytest.raises(NewConnectionError): pm.request("GET", "http://example.com", retries=False) evt.set() - def test_socks4_with_username(self): - def request_handler(listener): + def test_socks4_with_username(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock, username=b"user") @@ -621,7 +686,8 @@ def request_handler(listener): assert addr == "16.17.18.19" assert port == 80 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) while True: buf = sock.recv(65535) @@ -637,7 +703,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks4://%s:%s" % (self.host, self.port) + proxy_url = f"socks4://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url, username="user") as pm: response = pm.request("GET", "http://16.17.18.19") @@ -645,19 +711,18 @@ def request_handler(listener): assert response.data == b"" assert response.headers["Server"] == "SocksTestServer" - def test_socks_with_invalid_username(self): - def request_handler(listener): + def test_socks_with_invalid_username(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks4_negotiation(sock, username=b"user") - next(handler) + next(handler, None) self._start_server(request_handler) - proxy_url = "socks4a://%s:%s" % (self.host, self.port) + proxy_url = f"socks4a://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url, username="baduser") as pm: - with pytest.raises(NewConnectionError) as e: + with pytest.raises(NewConnectionError, match="different user-ids"): pm.request("GET", "http://example.com", retries=False) - assert "different user-ids" in str(e.value) class TestSOCKSWithTLS(IPV4SocketDummyServerTestCase): @@ -666,8 +731,8 @@ class TestSOCKSWithTLS(IPV4SocketDummyServerTestCase): """ @pytest.mark.skipif(not HAS_SSL, reason="No TLS available") - def test_basic_request(self): - def request_handler(listener): + def test_basic_request(self) -> None: + def request_handler(listener: socket.socket) -> None: sock = listener.accept()[0] handler = handle_socks5_negotiation(sock, negotiate=False) @@ -675,10 +740,11 @@ def request_handler(listener): assert addr == b"localhost" assert port == 443 - handler.send(True) + with pytest.raises(StopIteration): + handler.send(True) # Wrap in TLS - context = better_ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = better_ssl.SSLContext(ssl.PROTOCOL_SSLv23) # type: ignore[misc] context.load_cert_chain(DEFAULT_CERTS["certfile"], DEFAULT_CERTS["keyfile"]) tls = context.wrap_socket(sock, server_side=True) buf = b"" @@ -700,7 +766,7 @@ def request_handler(listener): sock.close() self._start_server(request_handler) - proxy_url = "socks5h://%s:%s" % (self.host, self.port) + proxy_url = f"socks5h://{self.host}:{self.port}" with socks.SOCKSProxyManager(proxy_url, ca_certs=DEFAULT_CA) as pm: response = pm.request("GET", "https://localhost") diff --git a/test/port_helpers.py b/test/port_helpers.py index ae18ccae6d..e8c94843ba 100644 --- a/test/port_helpers.py +++ b/test/port_helpers.py @@ -1,6 +1,8 @@ -# These helpers are copied from test_support.py in the Python 2.7 standard +# These helpers are copied from test/support/socket_helper.py in the Python 3.9 standard # library test suite. +from __future__ import annotations + import socket # Don't use "localhost", since resolving it uses the DNS under recent @@ -9,7 +11,10 @@ HOSTv6 = "::1" -def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): +def find_unused_port( + family: socket.AddressFamily = socket.AF_INET, + socktype: socket.SocketKind = socket.SOCK_STREAM, +) -> int: """Returns an unused port that should be suitable for binding. This is achieved by creating a temporary socket with the same family and type as the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to @@ -36,7 +41,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): the SO_REUSEADDR socket option having different semantics on Windows versus Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind, listen and then accept connections on identical host/ports. An EADDRINUSE - socket.error will be raised at some point (depending on the platform and + OSError will be raised at some point (depending on the platform and the order bind and listen were called on each socket). However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE @@ -63,15 +68,15 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): other process when we close and delete our temporary socket but before our calling code has a chance to bind the returned port. We can deal with this issue if/when we come across it.""" - tempsock = socket.socket(family, socktype) - port = bind_port(tempsock) - tempsock.close() + + with socket.socket(family, socktype) as tempsock: + port = bind_port(tempsock) del tempsock return port -def bind_port(sock, host=HOST): - """Bind the socket to a free port and return the port number. Relies on +def bind_port(sock: socket.socket, host: str = HOST) -> int: + """Bind the socket to a free port and return the port number. Relies on ephemeral ports in order to ensure we are using an unbound port. This is important as many tests may be running simultaneously, especially in a buildbot environment. This method raises an exception if the sock.family @@ -84,6 +89,7 @@ def bind_port(sock, host=HOST): on Windows), it will be set on the socket. This will prevent anyone else from bind()'ing to our host/port for the duration of the test. """ + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: if hasattr(socket, "SO_REUSEADDR"): if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: @@ -92,14 +98,21 @@ def bind_port(sock, host=HOST): "socket option on TCP/IP sockets!" ) if hasattr(socket, "SO_REUSEPORT"): - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: - raise ValueError( - "tests should never set the SO_REUSEPORT " - "socket option on TCP/IP sockets!" - ) + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise ValueError( + "tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!" + ) + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) sock.bind((host, 0)) port = sock.getsockname()[1] + assert isinstance(port, int) return port diff --git a/test/socketpair_helper.py b/test/socketpair_helper.py deleted file mode 100644 index 7ddb6009d7..0000000000 --- a/test/socketpair_helper.py +++ /dev/null @@ -1,63 +0,0 @@ -import socket - -# Figuring out what errors could come out of a socket. There are three -# different situations. Python 3 post-PEP3151 will define and use -# BlockingIOError and InterruptedError from sockets. For Python pre-PEP3151 -# both OSError and socket.error can be raised except on Windows where -# WindowsError can also be raised. We want to catch all of these possible -# exceptions so we catch WindowsError if it's defined. -try: - _CONNECT_ERROR = (BlockingIOError, InterruptedError) -except NameError: - try: - _CONNECT_ERROR = (WindowsError, OSError, socket.error) # noqa: F821 - except NameError: - _CONNECT_ERROR = (OSError, socket.error) - -if hasattr(socket, "socketpair"): - # Since Python 3.5, socket.socketpair() is now also available on Windows - socketpair = socket.socketpair -else: - # Replacement for socket.socketpair() - def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - """A socket pair usable as a self-pipe, for Windows. - - Origin: https://gist.github.com/4325783, by Geert Jansen. - Public domain. - """ - if family == socket.AF_INET: - host = "127.0.0.1" - elif family == socket.AF_INET6: - host = "::1" - else: - raise ValueError( - "Only AF_INET and AF_INET6 socket address families are supported" - ) - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with setblocking(0) - # that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(1) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except _CONNECT_ERROR: - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except Exception: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) diff --git a/test/test_collections.py b/test/test_collections.py index 4b8624cb6c..56cdea865d 100644 --- a/test/test_collections.py +++ b/test/test_collections.py @@ -1,23 +1,23 @@ +from __future__ import annotations + +import typing + import pytest from urllib3._collections import HTTPHeaderDict from urllib3._collections import RecentlyUsedContainer as Container -from urllib3.exceptions import InvalidHeader -from urllib3.packages import six -xrange = six.moves.xrange +class TestLRUContainer: + def test_maxsize(self) -> None: + d: Container[int, str] = Container(5) -class TestLRUContainer(object): - def test_maxsize(self): - d = Container(5) - - for i in xrange(5): + for i in range(5): d[i] = str(i) assert len(d) == 5 - for i in xrange(5): + for i in range(5): assert d[i] == str(i) d[i + 1] = str(i + 1) @@ -26,49 +26,54 @@ def test_maxsize(self): assert 0 not in d assert (i + 1) in d - def test_expire(self): - d = Container(5) + def test_maxsize_0(self) -> None: + d: Container[int, int] = Container(0) + d[1] = 1 + assert len(d) == 0 + + def test_expire(self) -> None: + d: Container[int, str] = Container(5) - for i in xrange(5): + for i in range(5): d[i] = str(i) - for i in xrange(5): + for i in range(5): d.get(0) # Add one more entry d[5] = "5" # Check state - assert list(d.keys()) == [2, 3, 4, 0, 5] + assert list(d._container.keys()) == [2, 3, 4, 0, 5] - def test_same_key(self): - d = Container(5) + def test_same_key(self) -> None: + d: Container[str, int] = Container(5) - for i in xrange(10): + for i in range(10): d["foo"] = i - assert list(d.keys()) == ["foo"] + assert list(d._container.keys()) == ["foo"] assert len(d) == 1 - def test_access_ordering(self): - d = Container(5) + def test_access_ordering(self) -> None: + d: Container[int, bool] = Container(5) - for i in xrange(10): + for i in range(10): d[i] = True # Keys should be ordered by access time - assert list(d.keys()) == [5, 6, 7, 8, 9] + assert list(d._container.keys()) == [5, 6, 7, 8, 9] new_order = [7, 8, 6, 9, 5] for k in new_order: d[k] - assert list(d.keys()) == new_order + assert list(d._container.keys()) == new_order - def test_delete(self): - d = Container(5) + def test_delete(self) -> None: + d: Container[int, bool] = Container(5) - for i in xrange(5): + for i in range(5): d[i] = True del d[0] @@ -79,10 +84,10 @@ def test_delete(self): d.pop(1, None) - def test_get(self): - d = Container(5) + def test_get(self) -> None: + d: Container[int, bool | int] = Container(5) - for i in xrange(5): + for i in range(5): d[i] = True r = d.get(4) @@ -97,21 +102,21 @@ def test_get(self): with pytest.raises(KeyError): d[5] - def test_disposal(self): - evicted_items = [] + def test_disposal(self) -> None: + evicted_items: list[int] = [] - def dispose_func(arg): + def dispose_func(arg: int) -> None: # Save the evicted datum for inspection evicted_items.append(arg) - d = Container(5, dispose_func=dispose_func) - for i in xrange(5): + d: Container[int, int] = Container(5, dispose_func=dispose_func) + for i in range(5): d[i] = i - assert list(d.keys()) == list(xrange(5)) + assert list(d._container.keys()) == list(range(5)) assert evicted_items == [] # Nothing disposed d[5] = 5 - assert list(d.keys()) == list(xrange(1, 6)) + assert list(d._container.keys()) == list(range(1, 6)) assert evicted_items == [0] del d[1] @@ -120,49 +125,57 @@ def dispose_func(arg): d.clear() assert evicted_items == [0, 1, 2, 3, 4, 5] - def test_iter(self): - d = Container() + def test_iter(self) -> None: + d: Container[str, str] = Container() with pytest.raises(NotImplementedError): d.__iter__() -class NonMappingHeaderContainer(object): - def __init__(self, **kwargs): +class NonMappingHeaderContainer: + def __init__(self, **kwargs: str) -> None: self._data = {} self._data.update(kwargs) - def keys(self): - return self._data.keys() + def keys(self) -> typing.Iterator[str]: + return iter(self._data) - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: return self._data[key] @pytest.fixture() -def d(): +def d() -> HTTPHeaderDict: header_dict = HTTPHeaderDict(Cookie="foo") header_dict.add("cookie", "bar") return header_dict -class TestHTTPHeaderDict(object): - def test_create_from_kwargs(self): - h = HTTPHeaderDict(ab=1, cd=2, ef=3, gh=4) +class TestHTTPHeaderDict: + def test_create_from_kwargs(self) -> None: + h = HTTPHeaderDict(ab="1", cd="2", ef="3", gh="4") assert len(h) == 4 assert "ab" in h - def test_create_from_dict(self): - h = HTTPHeaderDict(dict(ab=1, cd=2, ef=3, gh=4)) + def test_setdefault(self) -> None: + h = HTTPHeaderDict(a="1") + assert h.setdefault("A", "3") == "1" + assert h.setdefault("b", "2") == "2" + assert h.setdefault("c") == "" + assert h["c"] == "" + assert h["b"] == "2" + + def test_create_from_dict(self) -> None: + h = HTTPHeaderDict(dict(ab="1", cd="2", ef="3", gh="4")) assert len(h) == 4 assert "ab" in h - def test_create_from_iterator(self): + def test_create_from_iterator(self) -> None: teststr = "urllib3ontherocks" h = HTTPHeaderDict((c, c * 5) for c in teststr) assert len(h) == len(set(teststr)) - def test_create_from_list(self): + def test_create_from_list(self) -> None: headers = [ ("ab", "A"), ("cd", "B"), @@ -178,7 +191,7 @@ def test_create_from_list(self): assert clist[0] == "C" assert clist[-1] == "E" - def test_create_from_headerdict(self): + def test_create_from_headerdict(self) -> None: headers = [ ("ab", "A"), ("cd", "B"), @@ -197,54 +210,87 @@ def test_create_from_headerdict(self): assert h is not org assert h == org - def test_setitem(self, d): + def test_setitem(self, d: HTTPHeaderDict) -> None: d["Cookie"] = "foo" - assert d["cookie"] == "foo" + # The bytes value gets converted to str. The API is typed for str only, + # but the implementation continues supports bytes. + d[b"Cookie"] = "bar" # type: ignore[index] + assert d["cookie"] == "bar" d["cookie"] = "with, comma" assert d.getlist("cookie") == ["with, comma"] - def test_update(self, d): + def test_update(self, d: HTTPHeaderDict) -> None: d.update(dict(Cookie="foo")) assert d["cookie"] == "foo" d.update(dict(cookie="with, comma")) assert d.getlist("cookie") == ["with, comma"] - def test_delitem(self, d): + def test_delitem(self, d: HTTPHeaderDict) -> None: del d["cookie"] assert "cookie" not in d assert "COOKIE" not in d - def test_add_well_known_multiheader(self, d): + def test_add_well_known_multiheader(self, d: HTTPHeaderDict) -> None: d.add("COOKIE", "asdf") assert d.getlist("cookie") == ["foo", "bar", "asdf"] assert d["cookie"] == "foo, bar, asdf" - def test_add_comma_separated_multiheader(self, d): + def test_add_comma_separated_multiheader(self, d: HTTPHeaderDict) -> None: d.add("bar", "foo") - d.add("BAR", "bar") + # The bytes value gets converted to str. The API is typed for str only, + # but the implementation continues supports bytes. + d.add(b"BAR", "bar") # type: ignore[arg-type] d.add("Bar", "asdf") assert d.getlist("bar") == ["foo", "bar", "asdf"] assert d["bar"] == "foo, bar, asdf" - def test_extend_from_list(self, d): + def test_extend_from_list(self, d: HTTPHeaderDict) -> None: d.extend([("set-cookie", "100"), ("set-cookie", "200"), ("set-cookie", "300")]) assert d["set-cookie"] == "100, 200, 300" - def test_extend_from_dict(self, d): + def test_extend_from_dict(self, d: HTTPHeaderDict) -> None: d.extend(dict(cookie="asdf"), b="100") assert d["cookie"] == "foo, bar, asdf" assert d["b"] == "100" d.add("cookie", "with, comma") assert d.getlist("cookie") == ["foo", "bar", "asdf", "with, comma"] - def test_extend_from_container(self, d): + def test_extend_from_container(self, d: HTTPHeaderDict) -> None: h = NonMappingHeaderContainer(Cookie="foo", e="foofoo") d.extend(h) assert d["cookie"] == "foo, bar, foo" assert d["e"] == "foofoo" assert len(d) == 2 - def test_extend_from_headerdict(self, d): + def test_header_repeat(self, d: HTTPHeaderDict) -> None: + d["other-header"] = "hello" + d.add("other-header", "world") + + assert list(d.items()) == [ + ("Cookie", "foo"), + ("Cookie", "bar"), + ("other-header", "hello"), + ("other-header", "world"), + ] + + d.add("other-header", "!", combine=True) + expected_results = [ + ("Cookie", "foo"), + ("Cookie", "bar"), + ("other-header", "hello"), + ("other-header", "world, !"), + ] + + assert list(d.items()) == expected_results + # make sure the values persist over copys + assert list(d.copy().items()) == expected_results + + other_dict = HTTPHeaderDict() + # we also need for extensions to properly maintain results + other_dict.extend(d) + assert list(other_dict.items()) == expected_results + + def test_extend_from_headerdict(self, d: HTTPHeaderDict) -> None: h = HTTPHeaderDict(Cookie="foo", e="foofoo") d.extend(h) assert d["cookie"] == "foo, bar, foo" @@ -252,41 +298,48 @@ def test_extend_from_headerdict(self, d): assert len(d) == 2 @pytest.mark.parametrize("args", [(1, 2), (1, 2, 3, 4, 5)]) - def test_extend_with_wrong_number_of_args_is_typeerror(self, d, args): - with pytest.raises(TypeError) as err: - d.extend(*args) - assert "extend() takes at most 1 positional arguments" in err.value.args[0] - - def test_copy(self, d): + def test_extend_with_wrong_number_of_args_is_typeerror( + self, d: HTTPHeaderDict, args: tuple[int, ...] + ) -> None: + with pytest.raises( + TypeError, match=r"extend\(\) takes at most 1 positional arguments" + ): + d.extend(*args) # type: ignore[arg-type] + + def test_copy(self, d: HTTPHeaderDict) -> None: h = d.copy() assert d is not h assert d == h - def test_getlist(self, d): + def test_getlist(self, d: HTTPHeaderDict) -> None: assert d.getlist("cookie") == ["foo", "bar"] assert d.getlist("Cookie") == ["foo", "bar"] assert d.getlist("b") == [] d.add("b", "asdf") assert d.getlist("b") == ["asdf"] - def test_getlist_after_copy(self, d): + def test_getlist_after_copy(self, d: HTTPHeaderDict) -> None: assert d.getlist("cookie") == HTTPHeaderDict(d).getlist("cookie") - def test_equal(self, d): + def test_equal(self, d: HTTPHeaderDict) -> None: b = HTTPHeaderDict(cookie="foo, bar") c = NonMappingHeaderContainer(cookie="foo, bar") + e = [("cookie", "foo, bar")] assert d == b assert d == c + assert d == e assert d != 2 - def test_not_equal(self, d): + def test_not_equal(self, d: HTTPHeaderDict) -> None: b = HTTPHeaderDict(cookie="foo, bar") c = NonMappingHeaderContainer(cookie="foo, bar") + e = [("cookie", "foo, bar")] assert not (d != b) assert not (d != c) + assert not (d != e) assert d != 2 - def test_pop(self, d): + def test_pop(self, d: HTTPHeaderDict) -> None: key = "Cookie" a = d[key] b = d.pop(key) @@ -297,31 +350,37 @@ def test_pop(self, d): dummy = object() assert dummy is d.pop(key, dummy) - def test_discard(self, d): + def test_discard(self, d: HTTPHeaderDict) -> None: d.discard("cookie") assert "cookie" not in d d.discard("cookie") - def test_len(self, d): + def test_len(self, d: HTTPHeaderDict) -> None: assert len(d) == 1 d.add("cookie", "bla") d.add("asdf", "foo") # len determined by unique fieldnames assert len(d) == 2 - def test_repr(self, d): + def test_repr(self, d: HTTPHeaderDict) -> None: rep = "HTTPHeaderDict({'Cookie': 'foo, bar'})" assert repr(d) == rep - def test_items(self, d): + def test_items(self, d: HTTPHeaderDict) -> None: items = d.items() assert len(items) == 2 - assert items[0][0] == "Cookie" - assert items[0][1] == "foo" - assert items[1][0] == "Cookie" - assert items[1][1] == "bar" - - def test_dict_conversion(self, d): + assert list(items) == [ + ("Cookie", "foo"), + ("Cookie", "bar"), + ] + assert ("Cookie", "foo") in items + assert ("Cookie", "bar") in items + assert ("X-Some-Header", "foo") not in items + assert ("Cookie", "not_present") not in items + assert ("Cookie", 1) not in items # type: ignore[comparison-overlap] + assert "Cookie" not in items # type: ignore[comparison-overlap] + + def test_dict_conversion(self, d: HTTPHeaderDict) -> None: # Also tested in connectionpool, needs to preserve case hdict = { "Content-Length": "0", @@ -332,51 +391,35 @@ def test_dict_conversion(self, d): assert hdict == h assert hdict == dict(HTTPHeaderDict(hdict)) - def test_string_enforcement(self, d): + def test_string_enforcement(self, d: HTTPHeaderDict) -> None: # This currently throws AttributeError on key.lower(), should # probably be something nicer with pytest.raises(Exception): - d[3] = 5 + d[3] = "5" # type: ignore[index] with pytest.raises(Exception): - d.add(3, 4) + d.add(3, "4") # type: ignore[arg-type] with pytest.raises(Exception): - del d[3] + del d[3] # type: ignore[arg-type] with pytest.raises(Exception): - HTTPHeaderDict({3: 3}) - - @pytest.mark.skipif( - not six.PY2, reason="python3 has a different internal header implementation" - ) - def test_from_httplib_py2(self): - msg = """ -Server: nginx -Content-Type: text/html; charset=windows-1251 -Connection: keep-alive -X-Some-Multiline: asdf - asdf\t -\t asdf -Set-Cookie: bb_lastvisit=1348253375; expires=Sat, 21-Sep-2013 18:49:35 GMT; path=/ -Set-Cookie: bb_lastactivity=0; expires=Sat, 21-Sep-2013 18:49:35 GMT; path=/ -www-authenticate: asdf -www-authenticate: bla - -""" - buffer = six.moves.StringIO(msg.lstrip().replace("\n", "\r\n")) - msg = six.moves.http_client.HTTPMessage(buffer) - d = HTTPHeaderDict.from_httplib(msg) - assert d["server"] == "nginx" - cookies = d.getlist("set-cookie") - assert len(cookies) == 2 - assert cookies[0].startswith("bb_lastvisit") - assert cookies[1].startswith("bb_lastactivity") - assert d["x-some-multiline"] == "asdf asdf asdf" - assert d["www-authenticate"] == "asdf, bla" - assert d.getlist("www-authenticate") == ["asdf", "bla"] - with_invalid_multiline = """\tthis-is-not-a-header: but it has a pretend value -Authorization: Bearer 123 - -""" - buffer = six.moves.StringIO(with_invalid_multiline.replace("\n", "\r\n")) - msg = six.moves.http_client.HTTPMessage(buffer) - with pytest.raises(InvalidHeader): - HTTPHeaderDict.from_httplib(msg) + HTTPHeaderDict({3: 3}) # type: ignore[arg-type] + + def test_dunder_contains(self, d: HTTPHeaderDict) -> None: + """ + Test: + + HTTPHeaderDict.__contains__ returns True + - for matched string objects + - for case-similar string objects + HTTPHeaderDict.__contains__ returns False + - for non-similar strings + - for non-strings, even if they are keys + in the underlying datastructure + """ + assert "cookie" in d + assert "CoOkIe" in d + assert "Not a cookie" not in d + + marker = object() + d._container[marker] = ["some", "strings"] # type: ignore[index] + assert marker not in d + assert marker in d._container diff --git a/test/test_compatibility.py b/test/test_compatibility.py index 58a9ab5c6f..f18b0035ce 100644 --- a/test/test_compatibility.py +++ b/test/test_compatibility.py @@ -1,37 +1,15 @@ -import warnings +from __future__ import annotations -import pytest +import http.cookiejar +import urllib -from urllib3.connection import HTTPConnection -from urllib3.packages.six.moves import http_cookiejar, urllib from urllib3.response import HTTPResponse -class TestVersionCompatibility(object): - def test_connection_strict(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - # strict=True is deprecated in Py33+ - HTTPConnection("localhost", 12345, strict=True) - - if w: - pytest.fail( - "HTTPConnection raised warning on strict=True: %r" % w[0].message - ) - - def test_connection_source_address(self): - try: - # source_address does not exist in Py26- - HTTPConnection("localhost", 12345, source_address="127.0.0.1") - except TypeError as e: - pytest.fail("HTTPConnection raised TypeError on source_address: %r" % e) - - -class TestCookiejar(object): - def test_extract(self): +class TestCookiejar: + def test_extract(self) -> None: request = urllib.request.Request("http://google.com") - cookiejar = http_cookiejar.CookieJar() + cookiejar = http.cookiejar.CookieJar() response = HTTPResponse() cookies = [ @@ -40,5 +18,5 @@ def test_extract(self): ] for c in cookies: response.headers.add("set-cookie", c) - cookiejar.extract_cookies(response, request) + cookiejar.extract_cookies(response, request) # type: ignore[arg-type] assert len(cookiejar) == len(cookies) diff --git a/test/test_connection.py b/test/test_connection.py index 821ce4c226..8237311fe8 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,35 +1,56 @@ +from __future__ import annotations + import datetime +import socket +import typing +from http.client import ResponseNotReady +from unittest import mock -import mock import pytest -from urllib3.connection import RECENT_DATE, CertificateError, _match_hostname +from urllib3.connection import ( # type: ignore[attr-defined] + RECENT_DATE, + CertificateError, + HTTPConnection, + HTTPSConnection, + _match_hostname, + _url_from_connection, + _wrap_proxy_error, +) +from urllib3.exceptions import HTTPError, ProxyError +from urllib3.util.ssl_match_hostname import ( + CertificateError as ImplementationCertificateError, +) +from urllib3.util.ssl_match_hostname import _dnsname_match, match_hostname + +if typing.TYPE_CHECKING: + from urllib3.util.ssl_ import _TYPE_PEER_CERT_RET_DICT -class TestConnection(object): +class TestConnection: """ Tests in this suite should not make any network requests or connections. """ - def test_match_hostname_no_cert(self): + def test_match_hostname_no_cert(self) -> None: cert = None asserted_hostname = "foo" with pytest.raises(ValueError): _match_hostname(cert, asserted_hostname) - def test_match_hostname_empty_cert(self): - cert = {} + def test_match_hostname_empty_cert(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {} asserted_hostname = "foo" with pytest.raises(ValueError): _match_hostname(cert, asserted_hostname) - def test_match_hostname_match(self): - cert = {"subjectAltName": [("DNS", "foo")]} + def test_match_hostname_match(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("DNS", "foo"),)} asserted_hostname = "foo" _match_hostname(cert, asserted_hostname) - def test_match_hostname_mismatch(self): - cert = {"subjectAltName": [("DNS", "foo")]} + def test_match_hostname_mismatch(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("DNS", "foo"),)} asserted_hostname = "bar" try: with mock.patch("urllib3.connection.log.warning") as mock_log: @@ -39,14 +60,178 @@ def test_match_hostname_mismatch(self): mock_log.assert_called_once_with( "Certificate did not match expected hostname: %s. Certificate: %s", "bar", - {"subjectAltName": [("DNS", "foo")]}, + {"subjectAltName": (("DNS", "foo"),)}, + ) + assert e._peer_cert == cert # type: ignore[attr-defined] + + def test_match_hostname_no_dns(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("DNS", ""),)} + asserted_hostname = "bar" + try: + with mock.patch("urllib3.connection.log.warning") as mock_log: + _match_hostname(cert, asserted_hostname) + except CertificateError as e: + assert "hostname 'bar' doesn't match ''" in str(e) + mock_log.assert_called_once_with( + "Certificate did not match expected hostname: %s. Certificate: %s", + "bar", + {"subjectAltName": (("DNS", ""),)}, ) - assert e._peer_cert == cert + assert e._peer_cert == cert # type: ignore[attr-defined] + + def test_match_hostname_startwith_wildcard(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("DNS", "*"),)} + asserted_hostname = "foo" + _match_hostname(cert, asserted_hostname) + + def test_match_hostname_dnsname(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": (("DNS", "xn--p1b6ci4b4b3a*.xn--11b5bs8d"),) + } + asserted_hostname = "xn--p1b6ci4b4b3a*.xn--11b5bs8d" + _match_hostname(cert, asserted_hostname) + + def test_match_hostname_include_wildcard(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("DNS", "foo*"),)} + asserted_hostname = "foobar" + _match_hostname(cert, asserted_hostname) + + def test_match_hostname_more_than_one_dnsname_error(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": (("DNS", "foo*"), ("DNS", "fo*")) + } + asserted_hostname = "bar" + with pytest.raises(CertificateError, match="doesn't match either of"): + _match_hostname(cert, asserted_hostname) + + def test_dnsname_match_include_more_than_one_wildcard_error(self) -> None: + with pytest.raises(CertificateError, match="too many wildcards in certificate"): + _dnsname_match("foo**", "foobar") + + def test_match_hostname_ignore_common_name(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subject": ((("commonName", "foo"),),)} + asserted_hostname = "foo" + with pytest.raises( + ImplementationCertificateError, + match="no appropriate subjectAltName fields were found", + ): + match_hostname(cert, asserted_hostname) - def test_recent_date(self): + def test_match_hostname_check_common_name(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = {"subject": ((("commonName", "foo"),),)} + asserted_hostname = "foo" + match_hostname(cert, asserted_hostname, True) + + def test_match_hostname_ip_address(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": (("IP Address", "1.1.1.1"),) + } + asserted_hostname = "1.1.1.2" + try: + with mock.patch("urllib3.connection.log.warning") as mock_log: + _match_hostname(cert, asserted_hostname) + except CertificateError as e: + assert "hostname '1.1.1.2' doesn't match '1.1.1.1'" in str(e) + mock_log.assert_called_once_with( + "Certificate did not match expected hostname: %s. Certificate: %s", + "1.1.1.2", + {"subjectAltName": (("IP Address", "1.1.1.1"),)}, + ) + assert e._peer_cert == cert # type: ignore[attr-defined] + + @pytest.mark.parametrize( + ["asserted_hostname", "san_ip"], + [ + ("1:2::3:4", "1:2:0:0:0:0:3:4"), + ("1:2:0:0::3:4", "1:2:0:0:0:0:3:4"), + ("::0.1.0.2", "0:0:0:0:0:0:1:2"), + ("::1%42", "0:0:0:0:0:0:0:1"), + ("::2%iface", "0:0:0:0:0:0:0:2"), + ], + ) + def test_match_hostname_ip_address_ipv6( + self, asserted_hostname: str, san_ip: str + ) -> None: + """Check that hostname matches follow RFC 9110 rules for IPv6.""" + cert: _TYPE_PEER_CERT_RET_DICT = {"subjectAltName": (("IP Address", san_ip),)} + match_hostname(cert, asserted_hostname) + + def test_match_hostname_ip_address_ipv6_doesnt_match(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": (("IP Address", "1:2::2:1"),) + } + asserted_hostname = "1:2::2:2" + try: + with mock.patch("urllib3.connection.log.warning") as mock_log: + _match_hostname(cert, asserted_hostname) + except CertificateError as e: + assert "hostname '1:2::2:2' doesn't match '1:2::2:1'" in str(e) + mock_log.assert_called_once_with( + "Certificate did not match expected hostname: %s. Certificate: %s", + "1:2::2:2", + {"subjectAltName": (("IP Address", "1:2::2:1"),)}, + ) + assert e._peer_cert == cert # type: ignore[attr-defined] + + def test_match_hostname_dns_with_brackets_doesnt_match(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": ( + ("DNS", "localhost"), + ("IP Address", "localhost"), + ) + } + asserted_hostname = "[localhost]" + with pytest.raises(CertificateError) as e: + _match_hostname(cert, asserted_hostname) + assert ( + "hostname '[localhost]' doesn't match either of 'localhost', 'localhost'" + in str(e.value) + ) + + def test_match_hostname_ip_address_ipv6_brackets(self) -> None: + cert: _TYPE_PEER_CERT_RET_DICT = { + "subjectAltName": (("IP Address", "1:2::2:1"),) + } + asserted_hostname = "[1:2::2:1]" + # Assert no error is raised + _match_hostname(cert, asserted_hostname) + + def test_recent_date(self) -> None: # This test is to make sure that the RECENT_DATE value # doesn't get too far behind what the current date is. # When this test fails update urllib3.connection.RECENT_DATE # according to the rules defined in that file. two_years = datetime.timedelta(days=365 * 2) assert RECENT_DATE > (datetime.datetime.today() - two_years).date() + + def test_HTTPSConnection_default_socket_options(self) -> None: + conn = HTTPSConnection("not.a.real.host", port=443) + assert conn.socket_options == [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + + @pytest.mark.parametrize( + "proxy_scheme, err_part", + [ + ("http", "Unable to connect to proxy"), + ( + "https", + "Unable to connect to proxy. Your proxy appears to only use HTTP and not HTTPS", + ), + ], + ) + def test_wrap_proxy_error(self, proxy_scheme: str, err_part: str) -> None: + new_err = _wrap_proxy_error(HTTPError("unknown protocol"), proxy_scheme) + assert isinstance(new_err, ProxyError) is True + assert err_part in new_err.args[0] + + def test_url_from_pool(self) -> None: + conn = HTTPConnection("google.com", port=80) + + path = "path?query=foo" + assert f"http://google.com:80/{path}" == _url_from_connection(conn, path) + + def test_getresponse_requires_reponseoptions(self) -> None: + conn = HTTPConnection("google.com", port=80) + + # Should error if a request has not been sent + with pytest.raises(ResponseNotReady): + conn.getresponse() diff --git a/test/test_connectionpool.py b/test/test_connectionpool.py index eec6bd27c8..d81d33f7bd 100644 --- a/test/test_connectionpool.py +++ b/test/test_connectionpool.py @@ -1,44 +1,48 @@ -from __future__ import absolute_import +from __future__ import annotations +import http.client as httplib import ssl +import typing +from http.client import HTTPException +from queue import Empty from socket import error as SocketError from ssl import SSLError as BaseSSLError from test import SHORT_TIMEOUT +from unittest.mock import Mock, patch import pytest -from mock import Mock from dummyserver.server import DEFAULT_CA -from urllib3._collections import HTTPHeaderDict +from urllib3 import Retry +from urllib3.connection import HTTPConnection from urllib3.connectionpool import ( - HTTPConnection, HTTPConnectionPool, HTTPSConnectionPool, + _url_from_pool, connection_from_url, ) from urllib3.exceptions import ( ClosedPoolError, EmptyPoolError, + FullPoolError, HostChangedError, LocationValueError, MaxRetryError, ProtocolError, + ReadTimeoutError, SSLError, TimeoutError, ) -from urllib3.packages.six.moves import http_client as httplib -from urllib3.packages.six.moves.http_client import HTTPException -from urllib3.packages.six.moves.queue import Empty -from urllib3.packages.ssl_match_hostname import CertificateError from urllib3.response import HTTPResponse -from urllib3.util.timeout import Timeout +from urllib3.util.ssl_match_hostname import CertificateError +from urllib3.util.timeout import _DEFAULT_TIMEOUT, Timeout from .test_response import MockChunkedEncodingResponse, MockSock class HTTPUnixConnection(HTTPConnection): - def __init__(self, host, timeout=60, **kwargs): - super(HTTPUnixConnection, self).__init__("localhost") + def __init__(self, host: str, timeout: int = 60, **kwargs: typing.Any) -> None: + super().__init__("localhost") self.unix_socket = host self.timeout = timeout self.sock = None @@ -49,7 +53,7 @@ class HTTPUnixConnectionPool(HTTPConnectionPool): ConnectionCls = HTTPUnixConnection -class TestConnectionPool(object): +class TestConnectionPool: """ Tests in this suite should exercise the ConnectionPool functionality without actually making any network requests or connections. @@ -83,7 +87,7 @@ class TestConnectionPool(object): ), ], ) - def test_same_host(self, a, b): + def test_same_host(self, a: str, b: str) -> None: with connection_from_url(a) as c: assert c.is_same_host(b) @@ -109,7 +113,7 @@ def test_same_host(self, a, b): ("http://[dead::beef]", "https://[dead::beef%en5]/"), ], ) - def test_not_same_host(self, a, b): + def test_not_same_host(self, a: str, b: str) -> None: with connection_from_url(a) as c: assert not c.is_same_host(b) @@ -127,7 +131,7 @@ def test_not_same_host(self, a, b): ("google.com", "http://google.com:80/abracadabra"), ], ) - def test_same_host_no_port_http(self, a, b): + def test_same_host_no_port_http(self, a: str, b: str) -> None: # This test was introduced in #801 to deal with the fact that urllib3 # never initializes ConnectionPool objects with port=None. with HTTPConnectionPool(a) as c: @@ -144,7 +148,7 @@ def test_same_host_no_port_http(self, a, b): ("google.com", "https://google.com:443/abracadabra"), ], ) - def test_same_host_no_port_https(self, a, b): + def test_same_host_no_port_https(self, a: str, b: str) -> None: # This test was introduced in #801 to deal with the fact that urllib3 # never initializes ConnectionPool objects with port=None. with HTTPSConnectionPool(a) as c: @@ -159,7 +163,7 @@ def test_same_host_no_port_https(self, a, b): ("google.com", "http://google.com./"), ], ) - def test_not_same_host_no_port_http(self, a, b): + def test_not_same_host_no_port_http(self, a: str, b: str) -> None: with HTTPConnectionPool(a) as c: assert not c.is_same_host(b) @@ -175,7 +179,7 @@ def test_not_same_host_no_port_http(self, a, b): ("google.com", "https://google.com./"), ], ) - def test_not_same_host_no_port_https(self, a, b): + def test_not_same_host_no_port_https(self, a: str, b: str) -> None: with HTTPSConnectionPool(a) as c: assert not c.is_same_host(b) @@ -196,7 +200,7 @@ def test_not_same_host_no_port_https(self, a, b): ("%2Ftmp%2FTEST.sock", "http+unix://%2Ftmp%2FTEST.sock/abracadabra"), ], ) - def test_same_host_custom_protocol(self, a, b): + def test_same_host_custom_protocol(self, a: str, b: str) -> None: with HTTPUnixConnectionPool(a) as c: assert c.is_same_host(b) @@ -209,11 +213,11 @@ def test_same_host_custom_protocol(self, a, b): ("%2Fvar%2Frun%2Fdocker.sock", "http+unix://%2Ftmp%2FTEST.sock"), ], ) - def test_not_same_host_custom_protocol(self, a, b): + def test_not_same_host_custom_protocol(self, a: str, b: str) -> None: with HTTPUnixConnectionPool(a) as c: assert not c.is_same_host(b) - def test_max_connections(self): + def test_max_connections(self) -> None: with HTTPConnectionPool(host="localhost", maxsize=1, block=True) as pool: pool._get_conn(timeout=SHORT_TIMEOUT) @@ -225,26 +229,75 @@ def test_max_connections(self): assert pool.num_connections == 1 - def test_pool_edgecases(self): + def test_put_conn_when_pool_is_full_nonblocking( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """ + If maxsize = n and we _put_conn n + 1 conns, the n + 1th conn will + get closed and will not get added to the pool. + """ with HTTPConnectionPool(host="localhost", maxsize=1, block=False) as pool: conn1 = pool._get_conn() - conn2 = pool._get_conn() # New because block=False + # pool.pool is empty because we popped the one None that pool.pool was initialized with + # but this pool._get_conn call will not raise EmptyPoolError because block is False + conn2 = pool._get_conn() - pool._put_conn(conn1) - pool._put_conn(conn2) # Should be discarded + with patch.object(conn1, "close") as conn1_close: + with patch.object(conn2, "close") as conn2_close: + pool._put_conn(conn1) + pool._put_conn(conn2) + + assert conn1_close.called is False + assert conn2_close.called is True assert conn1 == pool._get_conn() assert conn2 != pool._get_conn() assert pool.num_connections == 3 + assert "Connection pool is full, discarding connection" in caplog.text + assert "Connection pool size: 1" in caplog.text + + def test_put_conn_when_pool_is_full_blocking(self) -> None: + """ + If maxsize = n and we _put_conn n + 1 conns, the n + 1th conn will + cause a FullPoolError. + """ + with HTTPConnectionPool(host="localhost", maxsize=1, block=True) as pool: + conn1 = pool._get_conn() + conn2 = pool._new_conn() + + with patch.object(conn1, "close") as conn1_close: + with patch.object(conn2, "close") as conn2_close: + pool._put_conn(conn1) + with pytest.raises(FullPoolError): + pool._put_conn(conn2) + + assert conn1_close.called is False + assert conn2_close.called is True + + assert conn1 == pool._get_conn() + + def test_put_conn_closed_pool(self) -> None: + with HTTPConnectionPool(host="localhost", maxsize=1, block=True) as pool: + conn1 = pool._get_conn() + with patch.object(conn1, "close") as conn1_close: + pool.close() + + assert pool.pool is None - def test_exception_str(self): + # Accessing pool.pool will raise AttributeError, which will get + # caught and will close conn1 + pool._put_conn(conn1) + + assert conn1_close.called is True + + def test_exception_str(self) -> None: assert ( str(EmptyPoolError(HTTPConnectionPool(host="localhost"), "Test.")) == "HTTPConnectionPool(host='localhost', port=None): Test." ) - def test_retry_exception_str(self): + def test_retry_exception_str(self) -> None: assert ( str(MaxRetryError(HTTPConnectionPool(host="localhost"), "Test.", None)) == "HTTPConnectionPool(host='localhost', port=None): " @@ -262,21 +315,23 @@ def test_retry_exception_str(self): "(Caused by %r)" % err ) - def test_pool_size(self): + def test_pool_size(self) -> None: POOL_SIZE = 1 with HTTPConnectionPool( host="localhost", maxsize=POOL_SIZE, block=True ) as pool: - def _raise(ex): - raise ex() - - def _test(exception, expect, reason=None): - pool._make_request = lambda *args, **kwargs: _raise(exception) - with pytest.raises(expect) as excinfo: - pool.request("GET", "/") + def _test( + exception: type[BaseException], + expect: type[BaseException], + reason: type[BaseException] | None = None, + ) -> None: + with patch.object(pool, "_make_request", side_effect=exception()): + with pytest.raises(expect) as excinfo: + pool.request("GET", "/") if reason is not None: - assert isinstance(excinfo.value.reason, reason) + assert isinstance(excinfo.value.reason, reason) # type: ignore[attr-defined] + assert pool.pool is not None assert pool.pool.qsize() == POOL_SIZE # Make sure that all of the exceptions return the connection @@ -288,26 +343,33 @@ def _test(exception, expect, reason=None): # being raised, a retry will be triggered, but that retry will # fail, eventually raising MaxRetryError, not EmptyPoolError # See: https://github.com/urllib3/urllib3/issues/76 - pool._make_request = lambda *args, **kwargs: _raise(HTTPException) - with pytest.raises(MaxRetryError): - pool.request("GET", "/", retries=1, pool_timeout=SHORT_TIMEOUT) + with patch.object(pool, "_make_request", side_effect=HTTPException()): + with pytest.raises(MaxRetryError): + pool.request("GET", "/", retries=1, pool_timeout=SHORT_TIMEOUT) + assert pool.pool is not None assert pool.pool.qsize() == POOL_SIZE - def test_empty_does_not_put_conn(self): + def test_empty_does_not_put_conn(self) -> None: """Do not put None back in the pool if the pool was empty""" with HTTPConnectionPool(host="localhost", maxsize=1, block=True) as pool: - pool._get_conn = Mock(side_effect=EmptyPoolError(pool, "Pool is empty")) - pool._put_conn = Mock(side_effect=AssertionError("Unexpected _put_conn")) - with pytest.raises(EmptyPoolError): - pool.request("GET", "/") - - def test_assert_same_host(self): + with patch.object( + pool, "_get_conn", side_effect=EmptyPoolError(pool, "Pool is empty") + ): + with patch.object( + pool, + "_put_conn", + side_effect=AssertionError("Unexpected _put_conn"), + ): + with pytest.raises(EmptyPoolError): + pool.request("GET", "/") + + def test_assert_same_host(self) -> None: with connection_from_url("http://google.com:80") as c: with pytest.raises(HostChangedError): c.request("GET", "http://yahoo.com:80", assert_same_host=True) - def test_pool_close(self): + def test_pool_close(self) -> None: pool = connection_from_url("http://google.com:80") # Populate with some connections @@ -331,9 +393,10 @@ def test_pool_close(self): pool._get_conn() with pytest.raises(Empty): + assert old_pool_queue is not None old_pool_queue.get(block=False) - def test_pool_close_twice(self): + def test_pool_close_twice(self) -> None: pool = connection_from_url("http://google.com:80") # Populate with some connections @@ -350,13 +413,13 @@ def test_pool_close_twice(self): except AttributeError: pytest.fail("Pool of the ConnectionPool is None and has no attribute get.") - def test_pool_timeouts(self): + def test_pool_timeouts(self) -> None: with HTTPConnectionPool(host="localhost") as pool: conn = pool._new_conn() assert conn.__class__ == HTTPConnection assert pool.timeout.__class__ == Timeout - assert pool.timeout._read == Timeout.DEFAULT_TIMEOUT - assert pool.timeout._connect == Timeout.DEFAULT_TIMEOUT + assert pool.timeout._read == _DEFAULT_TIMEOUT + assert pool.timeout._connect == _DEFAULT_TIMEOUT assert pool.timeout.total is None pool = HTTPConnectionPool(host="localhost", timeout=SHORT_TIMEOUT) @@ -364,11 +427,11 @@ def test_pool_timeouts(self): assert pool.timeout._connect == SHORT_TIMEOUT assert pool.timeout.total is None - def test_no_host(self): + def test_no_host(self) -> None: with pytest.raises(LocationValueError): - HTTPConnectionPool(None) + HTTPConnectionPool(None) # type: ignore[arg-type] - def test_contextmanager(self): + def test_contextmanager(self) -> None: with connection_from_url("http://google.com:80") as pool: # Populate with some connections conn1 = pool._get_conn() @@ -387,20 +450,20 @@ def test_contextmanager(self): with pytest.raises(ClosedPoolError): pool._get_conn() with pytest.raises(Empty): + assert old_pool_queue is not None old_pool_queue.get(block=False) - def test_absolute_url(self): - with connection_from_url("http://google.com:80") as c: - assert "http://google.com:80/path?query=foo" == c._absolute_url( - "path?query=foo" - ) + def test_url_from_pool(self) -> None: + with connection_from_url("http://google.com:80") as pool: + path = "path?query=foo" + assert f"http://google.com:80/{path}" == _url_from_pool(pool, path) - def test_ca_certs_default_cert_required(self): + def test_ca_certs_default_cert_required(self) -> None: with connection_from_url("https://google.com:80", ca_certs=DEFAULT_CA) as pool: conn = pool._get_conn() - assert conn.cert_reqs == ssl.CERT_REQUIRED + assert conn.cert_reqs == ssl.CERT_REQUIRED # type: ignore[attr-defined] - def test_cleanup_on_extreme_connection_error(self): + def test_cleanup_on_extreme_connection_error(self) -> None: """ This test validates that we clean up properly even on exceptions that we'd not otherwise catch, i.e. those that inherit from BaseException @@ -410,25 +473,25 @@ def test_cleanup_on_extreme_connection_error(self): class RealBad(BaseException): pass - def kaboom(*args, **kwargs): + def kaboom(*args: typing.Any, **kwargs: typing.Any) -> None: raise RealBad() with connection_from_url("http://localhost:80") as c: - c._make_request = kaboom + with patch.object(c, "_make_request", kaboom): + assert c.pool is not None + initial_pool_size = c.pool.qsize() - initial_pool_size = c.pool.qsize() - - try: - # We need to release_conn this way or we'd put it away - # regardless. - c.urlopen("GET", "/", release_conn=False) - except RealBad: - pass + try: + # We need to release_conn this way or we'd put it away + # regardless. + c.urlopen("GET", "/", release_conn=False) + except RealBad: + pass new_pool_size = c.pool.qsize() assert initial_pool_size == new_pool_size - def test_release_conn_param_is_respected_after_http_error_retry(self): + def test_release_conn_param_is_respected_after_http_error_retry(self) -> None: """For successful ```urlopen(release_conn=False)```, the connection isn't released, even after a retry. @@ -439,40 +502,73 @@ def test_release_conn_param_is_respected_after_http_error_retry(self): [1] """ - class _raise_once_make_request_function(object): + class _raise_once_make_request_function: """Callable that can mimic `_make_request()`. Raises the given exception on its first call, but returns a successful response on subsequent calls. """ - def __init__(self, ex): - super(_raise_once_make_request_function, self).__init__() - self._ex = ex - - def __call__(self, *args, **kwargs): + def __init__( + self, ex: type[BaseException], pool: HTTPConnectionPool + ) -> None: + super().__init__() + self._ex: type[BaseException] | None = ex + self._pool = pool + + def __call__( + self, + conn: HTTPConnection, + method: str, + url: str, + *args: typing.Any, + retries: Retry, + **kwargs: typing.Any, + ) -> HTTPResponse: if self._ex: ex, self._ex = self._ex, None raise ex() - response = httplib.HTTPResponse(MockSock) - response.fp = MockChunkedEncodingResponse([b"f", b"o", b"o"]) - response.headers = response.msg = HTTPHeaderDict() + httplib_response = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + httplib_response.fp = MockChunkedEncodingResponse([b"f", b"o", b"o"]) # type: ignore[assignment] + httplib_response.headers = httplib_response.msg = httplib.HTTPMessage() + + response_conn: HTTPConnection | None = kwargs.get("response_conn") + + response = HTTPResponse( + body=httplib_response, + headers=httplib_response.headers, # type: ignore[arg-type] + status=httplib_response.status, + version=httplib_response.version, + reason=httplib_response.reason, + original_response=httplib_response, + retries=retries, + request_method=method, + request_url=url, + preload_content=False, + connection=response_conn, + pool=self._pool, + ) return response - def _test(exception): + def _test(exception: type[BaseException]) -> None: with HTTPConnectionPool(host="localhost", maxsize=1, block=True) as pool: # Verify that the request succeeds after two attempts, and that the # connection is left on the response object, instead of being # released back into the pool. - pool._make_request = _raise_once_make_request_function(exception) - response = pool.urlopen( - "GET", - "/", - retries=1, - release_conn=False, - preload_content=False, - chunked=True, - ) + with patch.object( + pool, + "_make_request", + _raise_once_make_request_function(exception, pool), + ): + response = pool.urlopen( + "GET", + "/", + retries=1, + release_conn=False, + preload_content=False, + chunked=True, + ) + assert pool.pool is not None assert pool.pool.qsize() == 0 assert pool.num_connections == 2 assert response.connection is not None @@ -487,21 +583,12 @@ def _test(exception): _test(SocketError) _test(ProtocolError) - def test_custom_http_response_class(self): - class CustomHTTPResponse(HTTPResponse): - pass - - class CustomConnectionPool(HTTPConnectionPool): - ResponseCls = CustomHTTPResponse - - def _make_request(self, *args, **kwargs): - httplib_response = httplib.HTTPResponse(MockSock) - httplib_response.fp = MockChunkedEncodingResponse([b"f", b"o", b"o"]) - httplib_response.headers = httplib_response.msg = HTTPHeaderDict() - return httplib_response - - with CustomConnectionPool(host="localhost", maxsize=1, block=True) as pool: - response = pool.request( - "GET", "/", retries=False, chunked=True, preload_content=False - ) - assert isinstance(response, CustomHTTPResponse) + def test_read_timeout_0_does_not_raise_bad_status_line_error(self) -> None: + with HTTPConnectionPool(host="localhost", maxsize=1) as pool: + conn = Mock(spec=HTTPConnection) + # Needed to tell the pool that the connection is alive. + conn.is_closed = False + with patch.object(Timeout, "read_timeout", 0): + timeout = Timeout(1, 1, 1) + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "", "", timeout=timeout) diff --git a/test/test_exceptions.py b/test/test_exceptions.py index 9fd0eb0fb0..8d9eb1093f 100644 --- a/test/test_exceptions.py +++ b/test/test_exceptions.py @@ -1,7 +1,12 @@ +from __future__ import annotations + import pickle +from email.errors import MessageDefect +from test import DUMMY_POOL import pytest +from urllib3.connection import HTTPConnection from urllib3.connectionpool import HTTPConnectionPool from urllib3.exceptions import ( ClosedPoolError, @@ -12,36 +17,53 @@ HTTPError, LocationParseError, MaxRetryError, + NewConnectionError, ReadTimeoutError, ) -class TestPickle(object): +class TestPickle: @pytest.mark.parametrize( "exception", [ HTTPError(None), - MaxRetryError(None, None, None), - LocationParseError(None), + MaxRetryError(DUMMY_POOL, "", None), + LocationParseError(""), ConnectTimeoutError(None), HTTPError("foo"), HTTPError("foo", IOError("foo")), MaxRetryError(HTTPConnectionPool("localhost"), "/", None), LocationParseError("fake location"), - ClosedPoolError(HTTPConnectionPool("localhost"), None), - EmptyPoolError(HTTPConnectionPool("localhost"), None), - HostChangedError(HTTPConnectionPool("localhost"), "/", None), - ReadTimeoutError(HTTPConnectionPool("localhost"), "/", None), + ClosedPoolError(HTTPConnectionPool("localhost"), ""), + EmptyPoolError(HTTPConnectionPool("localhost"), ""), + HostChangedError(HTTPConnectionPool("localhost"), "/", 0), + ReadTimeoutError(HTTPConnectionPool("localhost"), "/", ""), ], ) - def test_exceptions(self, exception): + def test_exceptions(self, exception: Exception) -> None: result = pickle.loads(pickle.dumps(exception)) assert isinstance(result, type(exception)) -class TestFormat(object): - def test_header_parsing_errors(self): - hpe = HeaderParsingError("defects", "unparsed_data") +class TestFormat: + def test_header_parsing_errors(self) -> None: + hpe = HeaderParsingError([MessageDefect("defects")], "unparsed_data") assert "defects" in str(hpe) assert "unparsed_data" in str(hpe) + + +class TestNewConnectionError: + def test_pool_property_deprecation_warning(self) -> None: + err = NewConnectionError(HTTPConnection("localhost"), "test") + with pytest.warns(DeprecationWarning) as records: + err_pool = err.pool + + assert err_pool is err.conn + msg = ( + "The 'pool' property is deprecated and will be removed " + "in urllib3 v2.1.0. Use 'conn' instead." + ) + record = records[0] + assert isinstance(record.message, Warning) + assert record.message.args[0] == msg diff --git a/test/test_fields.py b/test/test_fields.py index 98ce17c1f4..faf74a34b3 100644 --- a/test/test_fields.py +++ b/test/test_fields.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import pytest -from urllib3.fields import RequestField, format_header_param_rfc2231, guess_content_type -from urllib3.packages.six import u +from urllib3.fields import ( + RequestField, + format_header_param, + format_header_param_html5, + format_header_param_rfc2231, + format_multipart_header_param, + guess_content_type, +) -class TestRequestField(object): +class TestRequestField: @pytest.mark.parametrize( "filename, content_types", [ @@ -13,18 +21,22 @@ class TestRequestField(object): (None, ["application/octet-stream"]), ], ) - def test_guess_content_type(self, filename, content_types): + def test_guess_content_type( + self, filename: str | None, content_types: list[str] + ) -> None: assert guess_content_type(filename) in content_types - def test_create(self): + def test_create(self) -> None: simple_field = RequestField("somename", "data") assert simple_field.render_headers() == "\r\n" filename_field = RequestField("somename", "data", filename="somefile.txt") assert filename_field.render_headers() == "\r\n" - headers_field = RequestField("somename", "data", headers={"Content-Length": 4}) + headers_field = RequestField( + "somename", "data", headers={"Content-Length": "4"} + ) assert headers_field.render_headers() == "Content-Length: 4\r\n\r\n" - def test_make_multipart(self): + def test_make_multipart(self) -> None: field = RequestField("somename", "data") field.make_multipart(content_type="image/jpg", content_location="/test") assert ( @@ -35,7 +47,7 @@ def test_make_multipart(self): "\r\n" ) - def test_make_multipart_empty_filename(self): + def test_make_multipart_empty_filename(self) -> None: field = RequestField("somename", "data", "") field.make_multipart(content_type="application/octet-stream") assert ( @@ -45,7 +57,7 @@ def test_make_multipart_empty_filename(self): "\r\n" ) - def test_render_parts(self): + def test_render_parts(self) -> None: field = RequestField("somename", "data") parts = field._render_parts({"name": "value", "filename": "value"}) assert 'name="value"' in parts @@ -53,50 +65,56 @@ def test_render_parts(self): parts = field._render_parts([("name", "value"), ("filename", "value")]) assert parts == 'name="value"; filename="value"' - def test_render_part_rfc2231_unicode(self): - field = RequestField( - "somename", "data", header_formatter=format_header_param_rfc2231 - ) - param = field._render_part("filename", u("n\u00e4me")) - assert param == "filename*=utf-8''n%C3%A4me" + @pytest.mark.parametrize( + ("value", "expect"), + [("näme", "filename*=utf-8''n%C3%A4me"), (b"name", 'filename="name"')], + ) + def test_format_header_param_rfc2231_deprecated( + self, value: bytes | str, expect: str + ) -> None: + with pytest.deprecated_call(match=r"urllib3 v2\.1\.0"): + param = format_header_param_rfc2231("filename", value) - def test_render_part_rfc2231_ascii(self): - field = RequestField( - "somename", "data", header_formatter=format_header_param_rfc2231 - ) - param = field._render_part("filename", b"name") - assert param == 'filename="name"' + assert param == expect - def test_render_part_html5_unicode(self): - field = RequestField("somename", "data") - param = field._render_part("filename", u("n\u00e4me")) - assert param == u('filename="n\u00e4me"') + def test_format_header_param_html5_deprecated(self) -> None: + with pytest.deprecated_call(match=r"urllib3 v2\.1\.0"): + param2 = format_header_param_html5("filename", "name") - def test_render_part_html5_ascii(self): - field = RequestField("somename", "data") - param = field._render_part("filename", b"name") - assert param == 'filename="name"' + with pytest.deprecated_call(match=r"urllib3 v2\.1\.0"): + param1 = format_header_param("filename", "name") - def test_render_part_html5_unicode_escape(self): - field = RequestField("somename", "data") - param = field._render_part("filename", u("hello\\world\u0022")) - assert param == u('filename="hello\\\\world%22"') + assert param1 == param2 - def test_render_part_html5_unicode_with_control_character(self): - field = RequestField("somename", "data") - param = field._render_part("filename", u("hello\x1A\x1B\x1C")) - assert param == u('filename="hello%1A\x1B%1C"') - - def test_from_tuples_rfc2231(self): - field = RequestField.from_tuples( - u("fieldname"), - (u("filen\u00e4me"), "data"), - header_formatter=format_header_param_rfc2231, - ) + @pytest.mark.parametrize( + ("value", "expect"), + [ + ("name", "name"), + ("näme", "näme"), + (b"n\xc3\xa4me", "näme"), + ("ski ⛷.txt", "ski ⛷.txt"), + ("control \x1A\x1B\x1C", "control \x1A\x1B\x1C"), + ("backslash \\", "backslash \\"), + ("quotes '\"", "quotes '%22"), + ("newline \n\r", "newline %0A%0D"), + ], + ) + def test_format_multipart_header_param( + self, value: bytes | str, expect: str + ) -> None: + param = format_multipart_header_param("filename", value) + assert param == f'filename="{expect}"' + + def test_from_tuples(self) -> None: + field = RequestField.from_tuples("file", ("スキー旅行.txt", "data")) cd = field.headers["Content-Disposition"] - assert cd == u("form-data; name=\"fieldname\"; filename*=utf-8''filen%C3%A4me") + assert cd == 'form-data; name="file"; filename="スキー旅行.txt"' + + def test_from_tuples_rfc2231(self) -> None: + with pytest.deprecated_call(match=r"urllib3 v2\.1\.0"): + field = RequestField.from_tuples( + "file", ("näme", "data"), header_formatter=format_header_param_rfc2231 + ) - def test_from_tuples_html5(self): - field = RequestField.from_tuples(u("fieldname"), (u("filen\u00e4me"), "data")) cd = field.headers["Content-Disposition"] - assert cd == u('form-data; name="fieldname"; filename="filen\u00e4me"') + assert cd == "form-data; name=\"file\"; filename*=utf-8''n%C3%A4me" diff --git a/test/test_filepost.py b/test/test_filepost.py index 5b0cfe1cb6..b6da4b9447 100644 --- a/test/test_filepost.py +++ b/test/test_filepost.py @@ -1,112 +1,100 @@ +from __future__ import annotations + import pytest from urllib3.fields import RequestField -from urllib3.filepost import encode_multipart_formdata, iter_fields -from urllib3.packages.six import b, u +from urllib3.filepost import _TYPE_FIELDS, encode_multipart_formdata BOUNDARY = "!! test boundary !!" +BOUNDARY_BYTES = BOUNDARY.encode() -class TestIterfields(object): - def test_dict(self): - for fieldname, value in iter_fields(dict(a="b")): - assert (fieldname, value) == ("a", "b") - - assert list(sorted(iter_fields(dict(a="b", c="d")))) == [("a", "b"), ("c", "d")] - - def test_tuple_list(self): - for fieldname, value in iter_fields([("a", "b")]): - assert (fieldname, value) == ("a", "b") - - assert list(iter_fields([("a", "b"), ("c", "d")])) == [("a", "b"), ("c", "d")] - - -class TestMultipartEncoding(object): +class TestMultipartEncoding: @pytest.mark.parametrize( "fields", [dict(k="v", k2="v2"), [("k", "v"), ("k2", "v2")]] ) - def test_input_datastructures(self, fields): + def test_input_datastructures(self, fields: _TYPE_FIELDS) -> None: encoded, _ = encode_multipart_formdata(fields, boundary=BOUNDARY) - assert encoded.count(b(BOUNDARY)) == 3 + assert encoded.count(BOUNDARY_BYTES) == 3 @pytest.mark.parametrize( "fields", [ [("k", "v"), ("k2", "v2")], - [("k", b"v"), (u("k2"), b"v2")], - [("k", b"v"), (u("k2"), "v2")], + [("k", b"v"), ("k2", b"v2")], + [("k", b"v"), ("k2", "v2")], ], ) - def test_field_encoding(self, fields): + def test_field_encoding(self, fields: _TYPE_FIELDS) -> None: encoded, content_type = encode_multipart_formdata(fields, boundary=BOUNDARY) expected = ( - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b'Content-Disposition: form-data; name="k"\r\n' b"\r\n" b"v\r\n" - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b'Content-Disposition: form-data; name="k2"\r\n' b"\r\n" b"v2\r\n" - b"--" + b(BOUNDARY) + b"--\r\n" + b"--" + BOUNDARY_BYTES + b"--\r\n" ) assert encoded == expected assert content_type == "multipart/form-data; boundary=" + str(BOUNDARY) - def test_filename(self): + def test_filename(self) -> None: fields = [("k", ("somename", b"v"))] encoded, content_type = encode_multipart_formdata(fields, boundary=BOUNDARY) expected = ( - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b'Content-Disposition: form-data; name="k"; filename="somename"\r\n' b"Content-Type: application/octet-stream\r\n" b"\r\n" b"v\r\n" - b"--" + b(BOUNDARY) + b"--\r\n" + b"--" + BOUNDARY_BYTES + b"--\r\n" ) assert encoded == expected assert content_type == "multipart/form-data; boundary=" + str(BOUNDARY) - def test_textplain(self): + def test_textplain(self) -> None: fields = [("k", ("somefile.txt", b"v"))] encoded, content_type = encode_multipart_formdata(fields, boundary=BOUNDARY) expected = ( - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b'Content-Disposition: form-data; name="k"; filename="somefile.txt"\r\n' b"Content-Type: text/plain\r\n" b"\r\n" b"v\r\n" - b"--" + b(BOUNDARY) + b"--\r\n" + b"--" + BOUNDARY_BYTES + b"--\r\n" ) assert encoded == expected assert content_type == "multipart/form-data; boundary=" + str(BOUNDARY) - def test_explicit(self): + def test_explicit(self) -> None: fields = [("k", ("somefile.txt", b"v", "image/jpeg"))] encoded, content_type = encode_multipart_formdata(fields, boundary=BOUNDARY) expected = ( - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b'Content-Disposition: form-data; name="k"; filename="somefile.txt"\r\n' b"Content-Type: image/jpeg\r\n" b"\r\n" b"v\r\n" - b"--" + b(BOUNDARY) + b"--\r\n" + b"--" + BOUNDARY_BYTES + b"--\r\n" ) assert encoded == expected assert content_type == "multipart/form-data; boundary=" + str(BOUNDARY) - def test_request_fields(self): + def test_request_fields(self) -> None: fields = [ RequestField( "k", @@ -118,11 +106,11 @@ def test_request_fields(self): encoded, content_type = encode_multipart_formdata(fields, boundary=BOUNDARY) expected = ( - b"--" + b(BOUNDARY) + b"\r\n" + b"--" + BOUNDARY_BYTES + b"\r\n" b"Content-Type: image/jpeg\r\n" b"\r\n" b"v\r\n" - b"--" + b(BOUNDARY) + b"--\r\n" + b"--" + BOUNDARY_BYTES + b"--\r\n" ) assert encoded == expected diff --git a/test/test_no_ssl.py b/test/test_no_ssl.py index 7cf6260e49..e793f79189 100644 --- a/test/test_no_ssl.py +++ b/test/test_no_ssl.py @@ -5,83 +5,36 @@ * HTTPS requests must fail with an error that points at the ssl module """ +from __future__ import annotations + import sys +from test import ImportBlocker, ModuleStash import pytest - -class ImportBlocker(object): - """ - Block Imports - - To be placed on ``sys.meta_path``. This ensures that the modules - specified cannot be imported, even if they are a builtin. - """ - - def __init__(self, *namestoblock): - self.namestoblock = namestoblock - - def find_module(self, fullname, path=None): - if fullname in self.namestoblock: - return self - return None - - def load_module(self, fullname): - raise ImportError("import of {0} is blocked".format(fullname)) - - -class ModuleStash(object): - """ - Stashes away previously imported modules - - If we reimport a module the data from coverage is lost, so we reuse the old - modules - """ - - def __init__(self, namespace, modules=sys.modules): - self.namespace = namespace - self.modules = modules - self._data = {} - - def stash(self): - self._data[self.namespace] = self.modules.pop(self.namespace, None) - - for module in list(self.modules.keys()): - if module.startswith(self.namespace + "."): - self._data[module] = self.modules.pop(module) - - def pop(self): - self.modules.pop(self.namespace, None) - - for module in list(self.modules.keys()): - if module.startswith(self.namespace + "."): - self.modules.pop(module) - - self.modules.update(self._data) - - ssl_blocker = ImportBlocker("ssl", "_ssl") module_stash = ModuleStash("urllib3") -class TestWithoutSSL(object): +class TestWithoutSSL: @classmethod - def setup_class(cls): + def setup_class(cls) -> None: sys.modules.pop("ssl", None) sys.modules.pop("_ssl", None) module_stash.stash() sys.meta_path.insert(0, ssl_blocker) - def teardown_class(cls): + @classmethod + def teardown_class(cls) -> None: sys.meta_path.remove(ssl_blocker) module_stash.pop() class TestImportWithoutSSL(TestWithoutSSL): - def test_cannot_import_ssl(self): + def test_cannot_import_ssl(self) -> None: with pytest.raises(ImportError): import ssl # noqa: F401 - def test_import_urllib3(self): + def test_import_urllib3(self) -> None: import urllib3 # noqa: F401 diff --git a/test/test_poolmanager.py b/test/test_poolmanager.py index c54ad7ed9a..e67c11f39a 100644 --- a/test/test_poolmanager.py +++ b/test/test_poolmanager.py @@ -1,17 +1,29 @@ +from __future__ import annotations + +import gc import socket from test import resolvesLocalhostFQDN +from unittest import mock +from unittest.mock import MagicMock, patch import pytest from urllib3 import connection_from_url -from urllib3.exceptions import ClosedPoolError, LocationValueError -from urllib3.poolmanager import PoolKey, PoolManager, key_fn_by_scheme +from urllib3.connectionpool import HTTPSConnectionPool +from urllib3.exceptions import LocationValueError +from urllib3.poolmanager import ( + _DEFAULT_BLOCKSIZE, + PoolKey, + PoolManager, + key_fn_by_scheme, +) from urllib3.util import retry, timeout +from urllib3.util.url import Url -class TestPoolManager(object): - @resolvesLocalhostFQDN - def test_same_url(self): +class TestPoolManager: + @resolvesLocalhostFQDN() + def test_same_url(self) -> None: # Convince ourselves that normally we don't get the same object conn1 = connection_from_url("http://localhost:8081/foo") conn2 = connection_from_url("http://localhost:8081/bar") @@ -34,7 +46,7 @@ def test_same_url(self): assert conn1 != conn2 - def test_many_urls(self): + def test_many_urls(self) -> None: urls = [ "http://localhost:8081/foo", "http://www.google.com/mail", @@ -56,59 +68,36 @@ def test_many_urls(self): assert len(connections) == 5 - def test_manager_clear(self): + def test_manager_clear(self) -> None: p = PoolManager(5) - conn_pool = p.connection_from_url("http://google.com") + p.connection_from_url("http://google.com") assert len(p.pools) == 1 - conn = conn_pool._get_conn() - p.clear() assert len(p.pools) == 0 - with pytest.raises(ClosedPoolError): - conn_pool._get_conn() - - conn_pool._put_conn(conn) - - with pytest.raises(ClosedPoolError): - conn_pool._get_conn() - - assert len(p.pools) == 0 - @pytest.mark.parametrize("url", ["http://@", None]) - def test_nohost(self, url): + def test_nohost(self, url: str | None) -> None: p = PoolManager(5) with pytest.raises(LocationValueError): - p.connection_from_url(url=url) + p.connection_from_url(url=url) # type: ignore[arg-type] - def test_contextmanager(self): + def test_contextmanager(self) -> None: with PoolManager(1) as p: - conn_pool = p.connection_from_url("http://google.com") + p.connection_from_url("http://google.com") assert len(p.pools) == 1 - conn = conn_pool._get_conn() - - assert len(p.pools) == 0 - - with pytest.raises(ClosedPoolError): - conn_pool._get_conn() - - conn_pool._put_conn(conn) - - with pytest.raises(ClosedPoolError): - conn_pool._get_conn() assert len(p.pools) == 0 - def test_http_pool_key_fields(self): + def test_http_pool_key_fields(self) -> None: """Assert the HTTPPoolKey fields are honored when selecting a pool.""" connection_pool_kw = { "timeout": timeout.Timeout(3.14), "retries": retry.Retry(total=6, connect=2), "block": True, - "strict": True, "source_address": "127.0.0.1", + "blocksize": _DEFAULT_BLOCKSIZE + 1, } p = PoolManager() conn_pools = [ @@ -129,19 +118,19 @@ def test_http_pool_key_fields(self): ) assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_https_pool_key_fields(self): + def test_https_pool_key_fields(self) -> None: """Assert the HTTPSPoolKey fields are honored when selecting a pool.""" connection_pool_kw = { "timeout": timeout.Timeout(3.14), "retries": retry.Retry(total=6, connect=2), "block": True, - "strict": True, "source_address": "127.0.0.1", "key_file": "/root/totally_legit.key", "cert_file": "/root/totally_legit.crt", "cert_reqs": "CERT_REQUIRED", "ca_certs": "/root/path_to_pem", "ssl_version": "SSLv23_METHOD", + "blocksize": _DEFAULT_BLOCKSIZE + 1, } p = PoolManager() conn_pools = [ @@ -167,13 +156,13 @@ def test_https_pool_key_fields(self): assert all(pool in conn_pools for pool in dup_pools) assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_default_pool_key_funcs_copy(self): + def test_default_pool_key_funcs_copy(self) -> None: """Assert each PoolManager gets a copy of ``pool_keys_by_scheme``.""" p = PoolManager() assert p.key_fn_by_scheme == p.key_fn_by_scheme assert p.key_fn_by_scheme is not key_fn_by_scheme - def test_pools_keyed_with_from_host(self): + def test_pools_keyed_with_from_host(self) -> None: """Assert pools are still keyed correctly with connection_from_host.""" ssl_kw = { "key_file": "/root/totally_legit.key", @@ -182,7 +171,7 @@ def test_pools_keyed_with_from_host(self): "ca_certs": "/root/path_to_pem", "ssl_version": "SSLv23_METHOD", } - p = PoolManager(5, **ssl_kw) + p = PoolManager(5, **ssl_kw) # type: ignore[arg-type] conns = [p.connection_from_host("example.com", 443, scheme="https")] for k in ssl_kw: @@ -196,7 +185,7 @@ def test_pools_keyed_with_from_host(self): if i != j ) - def test_https_connection_from_url_case_insensitive(self): + def test_https_connection_from_url_case_insensitive(self) -> None: """Assert scheme case is ignored when pooling HTTPS connections.""" p = PoolManager() pool = p.connection_from_url("https://example.com/") @@ -206,7 +195,7 @@ def test_https_connection_from_url_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_https_connection_from_host_case_insensitive(self): + def test_https_connection_from_host_case_insensitive(self) -> None: """Assert scheme case is ignored when getting the https key class.""" p = PoolManager() pool = p.connection_from_host("example.com", scheme="https") @@ -216,7 +205,7 @@ def test_https_connection_from_host_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_https_connection_from_context_case_insensitive(self): + def test_https_connection_from_context_case_insensitive(self) -> None: """Assert scheme case is ignored when getting the https key class.""" p = PoolManager() context = {"scheme": "https", "host": "example.com", "port": "443"} @@ -228,7 +217,7 @@ def test_https_connection_from_context_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_http_connection_from_url_case_insensitive(self): + def test_http_connection_from_url_case_insensitive(self) -> None: """Assert scheme case is ignored when pooling HTTP connections.""" p = PoolManager() pool = p.connection_from_url("http://example.com/") @@ -238,7 +227,7 @@ def test_http_connection_from_url_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_http_connection_from_host_case_insensitive(self): + def test_http_connection_from_host_case_insensitive(self) -> None: """Assert scheme case is ignored when getting the https key class.""" p = PoolManager() pool = p.connection_from_host("example.com", scheme="http") @@ -248,16 +237,17 @@ def test_http_connection_from_host_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_assert_hostname_and_fingerprint_flag(self): + def test_assert_hostname_and_fingerprint_flag(self) -> None: """Assert that pool manager can accept hostname and fingerprint flags.""" fingerprint = "92:81:FE:85:F7:0C:26:60:EC:D6:B3:BF:93:CF:F9:71:CC:07:7D:0A" p = PoolManager(assert_hostname=True, assert_fingerprint=fingerprint) pool = p.connection_from_url("https://example.com/") assert 1 == len(p.pools) + assert isinstance(pool, HTTPSConnectionPool) assert pool.assert_hostname assert fingerprint == pool.assert_fingerprint - def test_http_connection_from_context_case_insensitive(self): + def test_http_connection_from_context_case_insensitive(self) -> None: """Assert scheme case is ignored when getting the https key class.""" p = PoolManager() context = {"scheme": "http", "host": "example.com", "port": "8080"} @@ -269,11 +259,40 @@ def test_http_connection_from_context_case_insensitive(self): assert pool is other_pool assert all(isinstance(key, PoolKey) for key in p.pools.keys()) - def test_custom_pool_key(self): + @patch("urllib3.poolmanager.PoolManager.connection_from_pool_key") + def test_connection_from_context_strict_param( + self, connection_from_pool_key: mock.MagicMock + ) -> None: + p = PoolManager() + context = { + "scheme": "http", + "host": "example.com", + "port": 8080, + "strict": True, + } + with pytest.warns(DeprecationWarning) as records: + p.connection_from_context(context) + + msg = ( + "The 'strict' parameter is no longer needed on Python 3+. " + "This will raise an error in urllib3 v2.1.0." + ) + record = records[0] + assert isinstance(record.message, Warning) + assert record.message.args[0] == msg + + _, kwargs = connection_from_pool_key.call_args + assert kwargs["request_context"] == { + "scheme": "http", + "host": "example.com", + "port": 8080, + } + + def test_custom_pool_key(self) -> None: """Assert it is possible to define a custom key function.""" p = PoolManager(10) - p.key_fn_by_scheme["http"] = lambda x: tuple(x["key"]) + p.key_fn_by_scheme["http"] = lambda x: tuple(x["key"]) # type: ignore[assignment] pool1 = p.connection_from_url( "http://example.com", pool_kwargs={"key": "value"} ) @@ -288,43 +307,39 @@ def test_custom_pool_key(self): assert pool1 is pool3 assert pool1 is not pool2 - def test_override_pool_kwargs_url(self): + def test_override_pool_kwargs_url(self) -> None: """Assert overriding pool kwargs works with connection_from_url.""" - p = PoolManager(strict=True) - pool_kwargs = {"strict": False, "retries": 100, "block": True} + p = PoolManager() + pool_kwargs = {"retries": 100, "block": True} default_pool = p.connection_from_url("http://example.com/") override_pool = p.connection_from_url( "http://example.com/", pool_kwargs=pool_kwargs ) - assert default_pool.strict assert retry.Retry.DEFAULT == default_pool.retries assert not default_pool.block - assert not override_pool.strict assert 100 == override_pool.retries assert override_pool.block - def test_override_pool_kwargs_host(self): + def test_override_pool_kwargs_host(self) -> None: """Assert overriding pool kwargs works with connection_from_host""" - p = PoolManager(strict=True) - pool_kwargs = {"strict": False, "retries": 100, "block": True} + p = PoolManager() + pool_kwargs = {"retries": 100, "block": True} default_pool = p.connection_from_host("example.com", scheme="http") override_pool = p.connection_from_host( "example.com", scheme="http", pool_kwargs=pool_kwargs ) - assert default_pool.strict assert retry.Retry.DEFAULT == default_pool.retries assert not default_pool.block - assert not override_pool.strict assert 100 == override_pool.retries assert override_pool.block - def test_pool_kwargs_socket_options(self): + def test_pool_kwargs_socket_options(self) -> None: """Assert passing socket options works with connection_from_host""" p = PoolManager(socket_options=[]) override_opts = [ @@ -341,34 +356,121 @@ def test_pool_kwargs_socket_options(self): assert default_pool.conn_kw["socket_options"] == [] assert override_pool.conn_kw["socket_options"] == override_opts - def test_merge_pool_kwargs(self): + def test_merge_pool_kwargs(self) -> None: """Assert _merge_pool_kwargs works in the happy case""" - p = PoolManager(strict=True) + p = PoolManager(retries=100) merged = p._merge_pool_kwargs({"new_key": "value"}) - assert {"strict": True, "new_key": "value"} == merged + assert {"retries": 100, "new_key": "value"} == merged - def test_merge_pool_kwargs_none(self): + def test_merge_pool_kwargs_none(self) -> None: """Assert false-y values to _merge_pool_kwargs result in defaults""" - p = PoolManager(strict=True) + p = PoolManager(retries=100) merged = p._merge_pool_kwargs({}) assert p.connection_pool_kw == merged merged = p._merge_pool_kwargs(None) assert p.connection_pool_kw == merged - def test_merge_pool_kwargs_remove_key(self): + def test_merge_pool_kwargs_remove_key(self) -> None: """Assert keys can be removed with _merge_pool_kwargs""" - p = PoolManager(strict=True) - merged = p._merge_pool_kwargs({"strict": None}) - assert "strict" not in merged + p = PoolManager(retries=100) + merged = p._merge_pool_kwargs({"retries": None}) + assert "retries" not in merged - def test_merge_pool_kwargs_invalid_key(self): + def test_merge_pool_kwargs_invalid_key(self) -> None: """Assert removing invalid keys with _merge_pool_kwargs doesn't break""" - p = PoolManager(strict=True) + p = PoolManager(retries=100) merged = p._merge_pool_kwargs({"invalid_key": None}) assert p.connection_pool_kw == merged - def test_pool_manager_no_url_absolute_form(self): + def test_pool_manager_no_url_absolute_form(self) -> None: """Valides we won't send a request with absolute form without a proxy""" - p = PoolManager(strict=True) - assert p._proxy_requires_url_absolute_form("http://example.com") is False - assert p._proxy_requires_url_absolute_form("https://example.com") is False + p = PoolManager() + assert p._proxy_requires_url_absolute_form(Url("http://example.com")) is False + assert p._proxy_requires_url_absolute_form(Url("https://example.com")) is False + + @pytest.mark.parametrize( + "input_blocksize,expected_blocksize", + [ + (_DEFAULT_BLOCKSIZE, _DEFAULT_BLOCKSIZE), + (None, _DEFAULT_BLOCKSIZE), + (8192, 8192), + ], + ) + def test_poolmanager_blocksize( + self, input_blocksize: int, expected_blocksize: int + ) -> None: + """Assert PoolManager sets blocksize properly""" + p = PoolManager() + + pool_blocksize = p.connection_from_url( + "http://example.com", {"blocksize": input_blocksize} + ) + assert pool_blocksize.conn_kw["blocksize"] == expected_blocksize + assert pool_blocksize._get_conn().blocksize == expected_blocksize + + @pytest.mark.parametrize( + "url", + [ + "[a::b%zone]", + "[a::b%25zone]", + "http://[a::b%zone]", + "http://[a::b%25zone]", + ], + ) + @patch("urllib3.util.connection.create_connection") + def test_e2e_connect_to_ipv6_scoped( + self, create_connection: MagicMock, url: str + ) -> None: + """Checks that IPv6 scoped addresses are properly handled end-to-end. + + This is not strictly speaking a pool manager unit test - this test + lives here in absence of a better code location for e2e/integration + tests. + """ + p = PoolManager() + conn_pool = p.connection_from_url(url) + conn = conn_pool._get_conn() + conn.connect() + + assert create_connection.call_args[0][0] == ("a::b%zone", 80) + + @patch("urllib3.connection.ssl_wrap_socket") + @patch("urllib3.util.connection.create_connection") + def test_e2e_connect_to_ipv6_scoped_tls( + self, create_connection: MagicMock, ssl_wrap_socket: MagicMock + ) -> None: + p = PoolManager() + conn_pool = p.connection_from_url( + "https://[a::b%zone]", pool_kwargs={"assert_hostname": False} + ) + conn = conn_pool._get_conn() + conn.connect() + + assert ssl_wrap_socket.call_args[1]["server_hostname"] == "a::b" + + def test_thread_safty(self) -> None: + pool_manager = PoolManager(num_pools=2) + + # thread 1 gets a pool for host x + pool_1 = pool_manager.connection_from_url("http://host_x:80/") + + # thread 2 gets a pool for host y + pool_2 = pool_manager.connection_from_url("http://host_y:80/") + + # thread 3 gets a pool for host z + pool_3 = pool_manager.connection_from_url("http://host_z:80") + + # None of the pools should be closed, since all of them are referenced. + assert pool_1.pool is not None + assert pool_2.pool is not None + assert pool_3.pool is not None + + conn_queue = pool_1.pool + assert conn_queue.qsize() > 0 + + # thread 1 stops. + del pool_1 + gc.collect() + + # Connection should be closed, because reference to pool_1 is gone. + assert conn_queue.qsize() == 0 diff --git a/test/test_proxymanager.py b/test/test_proxymanager.py index 7f1c396cce..140ca9fb67 100644 --- a/test/test_proxymanager.py +++ b/test/test_proxymanager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from urllib3.exceptions import MaxRetryError, NewConnectionError, ProxyError @@ -8,11 +10,11 @@ from .port_helpers import find_unused_port -class TestProxyManager(object): +class TestProxyManager: @pytest.mark.parametrize("proxy_scheme", ["http", "https"]) - def test_proxy_headers(self, proxy_scheme): + def test_proxy_headers(self, proxy_scheme: str) -> None: url = "http://pypi.org/project/urllib3/" - proxy_url = "{}://something:1234".format(proxy_scheme) + proxy_url = f"{proxy_scheme}://something:1234" with ProxyManager(proxy_url) as p: # Verify default headers default_headers = {"Accept": "*/*", "Host": "pypi.org"} @@ -39,19 +41,21 @@ def test_proxy_headers(self, proxy_scheme): assert headers == expected_headers - def test_default_port(self): + def test_default_port(self) -> None: with ProxyManager("http://something") as p: + assert p.proxy is not None assert p.proxy.port == 80 with ProxyManager("https://something") as p: + assert p.proxy is not None assert p.proxy.port == 443 - def test_invalid_scheme(self): + def test_invalid_scheme(self) -> None: with pytest.raises(AssertionError): ProxyManager("invalid://host/p") with pytest.raises(ValueError): ProxyManager("invalid://host/p") - def test_proxy_tunnel(self): + def test_proxy_tunnel(self) -> None: http_url = parse_url("http://example.com") https_url = parse_url("https://example.com") with ProxyManager("http://proxy:8080") as p: @@ -66,16 +70,18 @@ def test_proxy_tunnel(self): assert p._proxy_requires_url_absolute_form(http_url) assert p._proxy_requires_url_absolute_form(https_url) - def test_proxy_connect_retry(self): + def test_proxy_connect_retry(self) -> None: retry = Retry(total=None, connect=False) port = find_unused_port() - with ProxyManager("http://localhost:{}".format(port)) as p: + with ProxyManager(f"http://localhost:{port}") as p: with pytest.raises(ProxyError) as ei: p.urlopen("HEAD", url="http://localhost/", retries=retry) assert isinstance(ei.value.original_error, NewConnectionError) retry = Retry(total=None, connect=2) - with ProxyManager("http://localhost:{}".format(port)) as p: - with pytest.raises(MaxRetryError) as ei: + with ProxyManager(f"http://localhost:{port}") as p: + with pytest.raises(MaxRetryError) as ei1: p.urlopen("HEAD", url="http://localhost/", retries=retry) - assert isinstance(ei.value.reason.original_error, NewConnectionError) + assert ei1.value.reason is not None + assert isinstance(ei1.value.reason, ProxyError) + assert isinstance(ei1.value.reason.original_error, NewConnectionError) diff --git a/test/test_queue_monkeypatch.py b/test/test_queue_monkeypatch.py index f8420a0eb6..508136d1a5 100644 --- a/test/test_queue_monkeypatch.py +++ b/test/test_queue_monkeypatch.py @@ -1,11 +1,12 @@ -from __future__ import absolute_import +from __future__ import annotations + +import queue +from unittest import mock -import mock import pytest from urllib3 import HTTPConnectionPool from urllib3.exceptions import EmptyPoolError -from urllib3.packages.six.moves import queue class BadError(Exception): @@ -13,16 +14,14 @@ class BadError(Exception): This should not be raised. """ - pass - -class TestMonkeypatchResistance(object): +class TestMonkeypatchResistance: """ Test that connection pool works even with a monkey patched Queue module, see obspy/obspy#1599, psf/requests#3742, urllib3/urllib3#1061. """ - def test_queue_monkeypatching(self): + def test_queue_monkeypatching(self) -> None: with mock.patch.object(queue, "Empty", BadError): with HTTPConnectionPool(host="localhost", block=True) as http: http._get_conn() diff --git a/test/test_response.py b/test/test_response.py index 03f2780c75..1a296b6844 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,19 +1,23 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations import contextlib -import re +import http.client as httplib import socket import ssl +import sys +import typing import zlib from base64 import b64decode +from http.client import IncompleteRead as httplib_IncompleteRead from io import BufferedReader, BytesIO, TextIOWrapper -from test import onlyBrotlipy +from test import onlyBrotli, onlyZstd +from unittest import mock -import mock import pytest -import six +from urllib3 import HTTPHeaderDict from urllib3.exceptions import ( + BodyNotHttplibCompatible, DecodeError, IncompleteRead, InvalidChunkLength, @@ -21,15 +25,69 @@ ProtocolError, ResponseNotChunked, SSLError, - httplib_IncompleteRead, ) -from urllib3.packages.six.moves import http_client as httplib -from urllib3.response import HTTPResponse, brotli +from urllib3.response import ( # type: ignore[attr-defined] + BaseHTTPResponse, + BytesQueueBuffer, + HTTPResponse, + brotli, + zstd, +) from urllib3.util.response import is_fp_closed from urllib3.util.retry import RequestHistory, Retry + +class TestBytesQueueBuffer: + def test_single_chunk(self) -> None: + buffer = BytesQueueBuffer() + assert len(buffer) == 0 + with pytest.raises(RuntimeError, match="buffer is empty"): + assert buffer.get(10) + + buffer.put(b"foo") + with pytest.raises(ValueError, match="n should be > 0"): + buffer.get(-1) + + assert buffer.get(1) == b"f" + assert buffer.get(2) == b"oo" + with pytest.raises(RuntimeError, match="buffer is empty"): + assert buffer.get(10) + + def test_read_too_much(self) -> None: + buffer = BytesQueueBuffer() + buffer.put(b"foo") + assert buffer.get(100) == b"foo" + + def test_multiple_chunks(self) -> None: + buffer = BytesQueueBuffer() + buffer.put(b"foo") + buffer.put(b"bar") + buffer.put(b"baz") + assert len(buffer) == 9 + + assert buffer.get(1) == b"f" + assert len(buffer) == 8 + assert buffer.get(4) == b"ooba" + assert len(buffer) == 4 + assert buffer.get(4) == b"rbaz" + assert len(buffer) == 0 + + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="pytest-memray requires Python 3.8+" + ) + @pytest.mark.limit_memory("12.5 MB") # assert that we're not doubling memory usage + def test_memory_usage(self) -> None: + # Allocate 10 1MiB chunks + buffer = BytesQueueBuffer() + for i in range(10): + # This allocates 2MiB, putting the max at around 12MiB. Not sure why. + buffer.put(bytes(2**20)) + + assert len(buffer.get(10 * 2**20)) == 10 * 2**20 + + # A known random (i.e, not-too-compressible) payload generated with: -# "".join(random.choice(string.printable) for i in xrange(512)) +# "".join(random.choice(string.printable) for i in range(512)) # .encode("zlib").encode("base64") # Randomness in tests == bad, and fixing a seed may not be sufficient. ZLIB_PAYLOAD = b64decode( @@ -47,39 +105,57 @@ @pytest.fixture -def sock(): +def sock() -> typing.Generator[socket.socket, None, None]: s = socket.socket() yield s s.close() -class TestLegacyResponse(object): - def test_getheaders(self): +class TestLegacyResponse: + def test_getheaders(self) -> None: headers = {"host": "example.com"} r = HTTPResponse(headers=headers) - assert r.getheaders() == headers + with pytest.warns( + DeprecationWarning, + match=r"HTTPResponse.getheaders\(\) is deprecated", + ): + assert r.getheaders() == HTTPHeaderDict(headers) - def test_getheader(self): + def test_getheader(self) -> None: headers = {"host": "example.com"} r = HTTPResponse(headers=headers) - assert r.getheader("host") == "example.com" + with pytest.warns( + DeprecationWarning, + match=r"HTTPResponse.getheader\(\) is deprecated", + ): + assert r.getheader("host") == "example.com" -class TestResponse(object): - def test_cache_content(self): - r = HTTPResponse("foo") - assert r.data == "foo" - assert r._body == "foo" +class TestResponse: + def test_cache_content(self) -> None: + r = HTTPResponse(b"foo") + assert r._body == b"foo" + assert r.data == b"foo" + assert r._body == b"foo" + + def test_cache_content_preload_false(self) -> None: + fp = BytesIO(b"foo") + r = HTTPResponse(fp, preload_content=False) + + assert not r._body + assert r.data == b"foo" + assert r._body == b"foo" + assert r.data == b"foo" - def test_default(self): + def test_default(self) -> None: r = HTTPResponse() assert r.data is None - def test_none(self): - r = HTTPResponse(None) + def test_none(self) -> None: + r = HTTPResponse(None) # type: ignore[arg-type] assert r.data is None - def test_preload(self): + def test_preload(self) -> None: fp = BytesIO(b"foo") r = HTTPResponse(fp, preload_content=True) @@ -87,7 +163,7 @@ def test_preload(self): assert fp.tell() == len(b"foo") assert r.data == b"foo" - def test_no_preload(self): + def test_no_preload(self) -> None: fp = BytesIO(b"foo") r = HTTPResponse(fp, preload_content=False) @@ -96,12 +172,12 @@ def test_no_preload(self): assert r.data == b"foo" assert fp.tell() == len(b"foo") - def test_decode_bad_data(self): + def test_decode_bad_data(self) -> None: fp = BytesIO(b"\x00" * 10) with pytest.raises(DecodeError): HTTPResponse(fp, headers={"content-encoding": "deflate"}) - def test_reference_read(self): + def test_reference_read(self) -> None: fp = BytesIO(b"foo") r = HTTPResponse(fp, preload_content=False) @@ -110,7 +186,7 @@ def test_reference_read(self): assert r.read() == b"" assert r.read() == b"" - def test_decode_deflate(self): + def test_decode_deflate(self) -> None: data = zlib.compress(b"foo") fp = BytesIO(data) @@ -118,7 +194,7 @@ def test_decode_deflate(self): assert r.data == b"foo" - def test_decode_deflate_case_insensitve(self): + def test_decode_deflate_case_insensitve(self) -> None: data = zlib.compress(b"foo") fp = BytesIO(data) @@ -126,7 +202,7 @@ def test_decode_deflate_case_insensitve(self): assert r.data == b"foo" - def test_chunked_decoding_deflate(self): + def test_chunked_decoding_deflate(self) -> None: data = zlib.compress(b"foo") fp = BytesIO(data) @@ -134,17 +210,12 @@ def test_chunked_decoding_deflate(self): fp, headers={"content-encoding": "deflate"}, preload_content=False ) - assert r.read(3) == b"" - # Buffer in case we need to switch to the raw stream - assert r._decoder._data is not None assert r.read(1) == b"f" - # Now that we've decoded data, we just stream through the decoder - assert r._decoder._data is None assert r.read(2) == b"oo" assert r.read() == b"" assert r.read() == b"" - def test_chunked_decoding_deflate2(self): + def test_chunked_decoding_deflate2(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -154,15 +225,12 @@ def test_chunked_decoding_deflate2(self): fp, headers={"content-encoding": "deflate"}, preload_content=False ) - assert r.read(1) == b"" assert r.read(1) == b"f" - # Once we've decoded data, we just stream to the decoder; no buffering - assert r._decoder._data is None assert r.read(2) == b"oo" assert r.read() == b"" assert r.read() == b"" - def test_chunked_decoding_gzip(self): + def test_chunked_decoding_gzip(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -172,13 +240,12 @@ def test_chunked_decoding_gzip(self): fp, headers={"content-encoding": "gzip"}, preload_content=False ) - assert r.read(11) == b"" assert r.read(1) == b"f" assert r.read(2) == b"oo" assert r.read() == b"" assert r.read() == b"" - def test_decode_gzip_multi_member(self): + def test_decode_gzip_multi_member(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -189,12 +256,12 @@ def test_decode_gzip_multi_member(self): assert r.data == b"foofoofoo" - def test_decode_gzip_error(self): + def test_decode_gzip_error(self) -> None: fp = BytesIO(b"foo") with pytest.raises(DecodeError): HTTPResponse(fp, headers={"content-encoding": "gzip"}) - def test_decode_gzip_swallow_garbage(self): + def test_decode_gzip_swallow_garbage(self) -> None: # When data comes from multiple calls to read(), data after # the first zlib error (here triggered by garbage) should be # ignored. @@ -215,7 +282,7 @@ def test_decode_gzip_swallow_garbage(self): assert ret == b"foofoofoo" - def test_chunked_decoding_gzip_swallow_garbage(self): + def test_chunked_decoding_gzip_swallow_garbage(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -226,16 +293,16 @@ def test_chunked_decoding_gzip_swallow_garbage(self): assert r.data == b"foofoofoo" - @onlyBrotlipy() - def test_decode_brotli(self): + @onlyBrotli() + def test_decode_brotli(self) -> None: data = brotli.compress(b"foo") fp = BytesIO(data) r = HTTPResponse(fp, headers={"content-encoding": "br"}) assert r.data == b"foo" - @onlyBrotlipy() - def test_chunked_decoding_brotli(self): + @onlyBrotli() + def test_chunked_decoding_brotli(self) -> None: data = brotli.compress(b"foobarbaz") fp = BytesIO(data) @@ -248,13 +315,55 @@ def test_chunked_decoding_brotli(self): break assert ret == b"foobarbaz" - @onlyBrotlipy() - def test_decode_brotli_error(self): + @onlyBrotli() + def test_decode_brotli_error(self) -> None: fp = BytesIO(b"foo") with pytest.raises(DecodeError): HTTPResponse(fp, headers={"content-encoding": "br"}) - def test_multi_decoding_deflate_deflate(self): + @onlyZstd() + def test_decode_zstd(self) -> None: + data = zstd.compress(b"foo") + + fp = BytesIO(data) + r = HTTPResponse(fp, headers={"content-encoding": "zstd"}) + assert r.data == b"foo" + + @onlyZstd() + def test_chunked_decoding_zstd(self) -> None: + data = zstd.compress(b"foobarbaz") + + fp = BytesIO(data) + r = HTTPResponse( + fp, headers={"content-encoding": "zstd"}, preload_content=False + ) + + ret = b"" + + for _ in range(100): + ret += r.read(1) + if r.closed: + break + assert ret == b"foobarbaz" + + @onlyZstd() + @pytest.mark.parametrize("data", [b"foo", b"x" * 100]) + def test_decode_zstd_error(self, data: bytes) -> None: + fp = BytesIO(data) + + with pytest.raises(DecodeError): + HTTPResponse(fp, headers={"content-encoding": "zstd"}) + + @onlyZstd() + @pytest.mark.parametrize("data", [b"foo", b"x" * 100]) + def test_decode_zstd_incomplete(self, data: bytes) -> None: + data = zstd.compress(data) + fp = BytesIO(data[:-1]) + + with pytest.raises(DecodeError): + HTTPResponse(fp, headers={"content-encoding": "zstd"}) + + def test_multi_decoding_deflate_deflate(self) -> None: data = zlib.compress(zlib.compress(b"foo")) fp = BytesIO(data) @@ -262,7 +371,7 @@ def test_multi_decoding_deflate_deflate(self): assert r.data == b"foo" - def test_multi_decoding_deflate_gzip(self): + def test_multi_decoding_deflate_gzip(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(zlib.compress(b"foo")) data += compress.flush() @@ -272,7 +381,7 @@ def test_multi_decoding_deflate_gzip(self): assert r.data == b"foo" - def test_multi_decoding_gzip_gzip(self): + def test_multi_decoding_gzip_gzip(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -286,12 +395,47 @@ def test_multi_decoding_gzip_gzip(self): assert r.data == b"foo" - def test_body_blob(self): + def test_read_multi_decoding_deflate_deflate(self) -> None: + msg = b"foobarbaz" * 42 + data = zlib.compress(zlib.compress(msg)) + + fp = BytesIO(data) + r = HTTPResponse( + fp, headers={"content-encoding": "deflate, deflate"}, preload_content=False + ) + + assert r.read(3) == b"foo" + assert r.read(3) == b"bar" + assert r.read(3) == b"baz" + assert r.read(9) == b"foobarbaz" + assert r.read(9 * 3) == b"foobarbaz" * 3 + assert r.read(9 * 37) == b"foobarbaz" * 37 + assert r.read() == b"" + + def test_body_blob(self) -> None: resp = HTTPResponse(b"foo") assert resp.data == b"foo" assert resp.closed - def test_io(self, sock): + def test_base_io(self) -> None: + resp = BaseHTTPResponse( + status=200, + version=11, + reason=None, + decode_content=False, + request_url=None, + ) + + assert not resp.closed + assert not resp.readable() + assert not resp.writable() + + with pytest.raises(NotImplementedError): + resp.read() + with pytest.raises(NotImplementedError): + resp.close() + + def test_io(self, sock: socket.socket) -> None: fp = BytesIO(b"foo") resp = HTTPResponse(fp, preload_content=False) @@ -327,14 +471,15 @@ def test_io(self, sock): with pytest.raises(IOError): resp3.fileno() - def test_io_closed_consistently(self, sock): + def test_io_closed_consistently(self, sock: socket.socket) -> None: try: hlr = httplib.HTTPResponse(sock) - hlr.fp = BytesIO(b"foo") - hlr.chunked = 0 + hlr.fp = BytesIO(b"foo") # type: ignore[assignment] + hlr.chunked = 0 # type: ignore[assignment] hlr.length = 3 with HTTPResponse(hlr, preload_content=False) as resp: assert not resp.closed + assert resp._fp is not None assert not resp._fp.isclosed() assert not is_fp_closed(resp._fp) assert not resp.isclosed() @@ -346,10 +491,10 @@ def test_io_closed_consistently(self, sock): finally: hlr.close() - def test_io_bufferedreader(self): + def test_io_bufferedreader(self) -> None: fp = BytesIO(b"foo") resp = HTTPResponse(fp, preload_content=False) - br = BufferedReader(resp) + br = BufferedReader(resp) # type: ignore[arg-type] assert br.read() == b"foo" @@ -360,14 +505,13 @@ def test_io_bufferedreader(self): # https://github.com/urllib3/urllib3/issues/1305 fp = BytesIO(b"hello\nworld") resp = HTTPResponse(fp, preload_content=False) - with pytest.raises(ValueError) as ctx: - list(BufferedReader(resp)) - assert str(ctx.value) == "readline of closed file" + with pytest.raises(ValueError, match="readline of closed file"): + list(BufferedReader(resp)) # type: ignore[arg-type] b = b"fooandahalf" fp = BytesIO(b) resp = HTTPResponse(fp, preload_content=False) - br = BufferedReader(resp, 5) + br = BufferedReader(resp, 5) # type: ignore[arg-type] br.read(1) # sets up the buffer, reading 5 assert len(fp.read()) == (len(b) - 5) @@ -377,10 +521,10 @@ def test_io_bufferedreader(self): while not br.closed: br.read(5) - def test_io_not_autoclose_bufferedreader(self): + def test_io_not_autoclose_bufferedreader(self) -> None: fp = BytesIO(b"hello\nworld") resp = HTTPResponse(fp, preload_content=False, auto_close=False) - reader = BufferedReader(resp) + reader = BufferedReader(resp) # type: ignore[arg-type] assert list(reader) == [b"hello\n", b"world"] assert not reader.closed @@ -391,16 +535,15 @@ def test_io_not_autoclose_bufferedreader(self): reader.close() assert reader.closed assert resp.closed - with pytest.raises(ValueError) as ctx: + with pytest.raises(ValueError, match="readline of closed file"): next(reader) - assert str(ctx.value) == "readline of closed file" - def test_io_textiowrapper(self): + def test_io_textiowrapper(self) -> None: fp = BytesIO(b"\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f") resp = HTTPResponse(fp, preload_content=False) - br = TextIOWrapper(resp, encoding="utf8") + br = TextIOWrapper(resp, encoding="utf8") # type: ignore[arg-type] - assert br.read() == u"äöüß" + assert br.read() == "äöüß" br.close() assert resp.closed @@ -411,25 +554,16 @@ def test_io_textiowrapper(self): b"\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f\n\xce\xb1\xce\xb2\xce\xb3\xce\xb4" ) resp = HTTPResponse(fp, preload_content=False) - with pytest.raises(ValueError) as ctx: - if six.PY2: - # py2's implementation of TextIOWrapper requires `read1` - # method which is provided by `BufferedReader` wrapper - resp = BufferedReader(resp) - list(TextIOWrapper(resp)) - assert re.match("I/O operation on closed file.?", str(ctx.value)) - - def test_io_not_autoclose_textiowrapper(self): + with pytest.raises(ValueError, match="I/O operation on closed file.?"): + list(TextIOWrapper(resp)) # type: ignore[arg-type] + + def test_io_not_autoclose_textiowrapper(self) -> None: fp = BytesIO( b"\xc3\xa4\xc3\xb6\xc3\xbc\xc3\x9f\n\xce\xb1\xce\xb2\xce\xb3\xce\xb4" ) resp = HTTPResponse(fp, preload_content=False, auto_close=False) - if six.PY2: - # py2's implementation of TextIOWrapper requires `read1` - # method which is provided by `BufferedReader` wrapper - resp = BufferedReader(resp) - reader = TextIOWrapper(resp, encoding="utf8") - assert list(reader) == [u"äöüß\n", u"αβγδ"] + reader = TextIOWrapper(resp, encoding="utf8") # type: ignore[arg-type] + assert list(reader) == ["äöüß\n", "αβγδ"] assert not reader.closed assert not resp.closed @@ -439,11 +573,50 @@ def test_io_not_autoclose_textiowrapper(self): reader.close() assert reader.closed assert resp.closed - with pytest.raises(ValueError) as ctx: + with pytest.raises(ValueError, match="I/O operation on closed file.?"): next(reader) - assert re.match("I/O operation on closed file.?", str(ctx.value)) - def test_streaming(self): + def test_read_with_illegal_mix_decode_toggle(self) -> None: + data = zlib.compress(b"foo") + + fp = BytesIO(data) + + resp = HTTPResponse( + fp, headers={"content-encoding": "deflate"}, preload_content=False + ) + + assert resp.read(1) == b"f" + + with pytest.raises( + RuntimeError, + match=( + r"Calling read\(decode_content=False\) is not supported after " + r"read\(decode_content=True\) was called" + ), + ): + resp.read(1, decode_content=False) + + with pytest.raises( + RuntimeError, + match=( + r"Calling read\(decode_content=False\) is not supported after " + r"read\(decode_content=True\) was called" + ), + ): + resp.read(decode_content=False) + + def test_read_with_mix_decode_toggle(self) -> None: + data = zlib.compress(b"foo") + + fp = BytesIO(data) + + resp = HTTPResponse( + fp, headers={"content-encoding": "deflate"}, preload_content=False + ) + assert resp.read(2, decode_content=False) is not None + assert resp.read(1, decode_content=True) == b"f" + + def test_streaming(self) -> None: fp = BytesIO(b"foo") resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2, decode_content=False) @@ -453,7 +626,7 @@ def test_streaming(self): with pytest.raises(StopIteration): next(stream) - def test_streaming_tell(self): + def test_streaming_tell(self) -> None: fp = BytesIO(b"foo") resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2, decode_content=False) @@ -471,7 +644,7 @@ def test_streaming_tell(self): with pytest.raises(StopIteration): next(stream) - def test_gzipped_streaming(self): + def test_gzipped_streaming(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -482,12 +655,12 @@ def test_gzipped_streaming(self): ) stream = resp.stream(2) - assert next(stream) == b"f" - assert next(stream) == b"oo" + assert next(stream) == b"fo" + assert next(stream) == b"o" with pytest.raises(StopIteration): next(stream) - def test_gzipped_streaming_tell(self): + def test_gzipped_streaming_tell(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) uncompressed_data = b"foo" data = compress.compress(uncompressed_data) @@ -508,10 +681,11 @@ def test_gzipped_streaming_tell(self): with pytest.raises(StopIteration): next(stream) - def test_deflate_streaming_tell_intermediate_point(self): + def test_deflate_streaming_tell_intermediate_point(self) -> None: # Ensure that ``tell()`` returns the correct number of bytes when # part-way through streaming compressed content. NUMBER_OF_READS = 10 + PART_SIZE = 64 class MockCompressedDataReading(BytesIO): """ @@ -519,7 +693,7 @@ class MockCompressedDataReading(BytesIO): calls to ``read``. """ - def __init__(self, payload, payload_part_size): + def __init__(self, payload: bytes, payload_part_size: int) -> None: self.payloads = [ payload[i * payload_part_size : (i + 1) * payload_part_size] for i in range(NUMBER_OF_READS + 1) @@ -527,7 +701,7 @@ def __init__(self, payload, payload_part_size): assert b"".join(self.payloads) == payload - def read(self, _): + def read(self, _: int) -> bytes: # type: ignore[override] # Amount is unused. if len(self.payloads) > 0: return self.payloads.pop(0) @@ -540,7 +714,7 @@ def read(self, _): resp = HTTPResponse( fp, headers={"content-encoding": "deflate"}, preload_content=False ) - stream = resp.stream() + stream = resp.stream(PART_SIZE) parts_positions = [(part, resp.tell()) for part in stream] end_of_stream = resp.tell() @@ -555,13 +729,29 @@ def read(self, _): assert uncompressed_data == payload # Check that the positions in the stream are correct - expected = [(i + 1) * payload_part_size for i in range(NUMBER_OF_READS)] - assert expected == list(positions) + # It is difficult to determine programatically what the positions + # returned by `tell` will be because the `HTTPResponse.read` method may + # call socket `read` a couple of times if it doesn't have enough data + # in the buffer or not call socket `read` at all if it has enough. All + # this depends on the message, how it was compressed, what is + # `PART_SIZE` and `payload_part_size`. + # So for simplicity the expected values are hardcoded. + expected = (92, 184, 230, 276, 322, 368, 414, 460) + assert expected == positions # Check that the end of the stream is in the correct place assert len(ZLIB_PAYLOAD) == end_of_stream - def test_deflate_streaming(self): + # Check that all parts have expected length + expected_last_part_size = len(uncompressed_data) % PART_SIZE + whole_parts = len(uncompressed_data) // PART_SIZE + if expected_last_part_size == 0: + expected_lengths = [PART_SIZE] * whole_parts + else: + expected_lengths = [PART_SIZE] * whole_parts + [expected_last_part_size] + assert expected_lengths == [len(part) for part in parts] + + def test_deflate_streaming(self) -> None: data = zlib.compress(b"foo") fp = BytesIO(data) @@ -570,12 +760,12 @@ def test_deflate_streaming(self): ) stream = resp.stream(2) - assert next(stream) == b"f" - assert next(stream) == b"oo" + assert next(stream) == b"fo" + assert next(stream) == b"o" with pytest.raises(StopIteration): next(stream) - def test_deflate2_streaming(self): + def test_deflate2_streaming(self) -> None: compress = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS) data = compress.compress(b"foo") data += compress.flush() @@ -586,12 +776,12 @@ def test_deflate2_streaming(self): ) stream = resp.stream(2) - assert next(stream) == b"f" - assert next(stream) == b"oo" + assert next(stream) == b"fo" + assert next(stream) == b"o" with pytest.raises(StopIteration): next(stream) - def test_empty_stream(self): + def test_empty_stream(self) -> None: fp = BytesIO(b"") resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2, decode_content=False) @@ -599,19 +789,51 @@ def test_empty_stream(self): with pytest.raises(StopIteration): next(stream) - def test_length_no_header(self): + @pytest.mark.parametrize( + "preload_content, amt", + [(True, None), (False, None), (False, 10 * 2**20)], + ) + @pytest.mark.limit_memory("25 MB") + def test_buffer_memory_usage_decode_one_chunk( + self, preload_content: bool, amt: int + ) -> None: + content_length = 10 * 2**20 # 10 MiB + fp = BytesIO(zlib.compress(bytes(content_length))) + resp = HTTPResponse( + fp, + preload_content=preload_content, + headers={"content-encoding": "deflate"}, + ) + data = resp.data if preload_content else resp.read(amt) + assert len(data) == content_length + + @pytest.mark.parametrize( + "preload_content, amt", + [(True, None), (False, None), (False, 10 * 2**20)], + ) + @pytest.mark.limit_memory("10.5 MB") + def test_buffer_memory_usage_no_decoding( + self, preload_content: bool, amt: int + ) -> None: + content_length = 10 * 2**20 # 10 MiB + fp = BytesIO(bytes(content_length)) + resp = HTTPResponse(fp, preload_content=preload_content, decode_content=False) + data = resp.data if preload_content else resp.read(amt) + assert len(data) == content_length + + def test_length_no_header(self) -> None: fp = BytesIO(b"12345") resp = HTTPResponse(fp, preload_content=False) assert resp.length_remaining is None - def test_length_w_valid_header(self): + def test_length_w_valid_header(self) -> None: headers = {"content-length": "5"} fp = BytesIO(b"12345") resp = HTTPResponse(fp, headers=headers, preload_content=False) assert resp.length_remaining == 5 - def test_length_w_bad_header(self): + def test_length_w_bad_header(self) -> None: garbage = {"content-length": "foo"} fp = BytesIO(b"12345") @@ -622,7 +844,7 @@ def test_length_w_bad_header(self): resp = HTTPResponse(fp, headers=garbage, preload_content=False) assert resp.length_remaining is None - def test_length_when_chunked(self): + def test_length_when_chunked(self) -> None: # This is expressly forbidden in RFC 7230 sec 3.3.2 # We fall back to chunked in this case and try to # handle response ignoring content length. @@ -632,7 +854,7 @@ def test_length_when_chunked(self): resp = HTTPResponse(fp, headers=headers, preload_content=False) assert resp.length_remaining is None - def test_length_with_multiple_content_lengths(self): + def test_length_with_multiple_content_lengths(self) -> None: headers = {"content-length": "5, 5, 5"} garbage = {"content-length": "5, 42"} fp = BytesIO(b"abcde") @@ -643,7 +865,7 @@ def test_length_with_multiple_content_lengths(self): with pytest.raises(InvalidHeader): HTTPResponse(fp, headers=garbage, preload_content=False) - def test_length_after_read(self): + def test_length_after_read(self) -> None: headers = {"content-length": "5"} # Test no defined length @@ -665,27 +887,29 @@ def test_length_after_read(self): next(data) assert resp.length_remaining == 3 - def test_mock_httpresponse_stream(self): + def test_mock_httpresponse_stream(self) -> None: # Mock out a HTTP Request that does enough to make it through urllib3's # read() and close() calls, and also exhausts and underlying file # object. - class MockHTTPRequest(object): - self.fp = None + class MockHTTPRequest: + def __init__(self) -> None: + self.fp: BytesIO | None = None - def read(self, amt): + def read(self, amt: int) -> bytes: + assert self.fp is not None data = self.fp.read(amt) if not data: self.fp = None return data - def close(self): + def close(self) -> None: self.fp = None bio = BytesIO(b"foo") fp = MockHTTPRequest() fp.fp = bio - resp = HTTPResponse(fp, preload_content=False) + resp = HTTPResponse(fp, preload_content=False) # type: ignore[arg-type] stream = resp.stream(2) assert next(stream) == b"fo" @@ -693,11 +917,11 @@ def close(self): with pytest.raises(StopIteration): next(stream) - def test_mock_transfer_encoding_chunked(self): + def test_mock_transfer_encoding_chunked(self) -> None: stream = [b"fo", b"o", b"bar"] fp = MockChunkedEncodingResponse(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] resp = HTTPResponse( r, preload_content=False, headers={"transfer-encoding": "chunked"} ) @@ -705,10 +929,10 @@ def test_mock_transfer_encoding_chunked(self): for i, c in enumerate(resp.stream()): assert c == stream[i] - def test_mock_gzipped_transfer_encoding_chunked_decoded(self): + def test_mock_gzipped_transfer_encoding_chunked_decoded(self) -> None: """Show that we can decode the gzipped and chunked body.""" - def stream(): + def stream() -> typing.Generator[bytes, None, None]: # Set up a generator to chunk the gzipped body compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foobar") @@ -717,8 +941,8 @@ def stream(): yield data[i : i + 2] fp = MockChunkedEncodingResponse(list(stream())) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] headers = {"transfer-encoding": "chunked", "content-encoding": "gzip"} resp = HTTPResponse(r, preload_content=False, headers=headers) @@ -728,11 +952,11 @@ def stream(): assert b"foobar" == data - def test_mock_transfer_encoding_chunked_custom_read(self): + def test_mock_transfer_encoding_chunked_custom_read(self) -> None: stream = [b"foooo", b"bbbbaaaaar"] fp = MockChunkedEncodingResponse(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -742,11 +966,11 @@ def test_mock_transfer_encoding_chunked_custom_read(self): response = list(resp.read_chunked(2)) assert expected_response == response - def test_mock_transfer_encoding_chunked_unlmtd_read(self): + def test_mock_transfer_encoding_chunked_unlmtd_read(self) -> None: stream = [b"foooo", b"bbbbaaaaar"] fp = MockChunkedEncodingResponse(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -754,14 +978,23 @@ def test_mock_transfer_encoding_chunked_unlmtd_read(self): ) assert stream == list(resp.read_chunked()) - def test_read_not_chunked_response_as_chunks(self): + def test_read_not_chunked_response_as_chunks(self) -> None: fp = BytesIO(b"foo") resp = HTTPResponse(fp, preload_content=False) r = resp.read_chunked() with pytest.raises(ResponseNotChunked): next(r) - def test_buggy_incomplete_read(self): + def test_read_chunked_not_supported(self) -> None: + fp = BytesIO(b"foo") + resp = HTTPResponse( + fp, preload_content=False, headers={"transfer-encoding": "chunked"} + ) + r = resp.read_chunked() + with pytest.raises(BodyNotHttplibCompatible): + next(r) + + def test_buggy_incomplete_read(self) -> None: # Simulate buggy versions of Python (<2.7.4) # See http://bugs.python.org/issue16298 content_length = 1337 @@ -777,14 +1010,14 @@ def test_buggy_incomplete_read(self): orig_ex = ctx.value.args[1] assert isinstance(orig_ex, IncompleteRead) - assert orig_ex.partial == 0 + assert orig_ex.partial == 0 # type: ignore[comparison-overlap] assert orig_ex.expected == content_length - def test_incomplete_chunk(self): + def test_incomplete_chunk(self) -> None: stream = [b"foooo", b"bbbbaaaaar"] fp = MockChunkedIncompleteRead(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -796,11 +1029,11 @@ def test_incomplete_chunk(self): orig_ex = ctx.value.args[1] assert isinstance(orig_ex, httplib_IncompleteRead) - def test_invalid_chunk_length(self): + def test_invalid_chunk_length(self) -> None: stream = [b"foooo", b"bbbbaaaaar"] fp = MockChunkedInvalidChunkLength(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -810,14 +1043,19 @@ def test_invalid_chunk_length(self): next(resp.read_chunked()) orig_ex = ctx.value.args[1] + msg = ( + "(\"Connection broken: InvalidChunkLength(got length b'ZZZ\\\\r\\\\n', 0 bytes read)\", " + "InvalidChunkLength(got length b'ZZZ\\r\\n', 0 bytes read))" + ) + assert str(ctx.value) == msg assert isinstance(orig_ex, InvalidChunkLength) - assert orig_ex.length == six.b(fp.BAD_LENGTH_LINE) + assert orig_ex.length == fp.BAD_LENGTH_LINE.encode() - def test_chunked_response_without_crlf_on_end(self): + def test_chunked_response_without_crlf_on_end(self) -> None: stream = [b"foo", b"bar", b"baz"] fp = MockChunkedEncodingWithoutCRLFOnEnd(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -825,11 +1063,11 @@ def test_chunked_response_without_crlf_on_end(self): ) assert stream == list(resp.stream()) - def test_chunked_response_with_extensions(self): + def test_chunked_response_with_extensions(self) -> None: stream = [b"foo", b"bar"] fp = MockChunkedEncodingWithExtensions(stream) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -837,8 +1075,8 @@ def test_chunked_response_with_extensions(self): ) assert stream == list(resp.stream()) - def test_chunked_head_response(self): - r = httplib.HTTPResponse(MockSock, method="HEAD") + def test_chunked_head_response(self) -> None: + r = httplib.HTTPResponse(MockSock, method="HEAD") # type: ignore[arg-type] r.chunked = True r.chunk_left = None resp = HTTPResponse( @@ -849,19 +1087,19 @@ def test_chunked_head_response(self): ) assert resp.chunked is True - resp.supports_chunked_reads = lambda: True - resp.release_conn = mock.Mock() + setattr(resp, "supports_chunked_reads", lambda: True) + setattr(resp, "release_conn", mock.Mock()) for _ in resp.stream(): continue - resp.release_conn.assert_called_once_with() + resp.release_conn.assert_called_once_with() # type: ignore[attr-defined] - def test_get_case_insensitive_headers(self): + def test_get_case_insensitive_headers(self) -> None: headers = {"host": "example.com"} r = HTTPResponse(headers=headers) assert r.headers.get("host") == "example.com" assert r.headers.get("Host") == "example.com" - def test_retries(self): + def test_retries(self) -> None: fp = BytesIO(b"") resp = HTTPResponse(fp) assert resp.retries is None @@ -869,16 +1107,24 @@ def test_retries(self): resp = HTTPResponse(fp, retries=retry) assert resp.retries == retry - def test_geturl(self): + def test_geturl(self) -> None: fp = BytesIO(b"") request_url = "https://example.com" resp = HTTPResponse(fp, request_url=request_url) assert resp.geturl() == request_url - def test_geturl_retries(self): + def test_url(self) -> None: + fp = BytesIO(b"") + request_url = "https://example.com" + resp = HTTPResponse(fp, request_url=request_url) + assert resp.url == request_url + resp.url = "https://anotherurl.com" + assert resp.url == "https://anotherurl.com" + + def test_geturl_retries(self) -> None: fp = BytesIO(b"") resp = HTTPResponse(fp, request_url="http://example.com") - request_histories = [ + request_histories = ( RequestHistory( method="GET", url="http://example.com", @@ -893,7 +1139,7 @@ def test_geturl_retries(self): status=301, redirect_location="https://www.example.com", ), - ] + ) retry = Retry(history=request_histories) resp = HTTPResponse(fp, retries=retry) assert resp.geturl() == "https://www.example.com" @@ -908,15 +1154,15 @@ def test_geturl_retries(self): (b"Hello\nworld\n\n\n!", [b"Hello\n", b"world\n", b"\n", b"\n", b"!"]), ], ) - def test__iter__(self, payload, expected_stream): + def test__iter__(self, payload: bytes, expected_stream: list[bytes]) -> None: actual_stream = [] for chunk in HTTPResponse(BytesIO(payload), preload_content=False): actual_stream.append(chunk) assert actual_stream == expected_stream - def test__iter__decode_content(self): - def stream(): + def test__iter__decode_content(self) -> None: + def stream() -> typing.Generator[bytes, None, None]: # Set up a generator to chunk the gzipped body compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) data = compress.compress(b"foo\nbar") @@ -925,8 +1171,8 @@ def stream(): yield data[i : i + 2] fp = MockChunkedEncodingResponse(list(stream())) - r = httplib.HTTPResponse(MockSock) - r.fp = fp + r = httplib.HTTPResponse(MockSock) # type: ignore[arg-type] + r.fp = fp # type: ignore[assignment] headers = {"transfer-encoding": "chunked", "content-encoding": "gzip"} resp = HTTPResponse(r, preload_content=False, headers=headers) @@ -936,13 +1182,13 @@ def stream(): assert b"foo\nbar" == data - def test_non_timeout_ssl_error_on_read(self): + def test_non_timeout_ssl_error_on_read(self) -> None: mac_error = ssl.SSLError( "SSL routines", "ssl3_get_record", "decryption failed or bad record mac" ) @contextlib.contextmanager - def make_bad_mac_fp(): + def make_bad_mac_fp() -> typing.Generator[BytesIO, None, None]: fp = BytesIO(b"") with mock.patch.object(fp, "read") as fp_read: # mac/decryption error @@ -961,8 +1207,8 @@ def make_bad_mac_fp(): assert e.value.args[0] == mac_error -class MockChunkedEncodingResponse(object): - def __init__(self, content): +class MockChunkedEncodingResponse: + def __init__(self, content: list[bytes]) -> None: """ content: collection of str, each str is a chunk in response """ @@ -972,13 +1218,12 @@ def __init__(self, content): self.cur_chunk = b"" self.chunks_exhausted = False - @staticmethod - def _encode_chunk(chunk): + def _encode_chunk(self, chunk: bytes) -> bytes: # In the general case, we can't decode the chunk to unicode - length = "%X\r\n" % len(chunk) + length = f"{len(chunk):X}\r\n" return length.encode() + chunk + b"\r\n" - def _pop_new_chunk(self): + def _pop_new_chunk(self) -> bytes: if self.chunks_exhausted: return b"" try: @@ -991,9 +1236,10 @@ def _pop_new_chunk(self): chunk = self._encode_chunk(chunk) if not isinstance(chunk, bytes): chunk = chunk.encode() + assert isinstance(chunk, bytes) return chunk - def pop_current_chunk(self, amt=-1, till_crlf=False): + def pop_current_chunk(self, amt: int = -1, till_crlf: bool = False) -> bytes: if amt > 0 and till_crlf: raise ValueError("Can't specify amt and till_crlf.") if len(self.cur_chunk) <= 0: @@ -1023,47 +1269,47 @@ def pop_current_chunk(self, amt=-1, till_crlf=False): self.cur_chunk = self.cur_chunk[amt:] return chunk_part - def readline(self): + def readline(self) -> bytes: return self.pop_current_chunk(till_crlf=True) - def read(self, amt=-1): + def read(self, amt: int = -1) -> bytes: return self.pop_current_chunk(amt) - def flush(self): + def flush(self) -> None: # Python 3 wants this method. pass - def close(self): + def close(self) -> None: self.closed = True class MockChunkedIncompleteRead(MockChunkedEncodingResponse): - def _encode_chunk(self, chunk): - return "9999\r\n%s\r\n" % chunk.decode() + def _encode_chunk(self, chunk: bytes) -> bytes: + return f"9999\r\n{chunk.decode()}\r\n".encode() class MockChunkedInvalidChunkLength(MockChunkedEncodingResponse): BAD_LENGTH_LINE = "ZZZ\r\n" - def _encode_chunk(self, chunk): - return "%s%s\r\n" % (self.BAD_LENGTH_LINE, chunk.decode()) + def _encode_chunk(self, chunk: bytes) -> bytes: + return f"{self.BAD_LENGTH_LINE}{chunk.decode()}\r\n".encode() class MockChunkedEncodingWithoutCRLFOnEnd(MockChunkedEncodingResponse): - def _encode_chunk(self, chunk): - return "%X\r\n%s%s" % ( + def _encode_chunk(self, chunk: bytes) -> bytes: + return "{:X}\r\n{}{}".format( len(chunk), chunk.decode(), "\r\n" if len(chunk) > 0 else "", - ) + ).encode() class MockChunkedEncodingWithExtensions(MockChunkedEncodingResponse): - def _encode_chunk(self, chunk): - return "%X;asd=qwe\r\n%s\r\n" % (len(chunk), chunk.decode()) + def _encode_chunk(self, chunk: bytes) -> bytes: + return f"{len(chunk):X};asd=qwe\r\n{chunk.decode()}\r\n".encode() -class MockSock(object): +class MockSock: @classmethod - def makefile(cls, *args, **kwargs): + def makefile(cls, *args: typing.Any, **kwargs: typing.Any) -> None: return diff --git a/test/test_retry.py b/test/test_retry.py index cc36089796..4163a06081 100644 --- a/test/test_retry.py +++ b/test/test_retry.py @@ -1,6 +1,9 @@ -import warnings +from __future__ import annotations -import mock +from test import DUMMY_POOL +from unittest import mock + +import freezegun # type: ignore[import] import pytest from urllib3.exceptions import ( @@ -11,22 +14,13 @@ ResponseError, SSLError, ) -from urllib3.packages import six -from urllib3.packages.six.moves import xrange from urllib3.response import HTTPResponse from urllib3.util.retry import RequestHistory, Retry -@pytest.fixture(scope="function", autouse=True) -def no_retry_deprecations(): - with warnings.catch_warnings(record=True) as w: - yield - assert len([str(x.message) for x in w if "Retry" in str(x.message)]) == 0 - - -class TestRetry(object): - def test_string(self): - """ Retry string representation looks the way we expect """ +class TestRetry: + def test_string(self) -> None: + """Retry string representation looks the way we expect""" retry = Retry() assert ( str(retry) @@ -39,7 +33,7 @@ def test_string(self): == "Retry(total=7, connect=None, read=None, redirect=None, status=None)" ) - def test_retry_both_specified(self): + def test_retry_both_specified(self) -> None: """Total can win if it's lower than the connect value""" error = ConnectTimeoutError() retry = Retry(connect=3, total=2) @@ -49,8 +43,8 @@ def test_retry_both_specified(self): retry.increment(error=error) assert e.value.reason == error - def test_retry_higher_total_loses(self): - """ A lower connect timeout than the total is honored """ + def test_retry_higher_total_loses(self) -> None: + """A lower connect timeout than the total is honored""" error = ConnectTimeoutError() retry = Retry(connect=2, total=3) retry = retry.increment(error=error) @@ -58,17 +52,17 @@ def test_retry_higher_total_loses(self): with pytest.raises(MaxRetryError): retry.increment(error=error) - def test_retry_higher_total_loses_vs_read(self): - """ A lower read timeout than the total is honored """ - error = ReadTimeoutError(None, "/", "read timed out") + def test_retry_higher_total_loses_vs_read(self) -> None: + """A lower read timeout than the total is honored""" + error = ReadTimeoutError(DUMMY_POOL, "/", "read timed out") retry = Retry(read=2, total=3) retry = retry.increment(method="GET", error=error) retry = retry.increment(method="GET", error=error) with pytest.raises(MaxRetryError): retry.increment(method="GET", error=error) - def test_retry_total_none(self): - """ if Total is none, connect error should take precedence """ + def test_retry_total_none(self) -> None: + """if Total is none, connect error should take precedence""" error = ConnectTimeoutError() retry = Retry(connect=2, total=None) retry = retry.increment(error=error) @@ -77,15 +71,15 @@ def test_retry_total_none(self): retry.increment(error=error) assert e.value.reason == error - error = ReadTimeoutError(None, "/", "read timed out") + timeout_error = ReadTimeoutError(DUMMY_POOL, "/", "read timed out") retry = Retry(connect=2, total=None) - retry = retry.increment(method="GET", error=error) - retry = retry.increment(method="GET", error=error) - retry = retry.increment(method="GET", error=error) + retry = retry.increment(method="GET", error=timeout_error) + retry = retry.increment(method="GET", error=timeout_error) + retry = retry.increment(method="GET", error=timeout_error) assert not retry.is_exhausted() - def test_retry_default(self): - """ If no value is specified, should retry connects 3 times """ + def test_retry_default(self) -> None: + """If no value is specified, should retry connects 3 times""" retry = Retry() assert retry.total == 10 assert retry.connect is None @@ -106,8 +100,8 @@ def test_retry_default(self): assert Retry(0).raise_on_redirect assert not Retry(False).raise_on_redirect - def test_retry_other(self): - """ If an unexpected error is raised, should retry other times """ + def test_retry_other(self) -> None: + """If an unexpected error is raised, should retry other times""" other_error = SSLError() retry = Retry(connect=1) retry = retry.increment(error=other_error) @@ -120,28 +114,26 @@ def test_retry_other(self): retry.increment(error=other_error) assert e.value.reason == other_error - def test_retry_read_zero(self): - """ No second chances on read timeouts, by default """ - error = ReadTimeoutError(None, "/", "read timed out") + def test_retry_read_zero(self) -> None: + """No second chances on read timeouts, by default""" + error = ReadTimeoutError(DUMMY_POOL, "/", "read timed out") retry = Retry(read=0) with pytest.raises(MaxRetryError) as e: retry.increment(method="GET", error=error) assert e.value.reason == error - def test_status_counter(self): + def test_status_counter(self) -> None: resp = HTTPResponse(status=400) retry = Retry(status=2) retry = retry.increment(response=resp) retry = retry.increment(response=resp) - with pytest.raises(MaxRetryError) as e: + msg = ResponseError.SPECIFIC_ERROR.format(status_code=400) + with pytest.raises(MaxRetryError, match=msg): retry.increment(response=resp) - assert str(e.value.reason) == ResponseError.SPECIFIC_ERROR.format( - status_code=400 - ) - def test_backoff(self): - """ Backoff is computed correctly """ - max_backoff = Retry.BACKOFF_MAX + def test_backoff(self) -> None: + """Backoff is computed correctly""" + max_backoff = Retry.DEFAULT_BACKOFF_MAX retry = Retry(total=100, backoff_factor=0.2) assert retry.get_backoff_time() == 0 # First request @@ -160,19 +152,72 @@ def test_backoff(self): retry = retry.increment(method="GET") assert retry.get_backoff_time() == 1.6 - for _ in xrange(10): + for _ in range(10): retry = retry.increment(method="GET") assert retry.get_backoff_time() == max_backoff - def test_zero_backoff(self): + def test_configurable_backoff_max(self) -> None: + """Configurable backoff is computed correctly""" + max_backoff = 1 + + retry = Retry(total=100, backoff_factor=0.2, backoff_max=max_backoff) + assert retry.get_backoff_time() == 0 # First request + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == 0 # First retry + + retry = retry.increment(method="GET") + assert retry.backoff_factor == 0.2 + assert retry.total == 98 + assert retry.get_backoff_time() == 0.4 # Start backoff + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == 0.8 + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == max_backoff + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == max_backoff + + def test_backoff_jitter(self) -> None: + """Backoff with jitter is computed correctly""" + max_backoff = 1 + jitter = 0.4 + retry = Retry( + total=100, + backoff_factor=0.2, + backoff_max=max_backoff, + backoff_jitter=jitter, + ) + assert retry.get_backoff_time() == 0 # First request + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == 0 # First retry + + retry = retry.increment(method="GET") + assert retry.backoff_factor == 0.2 + assert retry.total == 98 + assert 0.4 <= retry.get_backoff_time() <= 0.8 # Start backoff + + retry = retry.increment(method="GET") + assert 0.8 <= retry.get_backoff_time() <= max_backoff + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == max_backoff + + retry = retry.increment(method="GET") + assert retry.get_backoff_time() == max_backoff + + def test_zero_backoff(self) -> None: retry = Retry() assert retry.get_backoff_time() == 0 retry = retry.increment(method="GET") retry = retry.increment(method="GET") assert retry.get_backoff_time() == 0 - def test_backoff_reset_after_redirect(self): + def test_backoff_reset_after_redirect(self) -> None: retry = Retry(total=100, redirect=5, backoff_factor=0.2) retry = retry.increment(method="GET") retry = retry.increment(method="GET") @@ -184,15 +229,15 @@ def test_backoff_reset_after_redirect(self): retry = retry.increment(method="GET") assert retry.get_backoff_time() == 0.4 - def test_sleep(self): + def test_sleep(self) -> None: # sleep a very small amount of time so our code coverage is happy retry = Retry(backoff_factor=0.0001) retry = retry.increment(method="GET") retry = retry.increment(method="GET") retry.sleep() - def test_status_forcelist(self): - retry = Retry(status_forcelist=xrange(500, 600)) + def test_status_forcelist(self) -> None: + retry = Retry(status_forcelist=range(500, 600)) assert not retry.is_retry("GET", status_code=200) assert not retry.is_retry("GET", status_code=400) assert retry.is_retry("GET", status_code=500) @@ -202,10 +247,10 @@ def test_status_forcelist(self): assert retry.is_retry("GET", status_code=418) # String status codes are not matched. - retry = Retry(total=1, status_forcelist=["418"]) + retry = Retry(total=1, status_forcelist=["418"]) # type: ignore[list-item] assert not retry.is_retry("GET", status_code=418) - def test_allowed_methods_with_status_forcelist(self): + def test_allowed_methods_with_status_forcelist(self) -> None: # Falsey allowed_methods means to retry on any method. retry = Retry(status_forcelist=[500], allowed_methods=None) assert retry.is_retry("GET", status_code=500) @@ -216,92 +261,88 @@ def test_allowed_methods_with_status_forcelist(self): assert not retry.is_retry("GET", status_code=500) assert retry.is_retry("POST", status_code=500) - def test_exhausted(self): + def test_exhausted(self) -> None: assert not Retry(0).is_exhausted() assert Retry(-1).is_exhausted() assert Retry(1).increment(method="GET").total == 0 @pytest.mark.parametrize("total", [-1, 0]) - def test_disabled(self, total): + def test_disabled(self, total: int) -> None: with pytest.raises(MaxRetryError): Retry(total).increment(method="GET") - def test_error_message(self): + def test_error_message(self) -> None: retry = Retry(total=0) - with pytest.raises(MaxRetryError) as e: + with pytest.raises(MaxRetryError, match="read timed out") as e: retry = retry.increment( - method="GET", error=ReadTimeoutError(None, "/", "read timed out") + method="GET", error=ReadTimeoutError(DUMMY_POOL, "/", "read timed out") ) assert "Caused by redirect" not in str(e.value) - assert str(e.value.reason) == "None: read timed out" retry = Retry(total=1) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment("POST", "/") + retry = retry.increment("POST", "/") + with pytest.raises(MaxRetryError, match=ResponseError.GENERIC_ERROR) as e: retry = retry.increment("POST", "/") assert "Caused by redirect" not in str(e.value) assert isinstance(e.value.reason, ResponseError) - assert str(e.value.reason) == ResponseError.GENERIC_ERROR retry = Retry(total=1) response = HTTPResponse(status=500) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment("POST", "/", response=response) + msg = ResponseError.SPECIFIC_ERROR.format(status_code=500) + retry = retry.increment("POST", "/", response=response) + with pytest.raises(MaxRetryError, match=msg) as e: retry = retry.increment("POST", "/", response=response) assert "Caused by redirect" not in str(e.value) - msg = ResponseError.SPECIFIC_ERROR.format(status_code=500) - assert str(e.value.reason) == msg retry = Retry(connect=1) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment(error=ConnectTimeoutError("conntimeout")) + retry = retry.increment(error=ConnectTimeoutError("conntimeout")) + with pytest.raises(MaxRetryError, match="conntimeout") as e: retry = retry.increment(error=ConnectTimeoutError("conntimeout")) assert "Caused by redirect" not in str(e.value) - assert str(e.value.reason) == "conntimeout" - def test_history(self): + def test_history(self) -> None: retry = Retry(total=10, allowed_methods=frozenset(["GET", "POST"])) assert retry.history == tuple() connection_error = ConnectTimeoutError("conntimeout") retry = retry.increment("GET", "/test1", None, connection_error) - history = (RequestHistory("GET", "/test1", connection_error, None, None),) - assert retry.history == history + test_history1 = (RequestHistory("GET", "/test1", connection_error, None, None),) + assert retry.history == test_history1 - read_error = ReadTimeoutError(None, "/test2", "read timed out") + read_error = ReadTimeoutError(DUMMY_POOL, "/test2", "read timed out") retry = retry.increment("POST", "/test2", None, read_error) - history = ( + test_history2 = ( RequestHistory("GET", "/test1", connection_error, None, None), RequestHistory("POST", "/test2", read_error, None, None), ) - assert retry.history == history + assert retry.history == test_history2 response = HTTPResponse(status=500) retry = retry.increment("GET", "/test3", response, None) - history = ( + test_history3 = ( RequestHistory("GET", "/test1", connection_error, None, None), RequestHistory("POST", "/test2", read_error, None, None), RequestHistory("GET", "/test3", None, 500, None), ) - assert retry.history == history + assert retry.history == test_history3 - def test_retry_method_not_in_whitelist(self): - error = ReadTimeoutError(None, "/", "read timed out") + def test_retry_method_not_allowed(self) -> None: + error = ReadTimeoutError(DUMMY_POOL, "/", "read timed out") retry = Retry() with pytest.raises(ReadTimeoutError): retry.increment(method="POST", error=error) - def test_retry_default_remove_headers_on_redirect(self): + def test_retry_default_remove_headers_on_redirect(self) -> None: retry = Retry() assert list(retry.remove_headers_on_redirect) == ["authorization"] - def test_retry_set_remove_headers_on_redirect(self): + def test_retry_set_remove_headers_on_redirect(self) -> None: retry = Retry(remove_headers_on_redirect=["X-API-Secret"]) assert list(retry.remove_headers_on_redirect) == ["x-api-secret"] - @pytest.mark.parametrize("value", ["-1", "+1", "1.0", six.u("\xb2")]) # \xb2 = ^2 - def test_parse_retry_after_invalid(self, value): + @pytest.mark.parametrize("value", ["-1", "+1", "1.0", "\xb2"]) # \xb2 = ^2 + def test_parse_retry_after_invalid(self, value: str) -> None: retry = Retry() with pytest.raises(InvalidHeader): retry.parse_retry_after(value) @@ -309,18 +350,18 @@ def test_parse_retry_after_invalid(self, value): @pytest.mark.parametrize( "value, expected", [("0", 0), ("1000", 1000), ("\t42 ", 42)] ) - def test_parse_retry_after(self, value, expected): + def test_parse_retry_after(self, value: str, expected: int) -> None: retry = Retry() assert retry.parse_retry_after(value) == expected @pytest.mark.parametrize("respect_retry_after_header", [True, False]) - def test_respect_retry_after_header_propagated(self, respect_retry_after_header): - + def test_respect_retry_after_header_propagated( + self, respect_retry_after_header: bool + ) -> None: retry = Retry(respect_retry_after_header=respect_retry_after_header) new_retry = retry.new() assert new_retry.respect_retry_after_header == respect_retry_after_header - @pytest.mark.freeze_time("2019-06-03 11:00:00", tz_offset=0) @pytest.mark.parametrize( "retry_after_header,respect_retry_after_header,sleep_duration", [ @@ -352,11 +393,16 @@ def test_respect_retry_after_header_propagated(self, respect_retry_after_header) ) @pytest.mark.usefixtures("stub_timezone") def test_respect_retry_after_header_sleep( - self, retry_after_header, respect_retry_after_header, sleep_duration - ): + self, + retry_after_header: str, + respect_retry_after_header: bool, + sleep_duration: int | None, + ) -> None: retry = Retry(respect_retry_after_header=respect_retry_after_header) - with mock.patch("time.sleep") as sleep_mock: + with freezegun.freeze_time("2019-06-03 11:00:00", tz_offset=0), mock.patch( + "time.sleep" + ) as sleep_mock: # for the default behavior, it must be in RETRY_AFTER_STATUS_CODES response = HTTPResponse( status=503, headers={"Retry-After": retry_after_header} diff --git a/test/test_retry_deprecated.py b/test/test_retry_deprecated.py deleted file mode 100644 index 0c8de37661..0000000000 --- a/test/test_retry_deprecated.py +++ /dev/null @@ -1,471 +0,0 @@ -# This is a copy-paste of test_retry.py with extra asserts about deprecated options. It will be removed for v2. -import warnings - -import mock -import pytest - -from urllib3.exceptions import ( - ConnectTimeoutError, - InvalidHeader, - MaxRetryError, - ReadTimeoutError, - ResponseError, - SSLError, -) -from urllib3.packages import six -from urllib3.packages.six.moves import xrange -from urllib3.response import HTTPResponse -from urllib3.util.retry import RequestHistory, Retry - - -# TODO: Remove this entire file once deprecated Retry options are removed in v2. -@pytest.fixture(scope="function") -def expect_retry_deprecation(): - with warnings.catch_warnings(record=True) as w: - yield - assert len([str(x.message) for x in w if "Retry" in str(x.message)]) > 0 - - -class TestRetry(object): - def test_string(self): - """ Retry string representation looks the way we expect """ - retry = Retry() - assert ( - str(retry) - == "Retry(total=10, connect=None, read=None, redirect=None, status=None)" - ) - for _ in range(3): - retry = retry.increment(method="GET") - assert ( - str(retry) - == "Retry(total=7, connect=None, read=None, redirect=None, status=None)" - ) - - def test_retry_both_specified(self): - """Total can win if it's lower than the connect value""" - error = ConnectTimeoutError() - retry = Retry(connect=3, total=2) - retry = retry.increment(error=error) - retry = retry.increment(error=error) - with pytest.raises(MaxRetryError) as e: - retry.increment(error=error) - assert e.value.reason == error - - def test_retry_higher_total_loses(self): - """ A lower connect timeout than the total is honored """ - error = ConnectTimeoutError() - retry = Retry(connect=2, total=3) - retry = retry.increment(error=error) - retry = retry.increment(error=error) - with pytest.raises(MaxRetryError): - retry.increment(error=error) - - def test_retry_higher_total_loses_vs_read(self): - """ A lower read timeout than the total is honored """ - error = ReadTimeoutError(None, "/", "read timed out") - retry = Retry(read=2, total=3) - retry = retry.increment(method="GET", error=error) - retry = retry.increment(method="GET", error=error) - with pytest.raises(MaxRetryError): - retry.increment(method="GET", error=error) - - def test_retry_total_none(self): - """ if Total is none, connect error should take precedence """ - error = ConnectTimeoutError() - retry = Retry(connect=2, total=None) - retry = retry.increment(error=error) - retry = retry.increment(error=error) - with pytest.raises(MaxRetryError) as e: - retry.increment(error=error) - assert e.value.reason == error - - error = ReadTimeoutError(None, "/", "read timed out") - retry = Retry(connect=2, total=None) - retry = retry.increment(method="GET", error=error) - retry = retry.increment(method="GET", error=error) - retry = retry.increment(method="GET", error=error) - assert not retry.is_exhausted() - - def test_retry_default(self): - """ If no value is specified, should retry connects 3 times """ - retry = Retry() - assert retry.total == 10 - assert retry.connect is None - assert retry.read is None - assert retry.redirect is None - assert retry.other is None - - error = ConnectTimeoutError() - retry = Retry(connect=1) - retry = retry.increment(error=error) - with pytest.raises(MaxRetryError): - retry.increment(error=error) - - retry = Retry(connect=1) - retry = retry.increment(error=error) - assert not retry.is_exhausted() - - assert Retry(0).raise_on_redirect - assert not Retry(False).raise_on_redirect - - def test_retry_other(self): - """ If an unexpected error is raised, should retry other times """ - other_error = SSLError() - retry = Retry(connect=1) - retry = retry.increment(error=other_error) - retry = retry.increment(error=other_error) - assert not retry.is_exhausted() - - retry = Retry(other=1) - retry = retry.increment(error=other_error) - with pytest.raises(MaxRetryError) as e: - retry.increment(error=other_error) - assert e.value.reason == other_error - - def test_retry_read_zero(self): - """ No second chances on read timeouts, by default """ - error = ReadTimeoutError(None, "/", "read timed out") - retry = Retry(read=0) - with pytest.raises(MaxRetryError) as e: - retry.increment(method="GET", error=error) - assert e.value.reason == error - - def test_status_counter(self): - resp = HTTPResponse(status=400) - retry = Retry(status=2) - retry = retry.increment(response=resp) - retry = retry.increment(response=resp) - with pytest.raises(MaxRetryError) as e: - retry.increment(response=resp) - assert str(e.value.reason) == ResponseError.SPECIFIC_ERROR.format( - status_code=400 - ) - - def test_backoff(self): - """ Backoff is computed correctly """ - max_backoff = Retry.BACKOFF_MAX - - retry = Retry(total=100, backoff_factor=0.2) - assert retry.get_backoff_time() == 0 # First request - - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 0 # First retry - - retry = retry.increment(method="GET") - assert retry.backoff_factor == 0.2 - assert retry.total == 98 - assert retry.get_backoff_time() == 0.4 # Start backoff - - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 0.8 - - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 1.6 - - for _ in xrange(10): - retry = retry.increment(method="GET") - - assert retry.get_backoff_time() == max_backoff - - def test_zero_backoff(self): - retry = Retry() - assert retry.get_backoff_time() == 0 - retry = retry.increment(method="GET") - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 0 - - def test_backoff_reset_after_redirect(self): - retry = Retry(total=100, redirect=5, backoff_factor=0.2) - retry = retry.increment(method="GET") - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 0.4 - redirect_response = HTTPResponse(status=302, headers={"location": "test"}) - retry = retry.increment(method="GET", response=redirect_response) - assert retry.get_backoff_time() == 0 - retry = retry.increment(method="GET") - retry = retry.increment(method="GET") - assert retry.get_backoff_time() == 0.4 - - def test_sleep(self): - # sleep a very small amount of time so our code coverage is happy - retry = Retry(backoff_factor=0.0001) - retry = retry.increment(method="GET") - retry = retry.increment(method="GET") - retry.sleep() - - def test_status_forcelist(self): - retry = Retry(status_forcelist=xrange(500, 600)) - assert not retry.is_retry("GET", status_code=200) - assert not retry.is_retry("GET", status_code=400) - assert retry.is_retry("GET", status_code=500) - - retry = Retry(total=1, status_forcelist=[418]) - assert not retry.is_retry("GET", status_code=400) - assert retry.is_retry("GET", status_code=418) - - # String status codes are not matched. - retry = Retry(total=1, status_forcelist=["418"]) - assert not retry.is_retry("GET", status_code=418) - - def test_method_whitelist_with_status_forcelist(self, expect_retry_deprecation): - # Falsey method_whitelist means to retry on any method. - retry = Retry(status_forcelist=[500], method_whitelist=None) - assert retry.is_retry("GET", status_code=500) - assert retry.is_retry("POST", status_code=500) - - # Criteria of method_whitelist and status_forcelist are ANDed. - retry = Retry(status_forcelist=[500], method_whitelist=["POST"]) - assert not retry.is_retry("GET", status_code=500) - assert retry.is_retry("POST", status_code=500) - - def test_exhausted(self): - assert not Retry(0).is_exhausted() - assert Retry(-1).is_exhausted() - assert Retry(1).increment(method="GET").total == 0 - - @pytest.mark.parametrize("total", [-1, 0]) - def test_disabled(self, total): - with pytest.raises(MaxRetryError): - Retry(total).increment(method="GET") - - def test_error_message(self): - retry = Retry(total=0) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment( - method="GET", error=ReadTimeoutError(None, "/", "read timed out") - ) - assert "Caused by redirect" not in str(e.value) - assert str(e.value.reason) == "None: read timed out" - - retry = Retry(total=1) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment("POST", "/") - retry = retry.increment("POST", "/") - assert "Caused by redirect" not in str(e.value) - assert isinstance(e.value.reason, ResponseError) - assert str(e.value.reason) == ResponseError.GENERIC_ERROR - - retry = Retry(total=1) - response = HTTPResponse(status=500) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment("POST", "/", response=response) - retry = retry.increment("POST", "/", response=response) - assert "Caused by redirect" not in str(e.value) - msg = ResponseError.SPECIFIC_ERROR.format(status_code=500) - assert str(e.value.reason) == msg - - retry = Retry(connect=1) - with pytest.raises(MaxRetryError) as e: - retry = retry.increment(error=ConnectTimeoutError("conntimeout")) - retry = retry.increment(error=ConnectTimeoutError("conntimeout")) - assert "Caused by redirect" not in str(e.value) - assert str(e.value.reason) == "conntimeout" - - def test_history(self, expect_retry_deprecation): - retry = Retry(total=10, method_whitelist=frozenset(["GET", "POST"])) - assert retry.history == tuple() - connection_error = ConnectTimeoutError("conntimeout") - retry = retry.increment("GET", "/test1", None, connection_error) - history = (RequestHistory("GET", "/test1", connection_error, None, None),) - assert retry.history == history - - read_error = ReadTimeoutError(None, "/test2", "read timed out") - retry = retry.increment("POST", "/test2", None, read_error) - history = ( - RequestHistory("GET", "/test1", connection_error, None, None), - RequestHistory("POST", "/test2", read_error, None, None), - ) - assert retry.history == history - - response = HTTPResponse(status=500) - retry = retry.increment("GET", "/test3", response, None) - history = ( - RequestHistory("GET", "/test1", connection_error, None, None), - RequestHistory("POST", "/test2", read_error, None, None), - RequestHistory("GET", "/test3", None, 500, None), - ) - assert retry.history == history - - def test_retry_method_not_in_whitelist(self): - error = ReadTimeoutError(None, "/", "read timed out") - retry = Retry() - with pytest.raises(ReadTimeoutError): - retry.increment(method="POST", error=error) - - def test_retry_default_remove_headers_on_redirect(self): - retry = Retry() - - assert list(retry.remove_headers_on_redirect) == ["authorization"] - - def test_retry_set_remove_headers_on_redirect(self): - retry = Retry(remove_headers_on_redirect=["X-API-Secret"]) - - assert list(retry.remove_headers_on_redirect) == ["x-api-secret"] - - @pytest.mark.parametrize("value", ["-1", "+1", "1.0", six.u("\xb2")]) # \xb2 = ^2 - def test_parse_retry_after_invalid(self, value): - retry = Retry() - with pytest.raises(InvalidHeader): - retry.parse_retry_after(value) - - @pytest.mark.parametrize( - "value, expected", [("0", 0), ("1000", 1000), ("\t42 ", 42)] - ) - def test_parse_retry_after(self, value, expected): - retry = Retry() - assert retry.parse_retry_after(value) == expected - - @pytest.mark.parametrize("respect_retry_after_header", [True, False]) - def test_respect_retry_after_header_propagated(self, respect_retry_after_header): - - retry = Retry(respect_retry_after_header=respect_retry_after_header) - new_retry = retry.new() - assert new_retry.respect_retry_after_header == respect_retry_after_header - - @pytest.mark.freeze_time("2019-06-03 11:00:00", tz_offset=0) - @pytest.mark.parametrize( - "retry_after_header,respect_retry_after_header,sleep_duration", - [ - ("3600", True, 3600), - ("3600", False, None), - # Will sleep due to header is 1 hour in future - ("Mon, 3 Jun 2019 12:00:00 UTC", True, 3600), - # Won't sleep due to not respecting header - ("Mon, 3 Jun 2019 12:00:00 UTC", False, None), - # Won't sleep due to current time reached - ("Mon, 3 Jun 2019 11:00:00 UTC", True, None), - # Won't sleep due to current time reached + not respecting header - ("Mon, 3 Jun 2019 11:00:00 UTC", False, None), - # Handle all the formats in RFC 7231 Section 7.1.1.1 - ("Mon, 03 Jun 2019 11:30:12 GMT", True, 1812), - ("Monday, 03-Jun-19 11:30:12 GMT", True, 1812), - # Assume that datetimes without a timezone are in UTC per RFC 7231 - ("Mon Jun 3 11:30:12 2019", True, 1812), - ], - ) - @pytest.mark.parametrize( - "stub_timezone", - [ - "UTC", - "Asia/Jerusalem", - None, - ], - indirect=True, - ) - @pytest.mark.usefixtures("stub_timezone") - def test_respect_retry_after_header_sleep( - self, retry_after_header, respect_retry_after_header, sleep_duration - ): - retry = Retry(respect_retry_after_header=respect_retry_after_header) - - with mock.patch("time.sleep") as sleep_mock: - # for the default behavior, it must be in RETRY_AFTER_STATUS_CODES - response = HTTPResponse( - status=503, headers={"Retry-After": retry_after_header} - ) - - retry.sleep(response) - - # The expected behavior is that we'll only sleep if respecting - # this header (since we won't have any backoff sleep attempts) - if respect_retry_after_header and sleep_duration is not None: - sleep_mock.assert_called_with(sleep_duration) - else: - sleep_mock.assert_not_called() - - -class TestRetryDeprecations(object): - def test_cls_get_default_method_whitelist(self, expect_retry_deprecation): - assert Retry.DEFAULT_ALLOWED_METHODS == Retry.DEFAULT_METHOD_WHITELIST - - def test_cls_get_default_redirect_headers_blacklist(self, expect_retry_deprecation): - assert ( - Retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT - == Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST - ) - - def test_cls_set_default_method_whitelist(self, expect_retry_deprecation): - old_setting = Retry.DEFAULT_METHOD_WHITELIST - try: - Retry.DEFAULT_METHOD_WHITELIST = {"GET"} - retry = Retry() - assert retry.DEFAULT_ALLOWED_METHODS == {"GET"} - assert retry.DEFAULT_METHOD_WHITELIST == {"GET"} - assert retry.allowed_methods == {"GET"} - assert retry.method_whitelist == {"GET"} - - # Test that the default can be overridden both ways - retry = Retry(allowed_methods={"GET", "POST"}) - assert retry.DEFAULT_ALLOWED_METHODS == {"GET"} - assert retry.DEFAULT_METHOD_WHITELIST == {"GET"} - assert retry.allowed_methods == {"GET", "POST"} - assert retry.method_whitelist == {"GET", "POST"} - - retry = Retry(method_whitelist={"POST"}) - assert retry.DEFAULT_ALLOWED_METHODS == {"GET"} - assert retry.DEFAULT_METHOD_WHITELIST == {"GET"} - assert retry.allowed_methods == {"POST"} - assert retry.method_whitelist == {"POST"} - finally: - Retry.DEFAULT_METHOD_WHITELIST = old_setting - assert Retry.DEFAULT_ALLOWED_METHODS == old_setting - - def test_cls_set_default_redirect_headers_blacklist(self, expect_retry_deprecation): - old_setting = Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST - try: - Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST = {"test"} - retry = Retry() - assert retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT == {"test"} - assert retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST == {"test"} - assert retry.remove_headers_on_redirect == {"test"} - assert retry.remove_headers_on_redirect == {"test"} - - retry = Retry(remove_headers_on_redirect={"test2"}) - assert retry.DEFAULT_REMOVE_HEADERS_ON_REDIRECT == {"test"} - assert retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST == {"test"} - assert retry.remove_headers_on_redirect == {"test2"} - assert retry.remove_headers_on_redirect == {"test2"} - finally: - Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST = old_setting - assert Retry.DEFAULT_REDIRECT_HEADERS_BLACKLIST == old_setting - - @pytest.mark.parametrize( - "options", [(None, None), ({"GET"}, None), (None, {"GET"}), ({"GET"}, {"GET"})] - ) - def test_retry_allowed_methods_and_method_whitelist_error(self, options): - with pytest.raises(ValueError) as e: - Retry(allowed_methods=options[0], method_whitelist=options[1]) - assert str(e.value) == ( - "Using both 'allowed_methods' and 'method_whitelist' together " - "is not allowed. Instead only use 'allowed_methods'" - ) - - def test_retry_subclass_that_sets_method_whitelist(self, expect_retry_deprecation): - class SubclassRetry(Retry): - def __init__(self, **kwargs): - if "allowed_methods" in kwargs: - raise AssertionError( - "This subclass likely doesn't use 'allowed_methods'" - ) - - super(SubclassRetry, self).__init__(**kwargs) - - # Since we're setting 'method_whiteist' we get fallbacks - # within Retry.new() and Retry._is_method_retryable() - # to use 'method_whitelist' instead of 'allowed_methods' - self.method_whitelist = self.method_whitelist | {"POST"} - - retry = SubclassRetry() - assert retry.method_whitelist == Retry.DEFAULT_ALLOWED_METHODS | {"POST"} - assert retry.new(read=0).method_whitelist == retry.method_whitelist - assert retry._is_method_retryable("POST") - assert not retry._is_method_retryable("CONNECT") - - assert retry.new(method_whitelist={"GET"}).method_whitelist == {"GET", "POST"} - - # urllib3 doesn't do this during normal operation - # so we don't want users passing in 'allowed_methods' - # when their subclass doesn't support the option yet. - with pytest.raises(AssertionError) as e: - retry.new(allowed_methods={"GET"}) - assert str(e.value) == "This subclass likely doesn't use 'allowed_methods'" diff --git a/test/test_ssl.py b/test/test_ssl.py index 4a00d355e5..c886d4e51c 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -1,170 +1,217 @@ -from test import notPyPy2 +from __future__ import annotations + +import ssl +import typing +from unittest import mock -import mock import pytest -from urllib3.exceptions import SNIMissingWarning +from urllib3.exceptions import ProxySchemeUnsupported, SSLError from urllib3.util import ssl_ -@pytest.mark.parametrize( - "addr", - [ - # IPv6 - "::1", - "::", - "FE80::8939:7684:D84b:a5A4%251", - # IPv4 - "127.0.0.1", - "8.8.8.8", - b"127.0.0.1", - # IPv6 w/ Zone IDs - "FE80::8939:7684:D84b:a5A4%251", - b"FE80::8939:7684:D84b:a5A4%251", - "FE80::8939:7684:D84b:a5A4%19", - b"FE80::8939:7684:D84b:a5A4%19", - ], -) -def test_is_ipaddress_true(addr): - assert ssl_.is_ipaddress(addr) - - -@pytest.mark.parametrize( - "addr", - [ - "www.python.org", - b"www.python.org", - "v2.sg.media-imdb.com", - b"v2.sg.media-imdb.com", - ], -) -def test_is_ipaddress_false(addr): - assert not ssl_.is_ipaddress(addr) - - -@pytest.mark.parametrize( - ["has_sni", "server_hostname", "uses_sni"], - [ - (True, "127.0.0.1", False), - (False, "www.python.org", False), - (False, "0.0.0.0", False), - (True, "www.google.com", True), - (True, None, False), - (False, None, False), - ], -) -def test_context_sni_with_ip_address(monkeypatch, has_sni, server_hostname, uses_sni): - monkeypatch.setattr(ssl_, "HAS_SNI", has_sni) - - sock = mock.Mock() - context = mock.create_autospec(ssl_.SSLContext) - - ssl_.ssl_wrap_socket(sock, server_hostname=server_hostname, ssl_context=context) - - if uses_sni: - context.wrap_socket.assert_called_with(sock, server_hostname=server_hostname) - else: - context.wrap_socket.assert_called_with(sock) - - -@pytest.mark.parametrize( - ["has_sni", "server_hostname", "should_warn"], - [ - (True, "www.google.com", False), - (True, "127.0.0.1", False), - (False, "127.0.0.1", False), - (False, "www.google.com", True), - (True, None, False), - (False, None, False), - ], -) -def test_sni_missing_warning_with_ip_addresses( - monkeypatch, has_sni, server_hostname, should_warn -): - monkeypatch.setattr(ssl_, "HAS_SNI", has_sni) - - sock = mock.Mock() - context = mock.create_autospec(ssl_.SSLContext) - - with mock.patch("warnings.warn") as warn: - ssl_.ssl_wrap_socket(sock, server_hostname=server_hostname, ssl_context=context) - - if should_warn: - assert warn.call_count >= 1 - warnings = [call[0][1] for call in warn.call_args_list] - assert SNIMissingWarning in warnings - else: - assert warn.call_count == 0 - - -@pytest.mark.parametrize( - ["ciphers", "expected_ciphers"], - [ - (None, ssl_.DEFAULT_CIPHERS), - ("ECDH+AESGCM:ECDH+CHACHA20", "ECDH+AESGCM:ECDH+CHACHA20"), - ], -) -def test_create_urllib3_context_set_ciphers(monkeypatch, ciphers, expected_ciphers): - - context = mock.create_autospec(ssl_.SSLContext) - context.set_ciphers = mock.Mock() - context.options = 0 - monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) - - assert ssl_.create_urllib3_context(ciphers=ciphers) is context - - assert context.set_ciphers.call_count == 1 - assert context.set_ciphers.call_args == mock.call(expected_ciphers) - - -def test_wrap_socket_given_context_no_load_default_certs(): - context = mock.create_autospec(ssl_.SSLContext) - context.load_default_certs = mock.Mock() - - sock = mock.Mock() - ssl_.ssl_wrap_socket(sock, ssl_context=context) - - context.load_default_certs.assert_not_called() - - -@notPyPy2 -def test_wrap_socket_given_ca_certs_no_load_default_certs(monkeypatch): - context = mock.create_autospec(ssl_.SSLContext) - context.load_default_certs = mock.Mock() - context.options = 0 - - monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) - - sock = mock.Mock() - ssl_.ssl_wrap_socket(sock, ca_certs="/tmp/fake-file") - - context.load_default_certs.assert_not_called() - context.load_verify_locations.assert_called_with("/tmp/fake-file", None, None) - - -def test_wrap_socket_default_loads_default_certs(monkeypatch): - context = mock.create_autospec(ssl_.SSLContext) - context.load_default_certs = mock.Mock() - context.options = 0 - - monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) - - sock = mock.Mock() - ssl_.ssl_wrap_socket(sock) - - context.load_default_certs.assert_called_with() - - -@pytest.mark.parametrize( - ["pha", "expected_pha"], [(None, None), (False, True), (True, True)] -) -def test_create_urllib3_context_pha(monkeypatch, pha, expected_pha): - context = mock.create_autospec(ssl_.SSLContext) - context.set_ciphers = mock.Mock() - context.options = 0 - context.post_handshake_auth = pha - monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) - - assert ssl_.create_urllib3_context() is context - - assert context.post_handshake_auth == expected_pha +class TestSSL: + @pytest.mark.parametrize( + "addr", + [ + # IPv6 + "::1", + "::", + "FE80::8939:7684:D84b:a5A4%251", + # IPv4 + "127.0.0.1", + "8.8.8.8", + b"127.0.0.1", + # IPv6 w/ Zone IDs + "FE80::8939:7684:D84b:a5A4%251", + b"FE80::8939:7684:D84b:a5A4%251", + "FE80::8939:7684:D84b:a5A4%19", + b"FE80::8939:7684:D84b:a5A4%19", + ], + ) + def test_is_ipaddress_true(self, addr: bytes | str) -> None: + assert ssl_.is_ipaddress(addr) + + @pytest.mark.parametrize( + "addr", + [ + "www.python.org", + b"www.python.org", + "v2.sg.media-imdb.com", + b"v2.sg.media-imdb.com", + ], + ) + def test_is_ipaddress_false(self, addr: bytes | str) -> None: + assert not ssl_.is_ipaddress(addr) + + def test_create_urllib3_context_set_ciphers( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + ciphers = "ECDH+AESGCM:ECDH+CHACHA20" + context = mock.create_autospec(ssl_.SSLContext) + context.set_ciphers = mock.Mock() + context.options = 0 + monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) + + assert ssl_.create_urllib3_context(ciphers=ciphers) is context + + assert context.set_ciphers.call_count == 1 + assert context.set_ciphers.call_args == mock.call(ciphers) + + def test_create_urllib3_no_context(self) -> None: + with mock.patch("urllib3.util.ssl_.SSLContext", None): + with pytest.raises(TypeError): + ssl_.create_urllib3_context() + + def test_wrap_socket_given_context_no_load_default_certs(self) -> None: + context = mock.create_autospec(ssl_.SSLContext) + context.load_default_certs = mock.Mock() + + sock = mock.Mock() + ssl_.ssl_wrap_socket(sock, ssl_context=context) + + context.load_default_certs.assert_not_called() + + def test_wrap_socket_given_ca_certs_no_load_default_certs( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + context = mock.create_autospec(ssl_.SSLContext) + context.load_default_certs = mock.Mock() + context.options = 0 + + monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) + + sock = mock.Mock() + ssl_.ssl_wrap_socket(sock, ca_certs="/tmp/fake-file") + + context.load_default_certs.assert_not_called() + context.load_verify_locations.assert_called_with("/tmp/fake-file", None, None) + + def test_wrap_socket_default_loads_default_certs( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + context = mock.create_autospec(ssl_.SSLContext) + context.load_default_certs = mock.Mock() + context.options = 0 + + monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) + + sock = mock.Mock() + ssl_.ssl_wrap_socket(sock) + + context.load_default_certs.assert_called_with() + + def test_wrap_socket_no_ssltransport(self) -> None: + with mock.patch("urllib3.util.ssl_.SSLTransport", None): + with pytest.raises(ProxySchemeUnsupported): + sock = mock.Mock() + ssl_.ssl_wrap_socket(sock, tls_in_tls=True) + + @pytest.mark.parametrize( + ["pha", "expected_pha"], [(None, None), (False, True), (True, True)] + ) + def test_create_urllib3_context_pha( + self, + monkeypatch: pytest.MonkeyPatch, + pha: bool | None, + expected_pha: bool | None, + ) -> None: + context = mock.create_autospec(ssl_.SSLContext) + context.set_ciphers = mock.Mock() + context.options = 0 + context.post_handshake_auth = pha + monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) + + assert ssl_.create_urllib3_context() is context + + assert context.post_handshake_auth == expected_pha + + def test_create_urllib3_context_default_ciphers( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + context = mock.create_autospec(ssl_.SSLContext) + context.set_ciphers = mock.Mock() + context.options = 0 + monkeypatch.setattr(ssl_, "SSLContext", lambda *_, **__: context) + + ssl_.create_urllib3_context() + + context.set_ciphers.assert_not_called() + + @pytest.mark.parametrize( + "kwargs", + [ + { + "ssl_version": ssl.PROTOCOL_TLSv1, + "ssl_minimum_version": ssl.TLSVersion.MINIMUM_SUPPORTED, + }, + { + "ssl_version": ssl.PROTOCOL_TLSv1, + "ssl_maximum_version": ssl.TLSVersion.TLSv1, + }, + { + "ssl_version": ssl.PROTOCOL_TLSv1, + "ssl_minimum_version": ssl.TLSVersion.MINIMUM_SUPPORTED, + "ssl_maximum_version": ssl.TLSVersion.MAXIMUM_SUPPORTED, + }, + ], + ) + def test_create_urllib3_context_ssl_version_and_ssl_min_max_version_errors( + self, kwargs: dict[str, typing.Any] + ) -> None: + with pytest.raises(ValueError) as e: + ssl_.create_urllib3_context(**kwargs) + + assert str(e.value) == ( + "Can't specify both 'ssl_version' and either 'ssl_minimum_version' or 'ssl_maximum_version'" + ) + + @pytest.mark.parametrize( + "kwargs", + [ + { + "ssl_version": ssl.PROTOCOL_TLS, + "ssl_minimum_version": ssl.TLSVersion.MINIMUM_SUPPORTED, + }, + { + "ssl_version": ssl.PROTOCOL_TLS_CLIENT, + "ssl_minimum_version": ssl.TLSVersion.MINIMUM_SUPPORTED, + }, + { + "ssl_version": None, + "ssl_minimum_version": ssl.TLSVersion.MINIMUM_SUPPORTED, + }, + ], + ) + def test_create_urllib3_context_ssl_version_and_ssl_min_max_version_no_warning( + self, kwargs: dict[str, typing.Any] + ) -> None: + ssl_.create_urllib3_context(**kwargs) + + @pytest.mark.parametrize( + "kwargs", + [ + {"ssl_version": ssl.PROTOCOL_TLSv1, "ssl_minimum_version": None}, + {"ssl_version": ssl.PROTOCOL_TLSv1, "ssl_maximum_version": None}, + { + "ssl_version": ssl.PROTOCOL_TLSv1, + "ssl_minimum_version": None, + "ssl_maximum_version": None, + }, + ], + ) + def test_create_urllib3_context_ssl_version_and_ssl_min_max_version_no_error( + self, kwargs: dict[str, typing.Any] + ) -> None: + with pytest.warns( + DeprecationWarning, + match=r"'ssl_version' option is deprecated and will be removed in " + r"urllib3 v2\.1\.0\. Instead use 'ssl_minimum_version'", + ): + ssl_.create_urllib3_context(**kwargs) + + def test_assert_fingerprint_raises_exception_on_none_cert(self) -> None: + with pytest.raises(SSLError): + ssl_.assert_fingerprint( + cert=None, fingerprint="55:39:BF:70:05:12:43:FA:1F:D1:BF:4E:E8:1B:07:1D" + ) diff --git a/test/test_ssltransport.py b/test/test_ssltransport.py index 6cf632d57f..cace51db96 100644 --- a/test/test_ssltransport.py +++ b/test/test_ssltransport.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import platform import select import socket import ssl -import sys +import typing +from unittest import mock -import mock import pytest from dummyserver.server import DEFAULT_CA, DEFAULT_CERTS @@ -12,33 +14,36 @@ from urllib3.util import ssl_ from urllib3.util.ssltransport import SSLTransport +if typing.TYPE_CHECKING: + from typing_extensions import Literal + # consume_socket can iterate forever, we add timeouts to prevent halting. PER_TEST_TIMEOUT = 60 -def server_client_ssl_contexts(): +def server_client_ssl_contexts() -> tuple[ssl.SSLContext, ssl.SSLContext]: if hasattr(ssl, "PROTOCOL_TLS_SERVER"): server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - else: - # python 2.7 and 3.5 workaround. - # PROTOCOL_TLS_SERVER was added in 3.6 - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS) server_context.load_cert_chain(DEFAULT_CERTS["certfile"], DEFAULT_CERTS["keyfile"]) if hasattr(ssl, "PROTOCOL_TLS_CLIENT"): client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - else: - # python 2.7 and 3.5 workaround. - # PROTOCOL_TLS_SERVER was added in 3.6 - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.check_hostname = True client_context.load_verify_locations(DEFAULT_CA) return server_context, client_context -def sample_request(binary=True): +@typing.overload +def sample_request(binary: Literal[True] = ...) -> bytes: + ... + + +@typing.overload +def sample_request(binary: Literal[False]) -> str: + ... + + +def sample_request(binary: bool = True) -> bytes | str: request = ( b"GET http://www.testing.com/ HTTP/1.1\r\n" b"Host: www.testing.com\r\n" @@ -48,25 +53,43 @@ def sample_request(binary=True): return request if binary else request.decode("utf-8") -def validate_request(provided_request, binary=True): +def validate_request( + provided_request: bytearray, binary: Literal[False, True] = True +) -> None: assert provided_request is not None expected_request = sample_request(binary) assert provided_request == expected_request -def sample_response(binary=True): +@typing.overload +def sample_response(binary: Literal[True] = ...) -> bytes: + ... + + +@typing.overload +def sample_response(binary: Literal[False]) -> str: + ... + + +@typing.overload +def sample_response(binary: bool = ...) -> bytes | str: + ... + + +def sample_response(binary: bool = True) -> bytes | str: response = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" return response if binary else response.decode("utf-8") -def validate_response(provided_response, binary=True): +def validate_response( + provided_response: bytes | bytearray | str, binary: bool = True +) -> None: assert provided_response is not None expected_response = sample_response(binary) assert provided_response == expected_response -def validate_peercert(ssl_socket): - +def validate_peercert(ssl_socket: SSLTransport) -> None: binary_cert = ssl_socket.getpeercert(binary_form=True) assert type(binary_cert) == bytes assert len(binary_cert) > 0 @@ -77,7 +100,6 @@ def validate_peercert(ssl_socket): assert cert["serialNumber"] != "" -@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python3.5 or higher") class SingleTLSLayerTestCase(SocketDummyServerTestCase): """ Uses the SocketDummyServer to validate a single TLS layer can be @@ -85,23 +107,28 @@ class SingleTLSLayerTestCase(SocketDummyServerTestCase): """ @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls.server_context, cls.client_context = server_client_ssl_contexts() - def start_dummy_server(self, handler=None): - def socket_handler(listener): + def start_dummy_server( + self, handler: typing.Callable[[socket.socket], None] | None = None + ) -> None: + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - with self.server_context.wrap_socket(sock, server_side=True) as ssock: - request = consume_socket(ssock) - validate_request(request) - ssock.send(sample_response()) + try: + with self.server_context.wrap_socket(sock, server_side=True) as ssock: + request = consume_socket(ssock) + validate_request(request) + ssock.send(sample_response()) + except (ConnectionAbortedError, ConnectionResetError): + return chosen_handler = handler if handler else socket_handler self._start_server(chosen_handler) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_start_closed_socket(self): - """ Errors generated from an unconnected socket should bubble up.""" + def test_start_closed_socket(self) -> None: + """Errors generated from an unconnected socket should bubble up.""" sock = socket.socket(socket.AF_INET) context = ssl.create_default_context() sock.close() @@ -109,8 +136,8 @@ def test_start_closed_socket(self): SSLTransport(sock, context) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_close_after_handshake(self): - """ Socket errors should be bubbled up """ + def test_close_after_handshake(self) -> None: + """Socket errors should be bubbled up""" self.start_dummy_server() sock = socket.create_connection((self.host, self.port)) @@ -122,8 +149,8 @@ def test_close_after_handshake(self): ssock.send(b"blaaargh") @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_wrap_existing_socket(self): - """ Validates a single TLS layer can be established. """ + def test_wrap_existing_socket(self) -> None: + """Validates a single TLS layer can be established.""" self.start_dummy_server() sock = socket.create_connection((self.host, self.port)) @@ -136,7 +163,7 @@ def test_wrap_existing_socket(self): validate_response(response) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_unbuffered_text_makefile(self): + def test_unbuffered_text_makefile(self) -> None: self.start_dummy_server() sock = socket.create_connection((self.host, self.port)) @@ -150,13 +177,13 @@ def test_unbuffered_text_makefile(self): validate_response(response) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_unwrap_existing_socket(self): + def test_unwrap_existing_socket(self) -> None: """ Validates we can break up the TLS layer A full request/response is sent over TLS, and later over plain text. """ - def shutdown_handler(listener): + def shutdown_handler(listener: socket.socket) -> None: sock = listener.accept()[0] ssl_sock = self.server_context.wrap_socket(sock, server_side=True) @@ -186,8 +213,8 @@ def shutdown_handler(listener): validate_response(response) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_ssl_object_attributes(self): - """ Ensures common ssl attributes are exposed """ + def test_ssl_object_attributes(self) -> None: + """Ensures common ssl attributes are exposed""" self.start_dummy_server() sock = socket.create_connection((self.host, self.port)) @@ -202,8 +229,11 @@ def test_ssl_object_attributes(self): assert ssock.selected_npn_protocol() is None shared_ciphers = ssock.shared_ciphers() - assert type(shared_ciphers) == list - assert len(shared_ciphers) > 0 + # SSLContext.shared_ciphers() changed behavior completely in a patch version. + # See: https://github.com/python/cpython/issues/96931 + assert shared_ciphers is None or ( + type(shared_ciphers) is list and len(shared_ciphers) > 0 + ) assert ssock.compression() is None @@ -214,8 +244,8 @@ def test_ssl_object_attributes(self): validate_response(response) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_socket_object_attributes(self): - """ Ensures common socket attributes are exposed """ + def test_socket_object_attributes(self) -> None: + """Ensures common socket attributes are exposed""" self.start_dummy_server() sock = socket.create_connection((self.host, self.port)) @@ -239,29 +269,37 @@ class SocketProxyDummyServer(SocketDummyServerTestCase): socket. """ - def __init__(self, destination_server_host, destination_server_port): + def __init__( + self, destination_server_host: str, destination_server_port: int + ) -> None: self.destination_server_host = destination_server_host self.destination_server_port = destination_server_port - self.server_context, self.client_context = server_client_ssl_contexts() + self.server_ctx, _ = server_client_ssl_contexts() - def start_proxy_handler(self): + def start_proxy_handler(self) -> None: """ Socket handler for the proxy. Terminates the first TLS layer and tunnels any bytes needed for client <-> server communicatin. """ - def proxy_handler(listener): + def proxy_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - with self.server_context.wrap_socket(sock, server_side=True) as client_sock: + with self.server_ctx.wrap_socket(sock, server_side=True) as client_sock: upstream_sock = socket.create_connection( (self.destination_server_host, self.destination_server_port) ) self._read_write_loop(client_sock, upstream_sock) upstream_sock.close() + client_sock.close() self._start_server(proxy_handler) - def _read_write_loop(self, client_sock, server_sock, chunks=65536): + def _read_write_loop( + self, + client_sock: socket.socket, + server_sock: socket.socket, + chunks: int = 65536, + ) -> None: inputs = [client_sock, server_sock] output = [client_sock, server_sock] @@ -269,7 +307,7 @@ def _read_write_loop(self, client_sock, server_sock, chunks=65536): readable, writable, exception = select.select(inputs, output, inputs) if exception: - # Error ocurred with either of the sockets, time to + # Error occurred with either of the sockets, time to # wrap up, parent func will close sockets. break @@ -282,10 +320,14 @@ def _read_write_loop(self, client_sock, server_sock, chunks=65536): read_socket = server_sock write_socket = client_sock - # Ensure buffer is not full before writting + # Ensure buffer is not full before writing if write_socket in writable: try: b = read_socket.recv(chunks) + if len(b) == 0: + # One of the sockets has EOFed, we return to close + # both. + return write_socket.send(b) except ssl.SSLEOFError: # It's possible, depending on shutdown order, that we'll @@ -294,7 +336,6 @@ def _read_write_loop(self, client_sock, server_sock, chunks=65536): return -@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python3.5 or higher") class TlsInTlsTestCase(SocketDummyServerTestCase): """ Creates a TLS in TLS tunnel by chaining a 'SocketProxyDummyServer' and a @@ -306,40 +347,44 @@ class TlsInTlsTestCase(SocketDummyServerTestCase): """ @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls.server_context, cls.client_context = server_client_ssl_contexts() @classmethod - def start_proxy_server(cls): + def start_proxy_server(cls) -> None: # Proxy server will handle the first TLS connection and create a # connection to the destination server. cls.proxy_server = SocketProxyDummyServer(cls.host, cls.port) cls.proxy_server.start_proxy_handler() @classmethod - def teardown_class(cls): + def teardown_class(cls) -> None: if hasattr(cls, "proxy_server"): cls.proxy_server.teardown_class() - super(TlsInTlsTestCase, cls).teardown_class() + super().teardown_class() @classmethod - def start_destination_server(cls): + def start_destination_server(cls) -> None: """ Socket handler for the destination_server. Terminates the second TLS layer and send a basic HTTP response. """ - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - with cls.server_context.wrap_socket(sock, server_side=True) as ssock: - request = consume_socket(ssock) - validate_request(request) - ssock.send(sample_response()) + try: + with cls.server_context.wrap_socket(sock, server_side=True) as ssock: + request = consume_socket(ssock) + validate_request(request) + ssock.send(sample_response()) + except (ssl.SSLEOFError, ssl.SSLZeroReturnError, OSError): + return + sock.close() cls._start_server(socket_handler) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_tls_in_tls_tunnel(self): + def test_tls_in_tls_tunnel(self) -> None: """ Basic communication over the TLS in TLS tunnel. """ @@ -361,7 +406,7 @@ def test_tls_in_tls_tunnel(self): validate_response(response) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_wrong_sni_hint(self): + def test_wrong_sni_hint(self) -> None: """ Provides a wrong sni hint to validate an exception is thrown. """ @@ -374,17 +419,14 @@ def test_wrong_sni_hint(self): with self.client_context.wrap_socket( sock, server_hostname="localhost" ) as proxy_sock: - with pytest.raises(Exception) as e: + with pytest.raises(ssl.SSLCertVerificationError): SSLTransport( proxy_sock, self.client_context, server_hostname="veryverywrong" ) - # ssl.CertificateError is a child of ValueError in python3.6 or - # before. After python3.7 it's a child of SSLError - assert e.type in [ssl.SSLError, ssl.CertificateError] @pytest.mark.timeout(PER_TEST_TIMEOUT) @pytest.mark.parametrize("buffering", [None, 0]) - def test_tls_in_tls_makefile_raw_rw_binary(self, buffering): + def test_tls_in_tls_makefile_raw_rw_binary(self, buffering: int | None) -> None: """ Uses makefile with read, write and binary modes without buffering. """ @@ -400,13 +442,12 @@ def test_tls_in_tls_makefile_raw_rw_binary(self, buffering): with SSLTransport( proxy_sock, self.client_context, server_hostname="localhost" ) as destination_sock: - file = destination_sock.makefile("rwb", buffering) - file.write(sample_request()) + file.write(sample_request()) # type: ignore[call-overload] file.flush() response = bytearray(65536) - wrote = file.readinto(response) + wrote = file.readinto(response) # type: ignore[union-attr] assert wrote is not None # Allocated response is bigger than the actual response, we # rtrim remaining x00 bytes. @@ -419,7 +460,7 @@ def test_tls_in_tls_makefile_raw_rw_binary(self, buffering): reason="Skipping windows due to text makefile support", ) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_tls_in_tls_makefile_rw_text(self): + def test_tls_in_tls_makefile_rw_text(self) -> None: """ Creates a separate buffer for reading and writing using text mode and utf-8 encoding. @@ -436,22 +477,23 @@ def test_tls_in_tls_makefile_rw_text(self): with SSLTransport( proxy_sock, self.client_context, server_hostname="localhost" ) as destination_sock: - read = destination_sock.makefile("r", encoding="utf-8") write = destination_sock.makefile("w", encoding="utf-8") - write.write(sample_request(binary=False)) + write.write(sample_request(binary=False)) # type: ignore[arg-type, call-overload] write.flush() response = read.read() + assert isinstance(response, str) if "\r" not in response: # Carriage return will be removed when reading as a file on # some platforms. We add it before the comparison. + assert isinstance(response, str) response = response.replace("\n", "\r\n") validate_response(response, binary=False) @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_tls_in_tls_recv_into_sendall(self): + def test_tls_in_tls_recv_into_sendall(self) -> None: """ Valides recv_into and sendall also work as expected. Other tests are using recv/send. @@ -468,46 +510,63 @@ def test_tls_in_tls_recv_into_sendall(self): with SSLTransport( proxy_sock, self.client_context, server_hostname="localhost" ) as destination_sock: - destination_sock.sendall(sample_request()) response = bytearray(65536) destination_sock.recv_into(response) str_response = response.decode("utf-8").rstrip("\x00") validate_response(str_response, binary=False) - @pytest.mark.timeout(PER_TEST_TIMEOUT) - def test_tls_in_tls_recv_into_unbuffered(self): - """ - Valides recv_into without a preallocated buffer. - """ - self.start_destination_server() - self.start_proxy_server() - sock = socket.create_connection( - (self.proxy_server.host, self.proxy_server.port) +class TestSSLTransportWithMock: + def test_constructor_params(self) -> None: + server_hostname = "example-domain.com" + sock = mock.Mock() + context = mock.create_autospec(ssl_.SSLContext) + ssl_transport = SSLTransport( + sock, context, server_hostname=server_hostname, suppress_ragged_eofs=False ) - with self.client_context.wrap_socket( - sock, server_hostname="localhost" - ) as proxy_sock: - with SSLTransport( - proxy_sock, self.client_context, server_hostname="localhost" - ) as destination_sock: + context.wrap_bio.assert_called_with( + mock.ANY, mock.ANY, server_hostname=server_hostname + ) + assert not ssl_transport.suppress_ragged_eofs - destination_sock.send(sample_request()) - response = destination_sock.recv_into(None) - validate_response(response) + def test_various_flags_errors(self) -> None: + server_hostname = "example-domain.com" + sock = mock.Mock() + context = mock.create_autospec(ssl_.SSLContext) + ssl_transport = SSLTransport( + sock, context, server_hostname=server_hostname, suppress_ragged_eofs=False + ) + with pytest.raises(ValueError): + ssl_transport.recv(flags=1) + + with pytest.raises(ValueError): + ssl_transport.recv_into(bytearray(), flags=1) + with pytest.raises(ValueError): + ssl_transport.sendall(bytearray(), flags=1) -@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python3.5 or higher") -class TestSSLTransportWithMock(object): - def test_constructor_params(self): + with pytest.raises(ValueError): + ssl_transport.send(None, flags=1) # type: ignore[arg-type] + + def test_makefile_wrong_mode_error(self) -> None: server_hostname = "example-domain.com" sock = mock.Mock() context = mock.create_autospec(ssl_.SSLContext) ssl_transport = SSLTransport( sock, context, server_hostname=server_hostname, suppress_ragged_eofs=False ) - context.wrap_bio.assert_called_with( - mock.ANY, mock.ANY, server_hostname=server_hostname + with pytest.raises(ValueError): + ssl_transport.makefile(mode="x") + + def test_wrap_ssl_read_error(self) -> None: + server_hostname = "example-domain.com" + sock = mock.Mock() + context = mock.create_autospec(ssl_.SSLContext) + ssl_transport = SSLTransport( + sock, context, server_hostname=server_hostname, suppress_ragged_eofs=False ) - assert not ssl_transport.suppress_ragged_eofs + with mock.patch.object(ssl_transport, "_ssl_io_loop") as _ssl_io_loop: + _ssl_io_loop.side_effect = ssl.SSLError() + with pytest.raises(ssl.SSLError): + ssl_transport._wrap_ssl_read(1) diff --git a/test/test_util.py b/test/test_util.py index 827df42726..38f271296d 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,50 +1,56 @@ -# coding: utf-8 -import hashlib +from __future__ import annotations + import io import logging import socket import ssl +import sys +import typing import warnings from itertools import chain -from test import notBrotlipy, onlyBrotlipy, onlyPy2, onlyPy3 +from test import ImportBlocker, ModuleStash, notBrotli, notZstd, onlyBrotli, onlyZstd +from unittest import mock +from unittest.mock import MagicMock, Mock, patch +from urllib.parse import urlparse import pytest -from mock import Mock, patch -from urllib3 import add_stderr_logger, disable_warnings, util +from urllib3 import add_stderr_logger, disable_warnings +from urllib3.connection import ProxyConfig from urllib3.exceptions import ( InsecureRequestWarning, LocationParseError, - SNIMissingWarning, TimeoutStateError, UnrewindableBodyError, ) -from urllib3.packages import six -from urllib3.poolmanager import ProxyConfig from urllib3.util import is_fp_closed from urllib3.util.connection import _has_ipv6, allowed_gai_family, create_connection -from urllib3.util.proxy import connection_requires_http_tunnel, create_proxy_ssl_context +from urllib3.util.proxy import connection_requires_http_tunnel from urllib3.util.request import _FAILEDTELL, make_headers, rewind_body from urllib3.util.response import assert_header_parsing from urllib3.util.ssl_ import ( - _const_compare_digest_backport, + _TYPE_VERSION_INFO, + _is_has_never_check_common_name_reliable, resolve_cert_reqs, resolve_ssl_version, ssl_wrap_socket, ) -from urllib3.util.timeout import Timeout -from urllib3.util.url import Url, get_host, parse_url, split_first +from urllib3.util.timeout import _DEFAULT_TIMEOUT, Timeout +from urllib3.util.url import Url, _encode_invalid_chars, parse_url +from urllib3.util.util import to_bytes, to_str from . import clear_warnings +if typing.TYPE_CHECKING: + from typing_extensions import Literal + # This number represents a time in seconds, it doesn't mean anything in # isolation. Setting to a high-ish value to avoid conflicts with the smaller # numbers used for timeouts TIMEOUT_EPOCH = 1000 -class TestUtil(object): - +class TestUtil: url_host_map = [ # Hosts ("http://google.com/mail", ("http", "google.com", None)), @@ -106,6 +112,9 @@ class TestUtil(object): "http://[2010:836b:4179::836b:4179]", ("http", "[2010:836b:4179::836b:4179]", None), ), + # Scoped IPv6 (with ZoneID), both RFC 6874 compliant and not. + ("http://[a::b%25zone]", ("http", "[a::b%zone]", None)), + ("http://[a::b%zone]", ("http", "[a::b%zone]", None)), # Hosts ("HTTP://GOOGLE.COM/mail/", ("http", "google.com", None)), ("GOogle.COM/mail", ("http", "google.com", None)), @@ -132,38 +141,38 @@ class TestUtil(object): ), ] - @pytest.mark.parametrize("url, expected_host", url_host_map) - def test_get_host(self, url, expected_host): - returned_host = get_host(url) - assert returned_host == expected_host + @pytest.mark.parametrize(["url", "scheme_host_port"], url_host_map) + def test_scheme_host_port( + self, url: str, scheme_host_port: tuple[str, str, int | None] + ) -> None: + parsed_url = parse_url(url) + scheme, host, port = scheme_host_port + + assert (parsed_url.scheme or "http") == scheme + assert parsed_url.hostname == parsed_url.host == host + assert parsed_url.port == port + + def test_encode_invalid_chars_none(self) -> None: + assert _encode_invalid_chars(None, set()) is None - # TODO: Add more tests @pytest.mark.parametrize( - "location", + "url", [ "http://google.com:foo", "http://::1/", "http://::1:80/", "http://google.com:-80", - six.u("http://google.com:\xb2\xb2"), # \xb2 = ^2 - ], - ) - def test_invalid_host(self, location): - with pytest.raises(LocationParseError): - get_host(location) - - @pytest.mark.parametrize( - "url", - [ + "http://google.com:65536", + "http://google.com:\xb2\xb2", # \xb2 = ^2 # Invalid IDNA labels - u"http://\uD7FF.com", - u"http://❤️", + "http://\uD7FF.com", + "http://❤️", # Unicode surrogates - u"http://\uD800.com", - u"http://\uDC00.com", + "http://\uD800.com", + "http://\uDC00.com", ], ) - def test_invalid_url(self, url): + def test_invalid_url(self, url: str) -> None: with pytest.raises(LocationParseError): parse_url(url) @@ -181,6 +190,10 @@ def test_invalid_url(self, url): ), ("HTTPS://Example.Com/?Key=Value", "https://example.com/?Key=Value"), ("Https://Example.Com/#Fragment", "https://example.com/#Fragment"), + # IPv6 addresses with zone IDs. Both RFC 6874 (%25) as well as + # non-standard (unquoted %) variants. + ("[::1%zone]", "[::1%zone]"), + ("[::1%25zone]", "[::1%zone]"), ("[::1%25]", "[::1%25]"), ("[::Ff%etH0%Ff]/%ab%Af", "[::ff%etH0%FF]/%AB%AF"), ( @@ -200,16 +213,18 @@ def test_invalid_url(self, url): ), ], ) - def test_parse_url_normalization(self, url, expected_normalized_url): + def test_parse_url_normalization( + self, url: str, expected_normalized_url: str + ) -> None: """Assert parse_url normalizes the scheme/host, and only the scheme/host""" actual_normalized_url = parse_url(url).url assert actual_normalized_url == expected_normalized_url @pytest.mark.parametrize("char", [chr(i) for i in range(0x00, 0x21)] + ["\x7F"]) - def test_control_characters_are_percent_encoded(self, char): + def test_control_characters_are_percent_encoded(self, char: str) -> None: percent_char = "%" + (hex(ord(char))[2:].zfill(2).upper()) url = parse_url( - "http://user{0}@example.com/path{0}?query{0}#fragment{0}".format(char) + f"http://user{char}@example.com/path{char}?query{char}#fragment{char}" ) assert url == Url( @@ -257,15 +272,6 @@ def test_control_characters_are_percent_encoded(self, char): "http://foo:bar@localhost/", Url("http", auth="foo:bar", host="localhost", path="/"), ), - # Unicode type (Python 2.x) - ( - u"http://foo:bar@localhost/", - Url(u"http", auth=u"foo:bar", host=u"localhost", path=u"/"), - ), - ( - "http://foo:bar@localhost/", - Url("http", auth="foo:bar", host="localhost", path="/"), - ), ] non_round_tripping_parse_url_host_map = [ @@ -279,26 +285,26 @@ def test_control_characters_are_percent_encoded(self, char): ("http://google.com:/", Url("http", host="google.com", path="/")), # Uppercase IRI ( - u"http://Königsgäßchen.de/straße", + "http://Königsgäßchen.de/straße", Url("http", host="xn--knigsgchen-b4a3dun.de", path="/stra%C3%9Fe"), ), # Percent-encode in userinfo ( - u"http://user@email.com:password@example.com/", + "http://user@email.com:password@example.com/", Url("http", auth="user%40email.com:password", host="example.com", path="/"), ), ( - u'http://user":quoted@example.com/', + 'http://user":quoted@example.com/', Url("http", auth="user%22:quoted", host="example.com", path="/"), ), # Unicode Surrogates - (u"http://google.com/\uD800", Url("http", host="google.com", path="%ED%A0%80")), + ("http://google.com/\uD800", Url("http", host="google.com", path="%ED%A0%80")), ( - u"http://google.com?q=\uDC00", + "http://google.com?q=\uDC00", Url("http", host="google.com", path="", query="q=%ED%B0%80"), ), ( - u"http://google.com#\uDC00", + "http://google.com#\uDC00", Url("http", host="google.com", path="", fragment="%ED%B0%80"), ), ] @@ -307,12 +313,13 @@ def test_control_characters_are_percent_encoded(self, char): "url, expected_url", chain(parse_url_host_map, non_round_tripping_parse_url_host_map), ) - def test_parse_url(self, url, expected_url): + def test_parse_url(self, url: str, expected_url: Url) -> None: returned_url = parse_url(url) assert returned_url == expected_url + assert returned_url.hostname == returned_url.host == expected_url.host @pytest.mark.parametrize("url, expected_url", parse_url_host_map) - def test_unparse_url(self, url, expected_url): + def test_unparse_url(self, url: str, expected_url: Url) -> None: assert url == expected_url.url @pytest.mark.parametrize( @@ -327,20 +334,31 @@ def test_unparse_url(self, url, expected_url): ("/abc/./.././d/././e/.././f/./../../ghi", Url(path="/ghi")), ], ) - def test_parse_and_normalize_url_paths(self, url, expected_url): + def test_parse_and_normalize_url_paths(self, url: str, expected_url: Url) -> None: actual_url = parse_url(url) assert actual_url == expected_url assert actual_url.url == expected_url.url - def test_parse_url_invalid_IPv6(self): + def test_parse_url_invalid_IPv6(self) -> None: with pytest.raises(LocationParseError): parse_url("[::1") - def test_parse_url_negative_port(self): + def test_parse_url_negative_port(self) -> None: with pytest.raises(LocationParseError): parse_url("https://www.google.com:-80/") - def test_Url_str(self): + def test_parse_url_remove_leading_zeros(self) -> None: + url = parse_url("https://example.com:0000000000080") + assert url.port == 80 + + def test_parse_url_only_zeros(self) -> None: + url = parse_url("https://example.com:0") + assert url.port == 0 + + url = parse_url("https://example.com:000000000000") + assert url.port == 0 + + def test_Url_str(self) -> None: U = Url("http", host="google.com") assert str(U) == U.url @@ -357,19 +375,60 @@ def test_Url_str(self): ] @pytest.mark.parametrize("url, expected_request_uri", request_uri_map) - def test_request_uri(self, url, expected_request_uri): + def test_request_uri(self, url: str, expected_request_uri: str) -> None: returned_url = parse_url(url) assert returned_url.request_uri == expected_request_uri + url_authority_map: list[tuple[str, str | None]] = [ + ("http://user:pass@google.com/mail", "user:pass@google.com"), + ("http://user:pass@google.com:80/mail", "user:pass@google.com:80"), + ("http://user@google.com:80/mail", "user@google.com:80"), + ("http://user:pass@192.168.1.1/path", "user:pass@192.168.1.1"), + ("http://user:pass@192.168.1.1:80/path", "user:pass@192.168.1.1:80"), + ("http://user@192.168.1.1:80/path", "user@192.168.1.1:80"), + ("http://user:pass@[::1]/path", "user:pass@[::1]"), + ("http://user:pass@[::1]:80/path", "user:pass@[::1]:80"), + ("http://user@[::1]:80/path", "user@[::1]:80"), + ("http://user:pass@localhost/path", "user:pass@localhost"), + ("http://user:pass@localhost:80/path", "user:pass@localhost:80"), + ("http://user@localhost:80/path", "user@localhost:80"), + ] + url_netloc_map = [ ("http://google.com/mail", "google.com"), ("http://google.com:80/mail", "google.com:80"), + ("http://192.168.0.1/path", "192.168.0.1"), + ("http://192.168.0.1:80/path", "192.168.0.1:80"), + ("http://[::1]/path", "[::1]"), + ("http://[::1]:80/path", "[::1]:80"), + ("http://localhost", "localhost"), + ("http://localhost:80", "localhost:80"), ("google.com/foobar", "google.com"), ("google.com:12345", "google.com:12345"), + ("/", None), ] + combined_netloc_authority_map = url_authority_map + url_netloc_map + + # We compose this list due to variances between parse_url + # and urlparse when URIs don't provide a scheme. + url_authority_with_schemes_map = [ + u for u in combined_netloc_authority_map if u[0].startswith("http") + ] + + @pytest.mark.parametrize("url, expected_authority", combined_netloc_authority_map) + def test_authority(self, url: str, expected_authority: str | None) -> None: + assert parse_url(url).authority == expected_authority + + @pytest.mark.parametrize("url, expected_authority", url_authority_with_schemes_map) + def test_authority_matches_urllib_netloc( + self, url: str, expected_authority: str | None + ) -> None: + """Validate this matches the behavior of urlparse().netloc""" + assert urlparse(url).netloc == expected_authority + @pytest.mark.parametrize("url, expected_netloc", url_netloc_map) - def test_netloc(self, url, expected_netloc): + def test_netloc(self, url: str, expected_netloc: str | None) -> None: assert parse_url(url).netloc == expected_netloc url_vulnerabilities = [ @@ -385,7 +444,7 @@ def test_netloc(self, url, expected_netloc): ), # NodeJS unicode -> double dot ( - u"http://google.com/\uff2e\uff2e/abc", + "http://google.com/\uff2e\uff2e/abc", Url("http", host="google.com", path="/%EF%BC%AE%EF%BC%AE/abc"), ), # Scheme without :// @@ -396,14 +455,14 @@ def test_netloc(self, url, expected_netloc): ("//google.com/a/b/c", Url(host="google.com", path="/a/b/c")), # International URLs ( - u"http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ", + "http://ヒ:キ@ヒ.abc.ニ/ヒ?キ#ワ", Url( - u"http", - host=u"xn--pdk.abc.xn--idk", - auth=u"%E3%83%92:%E3%82%AD", - path=u"/%E3%83%92", - query=u"%E3%82%AD", - fragment=u"%E3%83%AF", + "http", + host="xn--pdk.abc.xn--idk", + auth="%E3%83%92:%E3%82%AD", + path="/%E3%83%92", + query="%E3%82%AD", + fragment="%E3%83%AF", ), ), # Injected headers (CVE-2016-5699, CVE-2019-9740, CVE-2019-9947) @@ -438,63 +497,81 @@ def test_netloc(self, url, expected_netloc): fragment="hash", ), ), + # Tons of '@' causing backtracking + pytest.param( + "https://" + ("@" * 10000) + "[", + False, + id="Tons of '@' causing backtracking 1", + ), + pytest.param( + "https://user:" + ("@" * 10000) + "example.com", + Url( + scheme="https", + auth="user:" + ("%40" * 9999), + host="example.com", + ), + id="Tons of '@' causing backtracking 2", + ), ] @pytest.mark.parametrize("url, expected_url", url_vulnerabilities) - def test_url_vulnerabilities(self, url, expected_url): + def test_url_vulnerabilities( + self, url: str, expected_url: Literal[False] | Url + ) -> None: if expected_url is False: with pytest.raises(LocationParseError): parse_url(url) else: assert parse_url(url) == expected_url - @onlyPy2 - def test_parse_url_bytes_to_str_python_2(self): - url = parse_url(b"https://www.google.com/") - assert url == Url("https", host="www.google.com", path="/") - - assert isinstance(url.scheme, str) - assert isinstance(url.host, str) - assert isinstance(url.path, str) - - @onlyPy2 - def test_parse_url_unicode_python_2(self): - url = parse_url(u"https://www.google.com/") - assert url == Url(u"https", host=u"www.google.com", path=u"/") - - assert isinstance(url.scheme, six.text_type) - assert isinstance(url.host, six.text_type) - assert isinstance(url.path, six.text_type) - - @onlyPy3 - def test_parse_url_bytes_type_error_python_3(self): + def test_parse_url_bytes_type_error(self) -> None: with pytest.raises(TypeError): - parse_url(b"https://www.google.com/") + parse_url(b"https://www.google.com/") # type: ignore[arg-type] @pytest.mark.parametrize( "kwargs, expected", [ + pytest.param( + {"accept_encoding": True}, + {"accept-encoding": "gzip,deflate,br,zstd"}, + marks=[onlyBrotli(), onlyZstd()], # type: ignore[list-item] + ), pytest.param( {"accept_encoding": True}, {"accept-encoding": "gzip,deflate,br"}, - marks=onlyBrotlipy(), + marks=[onlyBrotli(), notZstd()], # type: ignore[list-item] + ), + pytest.param( + {"accept_encoding": True}, + {"accept-encoding": "gzip,deflate,zstd"}, + marks=[notBrotli(), onlyZstd()], # type: ignore[list-item] ), pytest.param( {"accept_encoding": True}, {"accept-encoding": "gzip,deflate"}, - marks=notBrotlipy(), + marks=[notBrotli(), notZstd()], # type: ignore[list-item] ), ({"accept_encoding": "foo,bar"}, {"accept-encoding": "foo,bar"}), ({"accept_encoding": ["foo", "bar"]}, {"accept-encoding": "foo,bar"}), + pytest.param( + {"accept_encoding": True, "user_agent": "banana"}, + {"accept-encoding": "gzip,deflate,br,zstd", "user-agent": "banana"}, + marks=[onlyBrotli(), onlyZstd()], # type: ignore[list-item] + ), pytest.param( {"accept_encoding": True, "user_agent": "banana"}, {"accept-encoding": "gzip,deflate,br", "user-agent": "banana"}, - marks=onlyBrotlipy(), + marks=[onlyBrotli(), notZstd()], # type: ignore[list-item] + ), + pytest.param( + {"accept_encoding": True, "user_agent": "banana"}, + {"accept-encoding": "gzip,deflate,zstd", "user-agent": "banana"}, + marks=[notBrotli(), onlyZstd()], # type: ignore[list-item] ), pytest.param( {"accept_encoding": True, "user_agent": "banana"}, {"accept-encoding": "gzip,deflate", "user-agent": "banana"}, - marks=notBrotlipy(), + marks=[notBrotli(), notZstd()], # type: ignore[list-item] ), ({"user_agent": "banana"}, {"user-agent": "banana"}), ({"keep_alive": True}, {"connection": "keep-alive"}), @@ -506,10 +583,12 @@ def test_parse_url_bytes_type_error_python_3(self): ({"disable_cache": True}, {"cache-control": "no-cache"}), ], ) - def test_make_headers(self, kwargs, expected): - assert make_headers(**kwargs) == expected + def test_make_headers( + self, kwargs: dict[str, bool | str], expected: dict[str, str] + ) -> None: + assert make_headers(**kwargs) == expected # type: ignore[arg-type] - def test_rewind_body(self): + def test_rewind_body(self) -> None: body = io.BytesIO(b"test data") assert body.read() == b"test data" @@ -520,7 +599,7 @@ def test_rewind_body(self): rewind_body(body, 5) assert body.read() == b"data" - def test_rewind_body_failed_tell(self): + def test_rewind_body_failed_tell(self) -> None: body = io.BytesIO(b"test data") body.read() # Consume body @@ -529,40 +608,25 @@ def test_rewind_body_failed_tell(self): with pytest.raises(UnrewindableBodyError): rewind_body(body, body_pos) - def test_rewind_body_bad_position(self): + def test_rewind_body_bad_position(self) -> None: body = io.BytesIO(b"test data") body.read() # Consume body # Pass non-integer position with pytest.raises(ValueError): - rewind_body(body, body_pos=None) + rewind_body(body, body_pos=None) # type: ignore[arg-type] with pytest.raises(ValueError): - rewind_body(body, body_pos=object()) + rewind_body(body, body_pos=object()) # type: ignore[arg-type] - def test_rewind_body_failed_seek(self): - class BadSeek: - def seek(self, pos, offset=0): - raise IOError + def test_rewind_body_failed_seek(self) -> None: + class BadSeek(io.StringIO): + def seek(self, offset: int, whence: int = 0) -> typing.NoReturn: + raise OSError with pytest.raises(UnrewindableBodyError): rewind_body(BadSeek(), body_pos=2) - @pytest.mark.parametrize( - "input, expected", - [ - (("abcd", "b"), ("a", "cd", "b")), - (("abcd", "cb"), ("a", "cd", "b")), - (("abcd", ""), ("abcd", "", None)), - (("abcd", "a"), ("", "bcd", "a")), - (("abcd", "ab"), ("", "bcd", "a")), - (("abcd", "eb"), ("a", "cd", "b")), - ], - ) - def test_split_first(self, input, expected): - output = split_first(*input) - assert output == expected - - def test_add_stderr_logger(self): + def test_add_stderr_logger(self) -> None: handler = add_stderr_logger(level=logging.INFO) # Don't actually print debug logger = logging.getLogger("urllib3") assert handler in logger.handlers @@ -570,17 +634,20 @@ def test_add_stderr_logger(self): logger.debug("Testing add_stderr_logger") logger.removeHandler(handler) - def test_disable_warnings(self): + def test_disable_warnings(self) -> None: with warnings.catch_warnings(record=True) as w: clear_warnings() + warnings.simplefilter("default", InsecureRequestWarning) warnings.warn("This is a test.", InsecureRequestWarning) assert len(w) == 1 disable_warnings() warnings.warn("This is a test.", InsecureRequestWarning) assert len(w) == 1 - def _make_time_pass(self, seconds, timeout, time_mock): - """ Make some time pass for the timeout object """ + def _make_time_pass( + self, seconds: int, timeout: Timeout, time_mock: Mock + ) -> Timeout: + """Make some time pass for the timeout object""" time_mock.return_value = TIMEOUT_EPOCH timeout.start_connect() time_mock.return_value = TIMEOUT_EPOCH + seconds @@ -596,20 +663,22 @@ def _make_time_pass(self, seconds, timeout, time_mock): ({"read": True}, "cannot be a boolean"), ({"connect": 0}, "less than or equal"), ({"read": "foo"}, "int, float or None"), + ({"read": "1.0"}, "int, float or None"), ], ) - def test_invalid_timeouts(self, kwargs, message): - with pytest.raises(ValueError) as e: + def test_invalid_timeouts( + self, kwargs: dict[str, int | bool], message: str + ) -> None: + with pytest.raises(ValueError, match=message): Timeout(**kwargs) - assert message in str(e.value) - @patch("urllib3.util.timeout.current_time") - def test_timeout(self, current_time): + @patch("time.monotonic") + def test_timeout(self, time_monotonic: MagicMock) -> None: timeout = Timeout(total=3) # make 'no time' elapse timeout = self._make_time_pass( - seconds=0, timeout=timeout, time_mock=current_time + seconds=0, timeout=timeout, time_mock=time_monotonic ) assert timeout.read_timeout == 3 assert timeout.connect_timeout == 3 @@ -618,19 +687,19 @@ def test_timeout(self, current_time): assert timeout.connect_timeout == 2 timeout = Timeout() - assert timeout.connect_timeout == Timeout.DEFAULT_TIMEOUT + assert timeout.connect_timeout == _DEFAULT_TIMEOUT # Connect takes 5 seconds, leaving 5 seconds for read timeout = Timeout(total=10, read=7) timeout = self._make_time_pass( - seconds=5, timeout=timeout, time_mock=current_time + seconds=5, timeout=timeout, time_mock=time_monotonic ) assert timeout.read_timeout == 5 # Connect takes 2 seconds, read timeout still 7 seconds timeout = Timeout(total=10, read=7) timeout = self._make_time_pass( - seconds=2, timeout=timeout, time_mock=current_time + seconds=2, timeout=timeout, time_mock=time_monotonic ) assert timeout.read_timeout == 7 @@ -645,15 +714,24 @@ def test_timeout(self, current_time): timeout = Timeout(5) assert timeout.total == 5 - def test_timeout_str(self): + def test_timeout_default_resolve(self) -> None: + """The timeout default is resolved when read_timeout is accessed.""" + timeout = Timeout() + with patch("urllib3.util.timeout.getdefaulttimeout", return_value=2): + assert timeout.read_timeout == 2 + + with patch("urllib3.util.timeout.getdefaulttimeout", return_value=3): + assert timeout.read_timeout == 3 + + def test_timeout_str(self) -> None: timeout = Timeout(connect=1, read=2, total=3) assert str(timeout) == "Timeout(connect=1, read=2, total=3)" timeout = Timeout(connect=1, read=None, total=3) assert str(timeout) == "Timeout(connect=1, read=None, total=3)" - @patch("urllib3.util.timeout.current_time") - def test_timeout_elapsed(self, current_time): - current_time.return_value = TIMEOUT_EPOCH + @patch("time.monotonic") + def test_timeout_elapsed(self, time_monotonic: MagicMock) -> None: + time_monotonic.return_value = TIMEOUT_EPOCH timeout = Timeout(total=3) with pytest.raises(TimeoutStateError): timeout.get_connect_duration() @@ -662,101 +740,88 @@ def test_timeout_elapsed(self, current_time): with pytest.raises(TimeoutStateError): timeout.start_connect() - current_time.return_value = TIMEOUT_EPOCH + 2 + time_monotonic.return_value = TIMEOUT_EPOCH + 2 assert timeout.get_connect_duration() == 2 - current_time.return_value = TIMEOUT_EPOCH + 37 + time_monotonic.return_value = TIMEOUT_EPOCH + 37 assert timeout.get_connect_duration() == 37 - def test_is_fp_closed_object_supports_closed(self): - class ClosedFile(object): + def test_is_fp_closed_object_supports_closed(self) -> None: + class ClosedFile: @property - def closed(self): + def closed(self) -> Literal[True]: return True assert is_fp_closed(ClosedFile()) - def test_is_fp_closed_object_has_none_fp(self): - class NoneFpFile(object): + def test_is_fp_closed_object_has_none_fp(self) -> None: + class NoneFpFile: @property - def fp(self): + def fp(self) -> None: return None assert is_fp_closed(NoneFpFile()) - def test_is_fp_closed_object_has_fp(self): - class FpFile(object): + def test_is_fp_closed_object_has_fp(self) -> None: + class FpFile: @property - def fp(self): + def fp(self) -> Literal[True]: return True assert not is_fp_closed(FpFile()) - def test_is_fp_closed_object_has_neither_fp_nor_closed(self): - class NotReallyAFile(object): + def test_is_fp_closed_object_has_neither_fp_nor_closed(self) -> None: + class NotReallyAFile: pass with pytest.raises(ValueError): is_fp_closed(NotReallyAFile()) - def test_const_compare_digest_fallback(self): - target = hashlib.sha256(b"abcdef").digest() - assert _const_compare_digest_backport(target, target) - - prefix = target[:-1] - assert not _const_compare_digest_backport(target, prefix) - - suffix = target + b"0" - assert not _const_compare_digest_backport(target, suffix) - - incorrect = hashlib.sha256(b"xyz").digest() - assert not _const_compare_digest_backport(target, incorrect) - - def test_has_ipv6_disabled_on_compile(self): + def test_has_ipv6_disabled_on_compile(self) -> None: with patch("socket.has_ipv6", False): assert not _has_ipv6("::1") - def test_has_ipv6_enabled_but_fails(self): + def test_has_ipv6_enabled_but_fails(self) -> None: with patch("socket.has_ipv6", True): with patch("socket.socket") as mock: instance = mock.return_value instance.bind = Mock(side_effect=Exception("No IPv6 here!")) assert not _has_ipv6("::1") - def test_has_ipv6_enabled_and_working(self): + def test_has_ipv6_enabled_and_working(self) -> None: with patch("socket.has_ipv6", True): with patch("socket.socket") as mock: instance = mock.return_value instance.bind.return_value = True assert _has_ipv6("::1") - def test_has_ipv6_disabled_on_appengine(self): - gae_patch = patch( - "urllib3.contrib._appengine_environ.is_appengine_sandbox", return_value=True - ) - with gae_patch: - assert not _has_ipv6("::1") - - def test_ip_family_ipv6_enabled(self): + def test_ip_family_ipv6_enabled(self) -> None: with patch("urllib3.util.connection.HAS_IPV6", True): assert allowed_gai_family() == socket.AF_UNSPEC - def test_ip_family_ipv6_disabled(self): + def test_ip_family_ipv6_disabled(self) -> None: with patch("urllib3.util.connection.HAS_IPV6", False): assert allowed_gai_family() == socket.AF_INET @pytest.mark.parametrize("headers", [b"foo", None, object]) - def test_assert_header_parsing_throws_typeerror_with_non_headers(self, headers): + def test_assert_header_parsing_throws_typeerror_with_non_headers( + self, headers: bytes | object | None + ) -> None: with pytest.raises(TypeError): - assert_header_parsing(headers) + assert_header_parsing(headers) # type: ignore[arg-type] - def test_connection_requires_http_tunnel_no_proxy(self): + def test_connection_requires_http_tunnel_no_proxy(self) -> None: assert not connection_requires_http_tunnel( proxy_url=None, proxy_config=None, destination_scheme=None ) - def test_connection_requires_http_tunnel_http_proxy(self): + def test_connection_requires_http_tunnel_http_proxy(self) -> None: proxy = parse_url("http://proxy:8080") - proxy_config = ProxyConfig(ssl_context=None, use_forwarding_for_https=False) + proxy_config = ProxyConfig( + ssl_context=None, + use_forwarding_for_https=False, + assert_hostname=None, + assert_fingerprint=None, + ) destination_scheme = "http" assert not connection_requires_http_tunnel( proxy, proxy_config, destination_scheme @@ -765,20 +830,20 @@ def test_connection_requires_http_tunnel_http_proxy(self): destination_scheme = "https" assert connection_requires_http_tunnel(proxy, proxy_config, destination_scheme) - def test_connection_requires_http_tunnel_https_proxy(self): + def test_connection_requires_http_tunnel_https_proxy(self) -> None: proxy = parse_url("https://proxy:8443") - proxy_config = ProxyConfig(ssl_context=None, use_forwarding_for_https=False) + proxy_config = ProxyConfig( + ssl_context=None, + use_forwarding_for_https=False, + assert_hostname=None, + assert_fingerprint=None, + ) destination_scheme = "http" assert not connection_requires_http_tunnel( proxy, proxy_config, destination_scheme ) - def test_create_proxy_ssl_context(self): - ssl_context = create_proxy_ssl_context(ssl_version=None, cert_reqs=None) - ssl_context.verify_mode = ssl.CERT_REQUIRED - - @onlyPy3 - def test_assert_header_parsing_no_error_on_multipart(self): + def test_assert_header_parsing_no_error_on_multipart(self) -> None: from http import client header_msg = io.BytesIO() @@ -792,10 +857,12 @@ def test_assert_header_parsing_no_error_on_multipart(self): assert_header_parsing(client.parse_headers(header_msg)) @pytest.mark.parametrize("host", [".localhost", "...", "t" * 64]) - def test_create_connection_with_invalid_idna_labels(self, host): - with pytest.raises(LocationParseError) as ctx: + def test_create_connection_with_invalid_idna_labels(self, host: str) -> None: + with pytest.raises( + LocationParseError, + match=f"Failed to parse: '{host}', label empty or too long", + ): create_connection((host, 80)) - assert str(ctx.value) == "Failed to parse: '%s', label empty or too long" % host @pytest.mark.parametrize( "host", @@ -809,13 +876,96 @@ def test_create_connection_with_invalid_idna_labels(self, host): ) @patch("socket.getaddrinfo") @patch("socket.socket") - def test_create_connection_with_valid_idna_labels(self, socket, getaddrinfo, host): + def test_create_connection_with_valid_idna_labels( + self, socket: MagicMock, getaddrinfo: MagicMock, host: str + ) -> None: getaddrinfo.return_value = [(None, None, None, None, None)] socket.return_value = Mock() create_connection((host, 80)) + @patch("socket.getaddrinfo") + def test_create_connection_error(self, getaddrinfo: MagicMock) -> None: + getaddrinfo.return_value = [] + with pytest.raises(OSError, match="getaddrinfo returns an empty list"): + create_connection(("example.com", 80)) + + @patch("socket.getaddrinfo") + def test_dnsresolver_forced_error(self, getaddrinfo: MagicMock) -> None: + getaddrinfo.side_effect = socket.gaierror() + with pytest.raises(socket.gaierror): + # dns is valid but we force the error just for the sake of the test + create_connection(("example.com", 80)) + + def test_dnsresolver_expected_error(self) -> None: + with pytest.raises(socket.gaierror): + # windows: [Errno 11001] getaddrinfo failed in windows + # linux: [Errno -2] Name or service not known + # macos: [Errno 8] nodename nor servname provided, or not known + create_connection(("badhost.invalid", 80)) + + @patch("socket.getaddrinfo") + @patch("socket.socket") + def test_create_connection_with_scoped_ipv6( + self, socket: MagicMock, getaddrinfo: MagicMock + ) -> None: + # Check that providing create_connection with a scoped IPv6 address + # properly propagates the scope to getaddrinfo, and that the returned + # scoped ID makes it to the socket creation call. + fake_scoped_sa6 = ("a::b", 80, 0, 42) + getaddrinfo.return_value = [ + ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + fake_scoped_sa6, + ) + ] + socket.return_value = fake_sock = MagicMock() + + create_connection(("a::b%iface", 80)) + assert getaddrinfo.call_args[0][0] == "a::b%iface" + fake_sock.connect.assert_called_once_with(fake_scoped_sa6) + + @pytest.mark.parametrize( + "input,params,expected", + ( + ("test", {}, "test"), # str input + (b"test", {}, "test"), # bytes input + (b"test", {"encoding": "utf-8"}, "test"), # bytes input with utf-8 + (b"test", {"encoding": "ascii"}, "test"), # bytes input with ascii + ), + ) + def test_to_str( + self, input: bytes | str, params: dict[str, str], expected: str + ) -> None: + assert to_str(input, **params) == expected + + def test_to_str_error(self) -> None: + with pytest.raises(TypeError, match="not expecting type int"): + to_str(1) # type: ignore[arg-type] -class TestUtilSSL(object): + @pytest.mark.parametrize( + "input,params,expected", + ( + (b"test", {}, b"test"), # str input + ("test", {}, b"test"), # bytes input + ("é", {}, b"\xc3\xa9"), # bytes input + ("test", {"encoding": "utf-8"}, b"test"), # bytes input with utf-8 + ("test", {"encoding": "ascii"}, b"test"), # bytes input with ascii + ), + ) + def test_to_bytes( + self, input: bytes | str, params: dict[str, str], expected: bytes + ) -> None: + assert to_bytes(input, **params) == expected + + def test_to_bytes_error(self) -> None: + with pytest.raises(TypeError, match="not expecting type int"): + to_bytes(1) # type: ignore[arg-type] + + +class TestUtilSSL: """Test utils that use an SSL backend.""" @pytest.mark.parametrize( @@ -828,7 +978,9 @@ class TestUtilSSL(object): ("CERT_REQUIRED", ssl.CERT_REQUIRED), ], ) - def test_resolve_cert_reqs(self, candidate, requirements): + def test_resolve_cert_reqs( + self, candidate: int | str | None, requirements: int + ) -> None: assert resolve_cert_reqs(candidate) == requirements @pytest.mark.parametrize( @@ -840,11 +992,11 @@ def test_resolve_cert_reqs(self, candidate, requirements): (ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23), ], ) - def test_resolve_ssl_version(self, candidate, version): + def test_resolve_ssl_version(self, candidate: int | str, version: int) -> None: assert resolve_ssl_version(candidate) == version - def test_ssl_wrap_socket_loads_the_cert_chain(self): - socket = object() + def test_ssl_wrap_socket_loads_the_cert_chain(self) -> None: + socket = Mock() mock_context = Mock() ssl_wrap_socket( ssl_context=mock_context, sock=socket, certfile="/path/to/certfile" @@ -853,24 +1005,24 @@ def test_ssl_wrap_socket_loads_the_cert_chain(self): mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None) @patch("urllib3.util.ssl_.create_urllib3_context") - def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context): - socket = object() - ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED") + def test_ssl_wrap_socket_creates_new_context( + self, create_urllib3_context: mock.MagicMock + ) -> None: + socket = Mock() + ssl_wrap_socket(socket, cert_reqs=ssl.CERT_REQUIRED) - create_urllib3_context.assert_called_once_with( - None, "CERT_REQUIRED", ciphers=None - ) + create_urllib3_context.assert_called_once_with(None, 2, ciphers=None) - def test_ssl_wrap_socket_loads_verify_locations(self): - socket = object() + def test_ssl_wrap_socket_loads_verify_locations(self) -> None: + socket = Mock() mock_context = Mock() ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket) mock_context.load_verify_locations.assert_called_once_with( "/path/to/pem", None, None ) - def test_ssl_wrap_socket_loads_certificate_directories(self): - socket = object() + def test_ssl_wrap_socket_loads_certificate_directories(self) -> None: + socket = Mock() mock_context = Mock() ssl_wrap_socket( ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket @@ -879,8 +1031,8 @@ def test_ssl_wrap_socket_loads_certificate_directories(self): None, "/path/to/pems", None ) - def test_ssl_wrap_socket_loads_certificate_data(self): - socket = object() + def test_ssl_wrap_socket_loads_certificate_data(self) -> None: + socket = Mock() mock_context = Mock() ssl_wrap_socket( ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket @@ -889,7 +1041,9 @@ def test_ssl_wrap_socket_loads_certificate_data(self): None, None, "TOTALLY PEM DATA" ) - def _wrap_socket_and_mock_warn(self, sock, server_hostname): + def _wrap_socket_and_mock_warn( + self, sock: socket.socket, server_hostname: str | None + ) -> tuple[Mock, MagicMock]: mock_context = Mock() with patch("warnings.warn") as warn: ssl_wrap_socket( @@ -899,34 +1053,69 @@ def _wrap_socket_and_mock_warn(self, sock, server_hostname): ) return mock_context, warn - def test_ssl_wrap_socket_sni_hostname_use_or_warn(self): - """Test that either an SNI hostname is used or a warning is made.""" - sock = object() - context, warn = self._wrap_socket_and_mock_warn(sock, "www.google.com") - if util.HAS_SNI: - warn.assert_not_called() - context.wrap_socket.assert_called_once_with( - sock, server_hostname="www.google.com" - ) - else: - assert warn.call_count >= 1 - warnings = [call[0][1] for call in warn.call_args_list] - assert SNIMissingWarning in warnings - context.wrap_socket.assert_called_once_with(sock) - - def test_ssl_wrap_socket_sni_ip_address_no_warn(self): + def test_ssl_wrap_socket_sni_ip_address_no_warn(self) -> None: """Test that a warning is not made if server_hostname is an IP address.""" - sock = object() + sock = Mock() context, warn = self._wrap_socket_and_mock_warn(sock, "8.8.8.8") - if util.IS_SECURETRANSPORT: - context.wrap_socket.assert_called_once_with(sock, server_hostname="8.8.8.8") - else: - context.wrap_socket.assert_called_once_with(sock) + context.wrap_socket.assert_called_once_with(sock, server_hostname="8.8.8.8") warn.assert_not_called() - def test_ssl_wrap_socket_sni_none_no_warn(self): + def test_ssl_wrap_socket_sni_none_no_warn(self) -> None: """Test that a warning is not made if server_hostname is not given.""" - sock = object() + sock = Mock() context, warn = self._wrap_socket_and_mock_warn(sock, None) - context.wrap_socket.assert_called_once_with(sock) + context.wrap_socket.assert_called_once_with(sock, server_hostname=None) warn.assert_not_called() + + @pytest.mark.parametrize( + "openssl_version_number, implementation_name, version_info, reliable", + [ + # OpenSSL and Python OK -> reliable + (0x101010CF, "cpython", (3, 9, 3), True), + # Python OK -> reliable + (0x10101000, "cpython", (3, 9, 3), True), + (0x10101000, "pypy", (3, 6, 9), False), + # OpenSSL OK -> reliable + (0x101010CF, "cpython", (3, 9, 2), True), + # unreliable + (0x10101000, "cpython", (3, 9, 2), False), + ], + ) + def test_is_has_never_check_common_name_reliable( + self, + openssl_version_number: int, + implementation_name: str, + version_info: _TYPE_VERSION_INFO, + reliable: bool, + ) -> None: + assert ( + _is_has_never_check_common_name_reliable( + openssl_version_number, + implementation_name, + version_info, + ) + == reliable + ) + + +idna_blocker = ImportBlocker("idna") +module_stash = ModuleStash("urllib3") + + +class TestUtilWithoutIdna: + @classmethod + def setup_class(cls) -> None: + sys.modules.pop("idna", None) + + module_stash.stash() + sys.meta_path.insert(0, idna_blocker) + + @classmethod + def teardown_class(cls) -> None: + sys.meta_path.remove(idna_blocker) + module_stash.pop() + + def test_parse_url_without_idna(self) -> None: + url = "http://\uD7FF.com" + with pytest.raises(LocationParseError, match=f"Failed to parse: {url}"): + parse_url(url) diff --git a/test/test_wait.py b/test/test_wait.py index 38dad79dee..b5e9d1be14 100644 --- a/test/test_wait.py +++ b/test/test_wait.py @@ -1,13 +1,11 @@ +from __future__ import annotations + import signal -import socket import threading - -try: - from time import monotonic -except ImportError: - from time import time as monotonic - import time +import typing +from socket import socket, socketpair +from types import FrameType import pytest @@ -20,24 +18,25 @@ wait_for_write, ) -from .socketpair_helper import socketpair +TYPE_SOCKET_PAIR = typing.Tuple[socket, socket] +TYPE_WAIT_FOR = typing.Callable[..., bool] @pytest.fixture -def spair(): +def spair() -> typing.Generator[TYPE_SOCKET_PAIR, None, None]: a, b = socketpair() yield a, b a.close() b.close() -variants = [wait_for_socket, select_wait_for_socket] +variants: list[TYPE_WAIT_FOR] = [wait_for_socket, select_wait_for_socket] if _have_working_poll(): variants.append(poll_wait_for_socket) @pytest.mark.parametrize("wfs", variants) -def test_wait_for_socket(wfs, spair): +def test_wait_for_socket(wfs: TYPE_WAIT_FOR, spair: TYPE_SOCKET_PAIR) -> None: a, b = spair with pytest.raises(RuntimeError): @@ -56,7 +55,7 @@ def test_wait_for_socket(wfs, spair): try: while True: a.send(b"x" * 999999) - except (OSError, socket.error): + except OSError: pass # Now it's not writable anymore @@ -80,7 +79,7 @@ def test_wait_for_socket(wfs, spair): wfs(b, read=True) -def test_wait_for_read_write(spair): +def test_wait_for_read_write(spair: TYPE_SOCKET_PAIR) -> None: a, b = spair assert not wait_for_read(a, 0) @@ -96,7 +95,7 @@ def test_wait_for_read_write(spair): try: while True: a.send(b"x" * 999999) - except (OSError, socket.error): + except OSError: pass # Now it's not writable anymore @@ -105,18 +104,18 @@ def test_wait_for_read_write(spair): @pytest.mark.skipif(not hasattr(signal, "setitimer"), reason="need setitimer() support") @pytest.mark.parametrize("wfs", variants) -def test_eintr(wfs, spair): +def test_eintr(wfs: TYPE_WAIT_FOR, spair: TYPE_SOCKET_PAIR) -> None: a, b = spair interrupt_count = [0] - def handler(sig, frame): + def handler(sig: int, frame: FrameType | None) -> typing.Any: assert sig == signal.SIGALRM interrupt_count[0] += 1 old_handler = signal.signal(signal.SIGALRM, handler) try: assert not wfs(a, read=True, timeout=0) - start = monotonic() + start = time.monotonic() try: # Start delivering SIGALRM 10 times per second signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1) @@ -125,7 +124,7 @@ def handler(sig, frame): finally: # Stop delivering SIGALRM signal.setitimer(signal.ITIMER_REAL, 0) - end = monotonic() + end = time.monotonic() dur = end - start assert 0.9 < dur < 3 finally: @@ -136,11 +135,11 @@ def handler(sig, frame): @pytest.mark.skipif(not hasattr(signal, "setitimer"), reason="need setitimer() support") @pytest.mark.parametrize("wfs", variants) -def test_eintr_zero_timeout(wfs, spair): +def test_eintr_zero_timeout(wfs: TYPE_WAIT_FOR, spair: TYPE_SOCKET_PAIR) -> None: a, b = spair interrupt_count = [0] - def handler(sig, frame): + def handler(sig: int, frame: FrameType | None) -> typing.Any: assert sig == signal.SIGALRM interrupt_count[0] += 1 @@ -154,8 +153,11 @@ def handler(sig, frame): signal.setitimer(signal.ITIMER_REAL, 0.001, 0.001) # Hammer the system call for a while to trigger the # race. + end = time.monotonic() + 5 for i in range(100000): wfs(a, read=True, timeout=0) + if time.monotonic() >= end: + break finally: # Stop delivering SIGALRM signal.setitimer(signal.ITIMER_REAL, 0) @@ -167,22 +169,22 @@ def handler(sig, frame): @pytest.mark.skipif(not hasattr(signal, "setitimer"), reason="need setitimer() support") @pytest.mark.parametrize("wfs", variants) -def test_eintr_infinite_timeout(wfs, spair): +def test_eintr_infinite_timeout(wfs: TYPE_WAIT_FOR, spair: TYPE_SOCKET_PAIR) -> None: a, b = spair interrupt_count = [0] - def handler(sig, frame): + def handler(sig: int, frame: FrameType | None) -> typing.Any: assert sig == signal.SIGALRM interrupt_count[0] += 1 - def make_a_readable_after_one_second(): + def make_a_readable_after_one_second() -> None: time.sleep(1) b.send(b"x") old_handler = signal.signal(signal.SIGALRM, handler) try: assert not wfs(a, read=True, timeout=0) - start = monotonic() + start = time.monotonic() try: # Start delivering SIGALRM 10 times per second signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1) @@ -194,7 +196,7 @@ def make_a_readable_after_one_second(): # Stop delivering SIGALRM signal.setitimer(signal.ITIMER_REAL, 0) thread.join() - end = monotonic() + end = time.monotonic() dur = end - start assert 0.9 < dur < 3 finally: diff --git a/test/tz_stub.py b/test/tz_stub.py index c48f5df024..41b114bb2a 100644 --- a/test/tz_stub.py +++ b/test/tz_stub.py @@ -1,14 +1,22 @@ +from __future__ import annotations + import datetime import os import time +import typing from contextlib import contextmanager import pytest -from dateutil import tz + +try: + import zoneinfo # type: ignore[import] +except ImportError: + # Python < 3.9 + from backports import zoneinfo # type: ignore[no-redef] @contextmanager -def stub_timezone_ctx(tzname): +def stub_timezone_ctx(tzname: str | None) -> typing.Generator[None, None, None]: """ Switch to a locally-known timezone specified by `tzname`. On exit, restore the previous timezone. @@ -22,16 +30,16 @@ def stub_timezone_ctx(tzname): if not hasattr(time, "tzset"): pytest.skip("Timezone patching is not supported") - # Make sure the new timezone exists, at least in dateutil - new_tz = tz.gettz(tzname) - if new_tz is None: - raise ValueError("Invalid timezone specified: %r" % (tzname,)) + # Make sure the new timezone exists + try: + zoneinfo.ZoneInfo(tzname) + except zoneinfo.ZoneInfoNotFoundError: + raise ValueError(f"Invalid timezone specified: {tzname!r}") # Get the current timezone - local_tz = tz.tzlocal() - if local_tz is None: - raise EnvironmentError("Cannot determine current timezone") - old_tzname = datetime.datetime.now(local_tz).tzname() + old_tzname = datetime.datetime.now().astimezone().tzname() + if old_tzname is None: + raise OSError("Cannot determine current timezone") os.environ["TZ"] = tzname time.tzset() diff --git a/test/with_dummyserver/test_chunked_transfer.py b/test/with_dummyserver/test_chunked_transfer.py index 907a714603..c2dc12e769 100644 --- a/test/with_dummyserver/test_chunked_transfer.py +++ b/test/with_dummyserver/test_chunked_transfer.py @@ -1,4 +1,6 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + +import socket import pytest @@ -11,15 +13,12 @@ from urllib3.util import SKIP_HEADER from urllib3.util.retry import Retry -# Retry failed tests -pytestmark = pytest.mark.flaky - class TestChunkedTransfer(SocketDummyServerTestCase): - def start_chunked_handler(self): + def start_chunked_handler(self) -> None: self.buffer = b"" - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] while not self.buffer.endswith(b"\r\n0\r\n\r\n"): @@ -35,22 +34,30 @@ def socket_handler(listener): self._start_server(socket_handler) - def test_chunks(self): + @pytest.mark.parametrize( + "chunks", + [ + ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"], + [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"], + ], + ) + def test_chunks(self, chunks: list[bytes | str]) -> None: self.start_chunked_handler() - chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - pool.urlopen("GET", "/", chunks, headers=dict(DNT="1"), chunked=True) + pool.urlopen("GET", "/", body=chunks, headers=dict(DNT="1"), chunked=True) # type: ignore[arg-type] assert b"Transfer-Encoding" in self.buffer body = self.buffer.split(b"\r\n\r\n", 1)[1] lines = body.split(b"\r\n") # Empty chunks should have been skipped, as this could not be distinguished # from terminating the transmission - for i, chunk in enumerate([c for c in chunks if c]): + for i, chunk in enumerate( + [c.decode() if isinstance(c, bytes) else c for c in chunks if c] + ): assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8") assert lines[i * 2 + 1] == chunk.encode("utf-8") - def _test_body(self, data): + def _test_body(self, data: bytes | str | None) -> None: self.start_chunked_handler() with HTTPConnectionPool(self.host, self.port, retries=False) as pool: pool.urlopen("GET", "/", data, chunked=True) @@ -68,61 +75,82 @@ def _test_body(self, data): else: assert body == b"0\r\n\r\n" - def test_bytestring_body(self): + def test_bytestring_body(self) -> None: self._test_body(b"thisshouldbeonechunk\r\nasdf") - def test_unicode_body(self): - self._test_body(u"thisshouldbeonechunk\r\näöüß") + def test_unicode_body(self) -> None: + self._test_body("thisshouldbeonechunk\r\näöüß") - def test_empty_body(self): + def test_empty_body(self) -> None: self._test_body(None) - def test_empty_string_body(self): + def test_empty_string_body(self) -> None: self._test_body("") - def test_empty_iterable_body(self): - self._test_body([]) + def test_empty_iterable_body(self) -> None: + self._test_body(None) - def _get_header_lines(self, prefix): + def _get_header_lines(self, prefix: bytes) -> list[bytes]: header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() header_lines = header_block.split(b"\r\n")[1:] return [x for x in header_lines if x.startswith(prefix)] - def test_removes_duplicate_host_header(self): + def test_removes_duplicate_host_header(self) -> None: self.start_chunked_handler() - chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}, chunked=True) + pool.urlopen( + "GET", "/", body=chunks, headers={"Host": "test.org"}, chunked=True + ) host_headers = self._get_header_lines(b"host") assert len(host_headers) == 1 - def test_provides_default_host_header(self): + def test_provides_default_host_header(self) -> None: self.start_chunked_handler() - chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - pool.urlopen("GET", "/", chunks, chunked=True) + pool.urlopen("GET", "/", body=chunks, chunked=True) host_headers = self._get_header_lines(b"host") assert len(host_headers) == 1 - def test_provides_default_user_agent_header(self): + def test_provides_default_user_agent_header(self) -> None: self.start_chunked_handler() - chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - pool.urlopen("GET", "/", chunks, chunked=True) + pool.urlopen("GET", "/", body=chunks, chunked=True) ua_headers = self._get_header_lines(b"user-agent") assert len(ua_headers) == 1 - def test_remove_user_agent_header(self): + def test_preserve_user_agent_header(self) -> None: self.start_chunked_handler() - chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] with HTTPConnectionPool(self.host, self.port, retries=False) as pool: pool.urlopen( "GET", "/", - chunks, + body=chunks, + headers={"user-Agent": "test-agent"}, + chunked=True, + ) + + ua_headers = self._get_header_lines(b"user-agent") + # Validate that there is only one User-Agent header. + assert len(ua_headers) == 1 + # Validate that the existing User-Agent header is the one that was + # provided. + assert ua_headers[0] == b"user-agent: test-agent" + + def test_remove_user_agent_header(self) -> None: + self.start_chunked_handler() + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen( + "GET", + "/", + body=chunks, headers={"User-Agent": SKIP_HEADER}, chunked=True, ) @@ -130,11 +158,39 @@ def test_remove_user_agent_header(self): ua_headers = self._get_header_lines(b"user-agent") assert len(ua_headers) == 0 - def test_preserve_chunked_on_retry_after(self): + def test_provides_default_transfer_encoding_header(self) -> None: + self.start_chunked_handler() + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", body=chunks, chunked=True) + + te_headers = self._get_header_lines(b"transfer-encoding") + assert len(te_headers) == 1 + + def test_preserve_transfer_encoding_header(self) -> None: + self.start_chunked_handler() + chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen( + "GET", + "/", + body=chunks, + headers={"transfer-Encoding": "test-transfer-encoding"}, + chunked=True, + ) + + te_headers = self._get_header_lines(b"transfer-encoding") + # Validate that there is only one Transfer-Encoding header. + assert len(te_headers) == 1 + # Validate that the existing Transfer-Encoding header is the one that + # was provided. + assert te_headers[0] == b"transfer-encoding: test-transfer-encoding" + + def test_preserve_chunked_on_retry_after(self) -> None: self.chunked_requests = 0 - self.socks = [] + self.socks: list[socket.socket] = [] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for _ in range(2): sock = listener.accept()[0] self.socks.append(sock) @@ -159,10 +215,12 @@ def socket_handler(listener): sock.close() assert self.chunked_requests == 2 - def test_preserve_chunked_on_redirect(self, monkeypatch): + def test_preserve_chunked_on_redirect( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: self.chunked_requests = 0 - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for i in range(2): sock = listener.accept()[0] request = ConnectionMarker.consume_request(sock) @@ -187,10 +245,12 @@ def socket_handler(listener): ) assert self.chunked_requests == 2 - def test_preserve_chunked_on_broken_connection(self, monkeypatch): + def test_preserve_chunked_on_broken_connection( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: self.chunked_requests = 0 - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for i in range(2): sock = listener.accept()[0] request = ConnectionMarker.consume_request(sock) diff --git a/test/with_dummyserver/test_connection.py b/test/with_dummyserver/test_connection.py new file mode 100644 index 0000000000..d06a7551b5 --- /dev/null +++ b/test/with_dummyserver/test_connection.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import typing +from http.client import ResponseNotReady + +import pytest + +from dummyserver.testcase import HTTPDummyServerTestCase as server +from urllib3 import HTTPConnectionPool +from urllib3.response import HTTPResponse + + +@pytest.fixture() +def pool() -> typing.Generator[HTTPConnectionPool, None, None]: + server.setup_class() + + with HTTPConnectionPool(server.host, server.port) as pool: + yield pool + + server.teardown_class() + + +def test_returns_urllib3_HTTPResponse(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + method = "GET" + path = "/" + + conn.request(method, path) + + response = conn.getresponse() + + assert isinstance(response, HTTPResponse) + + +def test_does_not_release_conn(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + method = "GET" + path = "/" + + conn.request(method, path) + + response = conn.getresponse() + + response.release_conn() + assert pool.pool.qsize() == 0 # type: ignore[union-attr] + + +def test_releases_conn(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + assert conn is not None + + method = "GET" + path = "/" + + conn.request(method, path) + + response = conn.getresponse() + # If these variables are set by the pool + # then the response can release the connection + # back into the pool. + response._pool = pool # type: ignore[attr-defined] + response._connection = conn # type: ignore[attr-defined] + + response.release_conn() + assert pool.pool.qsize() == 1 # type: ignore[union-attr] + + +def test_double_getresponse(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + method = "GET" + path = "/" + + conn.request(method, path) + + _ = conn.getresponse() + + # Calling getrepsonse() twice should cause an error + with pytest.raises(ResponseNotReady): + conn.getresponse() + + +def test_connection_state_properties(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + assert conn.is_closed is True + assert conn.is_connected is False + assert conn.has_connected_to_proxy is False + assert conn.is_verified is False + assert conn.proxy_is_verified is None + + conn.connect() + + assert conn.is_closed is False + assert conn.is_connected is True + assert conn.has_connected_to_proxy is False + assert conn.is_verified is False + assert conn.proxy_is_verified is None + + conn.request("GET", "/") + resp = conn.getresponse() + assert resp.status == 200 + + conn.close() + + assert conn.is_closed is True + assert conn.is_connected is False + assert conn.has_connected_to_proxy is False + assert conn.is_verified is False + assert conn.proxy_is_verified is None + + +def test_set_tunnel_is_reset(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + assert conn.is_closed is True + assert conn.is_connected is False + assert conn.has_connected_to_proxy is False + assert conn.is_verified is False + assert conn.proxy_is_verified is None + + conn.set_tunnel(host="host", port=8080, scheme="http") + + assert conn._tunnel_host == "host" # type: ignore[attr-defined] + assert conn._tunnel_port == 8080 # type: ignore[attr-defined] + assert conn._tunnel_scheme == "http" # type: ignore[attr-defined] + + conn.close() + + assert conn._tunnel_host is None # type: ignore[attr-defined] + assert conn._tunnel_port is None # type: ignore[attr-defined] + assert conn._tunnel_scheme is None # type: ignore[attr-defined] + + +def test_invalid_tunnel_scheme(pool: HTTPConnectionPool) -> None: + conn = pool._get_conn() + + with pytest.raises(ValueError) as e: + conn.set_tunnel(host="host", port=8080, scheme="socks") + assert ( + str(e.value) + == "Invalid proxy scheme for tunneling: 'socks', must be either 'http' or 'https'" + ) diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index 408a23802d..1748fbdaae 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -1,18 +1,16 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations import io -import json -import logging import socket -import sys import time +import typing import warnings -from test import LONG_TIMEOUT, SHORT_TIMEOUT, onlyPy2 +from test import LONG_TIMEOUT, SHORT_TIMEOUT from threading import Event +from unittest import mock +from urllib.parse import urlencode -import mock import pytest -import six from dummyserver.server import HAS_IPV6_AND_DNS, NoIPv6Warning from dummyserver.testcase import HTTPDummyServerTestCase, SocketDummyServerTestCase @@ -24,33 +22,27 @@ DecodeError, EmptyPoolError, MaxRetryError, + NameResolutionError, NewConnectionError, ReadTimeoutError, UnrewindableBodyError, ) -from urllib3.packages.six import b, u -from urllib3.packages.six.moves.urllib.parse import urlencode +from urllib3.fields import _TYPE_FIELD_VALUE_TUPLE from urllib3.util import SKIP_HEADER, SKIPPABLE_HEADERS from urllib3.util.retry import RequestHistory, Retry -from urllib3.util.timeout import Timeout +from urllib3.util.timeout import _TYPE_TIMEOUT, Timeout from .. import INVALID_SOURCE_ADDRESSES, TARPIT_HOST, VALID_SOURCE_ADDRESSES from ..port_helpers import find_unused_port -pytestmark = pytest.mark.flaky -log = logging.getLogger("urllib3.connectionpool") -log.setLevel(logging.NOTSET) -log.addHandler(logging.StreamHandler(sys.stdout)) - - -def wait_for_socket(ready_event): +def wait_for_socket(ready_event: Event) -> None: ready_event.wait() ready_event.clear() class TestConnectionPoolTimeouts(SocketDummyServerTestCase): - def test_timeout_float(self): + def test_timeout_float(self) -> None: block_event = Event() ready_event = self.start_basic_handler(block_send=block_event, num=2) @@ -65,7 +57,7 @@ def test_timeout_float(self): block_event.set() # Pre-release block pool.request("GET", "/", timeout=LONG_TIMEOUT) - def test_conn_closed(self): + def test_conn_closed(self) -> None: block_event = Event() self.start_basic_handler(block_send=block_event, num=1) @@ -77,15 +69,15 @@ def test_conn_closed(self): try: with pytest.raises(ReadTimeoutError): pool.urlopen("GET", "/") - if conn.sock: + if not conn.is_closed: with pytest.raises(socket.error): - conn.sock.recv(1024) + conn.sock.recv(1024) # type: ignore[attr-defined] finally: pool._put_conn(conn) block_event.set() - def test_timeout(self): + def test_timeout(self) -> None: # Requests should time out when expected block_event = Event() ready_event = self.start_basic_handler(block_send=block_event, num=3) @@ -121,7 +113,7 @@ def test_timeout(self): pool.request("GET", "/", timeout=SHORT_TIMEOUT) block_event.set() # Release request - def test_connect_timeout(self): + def test_connect_timeout(self) -> None: url = "/" host, port = TARPIT_HOST, 80 timeout = Timeout(connect=SHORT_TIMEOUT) @@ -148,7 +140,7 @@ def test_connect_timeout(self): with pytest.raises(ConnectTimeoutError): pool.request("GET", url, timeout=timeout) - def test_total_applies_connect(self): + def test_total_applies_connect(self) -> None: host, port = TARPIT_HOST, 80 timeout = Timeout(total=None, connect=SHORT_TIMEOUT) @@ -169,7 +161,7 @@ def test_total_applies_connect(self): finally: conn.close() - def test_total_timeout(self): + def test_total_timeout(self) -> None: block_event = Event() ready_event = self.start_basic_handler(block_send=block_event, num=2) @@ -194,7 +186,7 @@ def test_total_timeout(self): with pytest.raises(ReadTimeoutError): pool.request("GET", "/") - def test_create_connection_timeout(self): + def test_create_connection_timeout(self) -> None: self.start_basic_handler(block_send=Event(), num=0) # needed for self.port timeout = Timeout(connect=SHORT_TIMEOUT, total=LONG_TIMEOUT) @@ -207,22 +199,22 @@ def test_create_connection_timeout(self): class TestConnectionPool(HTTPDummyServerTestCase): - def test_get(self): + def test_get(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/specific_method", fields={"method": "GET"}) assert r.status == 200, r.data - def test_post_url(self): + def test_post_url(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("POST", "/specific_method", fields={"method": "POST"}) assert r.status == 200, r.data - def test_urlopen_put(self): + def test_urlopen_put(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.urlopen("PUT", "/specific_method?method=PUT") assert r.status == 200, r.data - def test_wrong_specific_method(self): + def test_wrong_specific_method(self) -> None: # To make sure the dummy server is actually returning failed responses with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/specific_method", fields={"method": "POST"}) @@ -232,20 +224,20 @@ def test_wrong_specific_method(self): r = pool.request("POST", "/specific_method", fields={"method": "GET"}) assert r.status == 400, r.data - def test_upload(self): + def test_upload(self) -> None: data = "I'm in ur multipart form-data, hazing a cheezburgr" - fields = { + fields: dict[str, _TYPE_FIELD_VALUE_TUPLE] = { "upload_param": "filefield", "upload_filename": "lolcat.txt", - "upload_size": len(data), "filefield": ("lolcat.txt", data), } + fields["upload_size"] = len(data) # type: ignore with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("POST", "/upload", fields=fields) assert r.status == 200, r.data - def test_one_name_multiple_values(self): + def test_one_name_multiple_values(self) -> None: fields = [("foo", "a"), ("foo", "b")] with HTTPConnectionPool(self.host, self.port) as pool: @@ -257,7 +249,7 @@ def test_one_name_multiple_values(self): r = pool.request("POST", "/echo", fields=fields) assert r.data.count(b'name="foo"') == 2 - def test_request_method_body(self): + def test_request_method_body(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: body = b"hi" r = pool.request("POST", "/echo", body=body) @@ -267,47 +259,55 @@ def test_request_method_body(self): with pytest.raises(TypeError): pool.request("POST", "/echo", body=body, fields=fields) - def test_unicode_upload(self): - fieldname = u("myfile") - filename = u("\xe2\x99\xa5.txt") - data = u("\xe2\x99\xa5").encode("utf8") + def test_unicode_upload(self) -> None: + fieldname = "myfile" + filename = "\xe2\x99\xa5.txt" + data = "\xe2\x99\xa5".encode() size = len(data) - fields = { - u("upload_param"): fieldname, - u("upload_filename"): filename, - u("upload_size"): size, + fields: dict[str, _TYPE_FIELD_VALUE_TUPLE] = { + "upload_param": fieldname, + "upload_filename": filename, fieldname: (filename, data), } + fields["upload_size"] = size # type: ignore with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("POST", "/upload", fields=fields) assert r.status == 200, r.data - def test_nagle(self): - """ Test that connections have TCP_NODELAY turned on """ + def test_nagle(self) -> None: + """Test that connections have TCP_NODELAY turned on""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. with HTTPConnectionPool(self.host, self.port) as pool: conn = pool._get_conn() try: pool._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn.sock.getsockopt( + tcp_nodelay_setting = conn.sock.getsockopt( # type: ignore[attr-defined] socket.IPPROTO_TCP, socket.TCP_NODELAY ) assert tcp_nodelay_setting finally: conn.close() - def test_socket_options(self): + @pytest.mark.parametrize( + "socket_options", + [ + [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)], + ((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),), + ], + ) + def test_socket_options(self, socket_options: tuple[int, int, int]) -> None: """Test that connections accept socket options.""" # This test needs to be here in order to be run. socket.create_connection actually tries to # connect to the host provided so we need a dummyserver to be running. with HTTPConnectionPool( self.host, self.port, - socket_options=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)], + socket_options=socket_options, ) as pool: - s = pool._new_conn()._new_conn() # Get the socket + # Get the socket of a new connection. + s = pool._new_conn()._new_conn() # type: ignore[attr-defined] try: using_keepalive = ( s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 @@ -316,19 +316,24 @@ def test_socket_options(self): finally: s.close() - def test_disable_default_socket_options(self): - """Test that passing None disables all socket options.""" + @pytest.mark.parametrize("socket_options", [None, []]) + def test_disable_default_socket_options( + self, socket_options: list[int] | None + ) -> None: + """Test that passing None or empty list disables all socket options.""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - with HTTPConnectionPool(self.host, self.port, socket_options=None) as pool: - s = pool._new_conn()._new_conn() + with HTTPConnectionPool( + self.host, self.port, socket_options=socket_options + ) as pool: + s = pool._new_conn()._new_conn() # type: ignore[attr-defined] try: using_nagle = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) == 0 assert using_nagle finally: s.close() - def test_defaults_are_applied(self): + def test_defaults_are_applied(self) -> None: """Test that modifying the default socket options works.""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. @@ -337,10 +342,9 @@ def test_defaults_are_applied(self): conn = pool._new_conn() try: # Update the default socket options - conn.default_socket_options += [ - (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - ] - s = conn._new_conn() + assert conn.socket_options is not None + conn.socket_options += [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] # type: ignore[operator] + s = conn._new_conn() # type: ignore[attr-defined] nagle_disabled = ( s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) > 0 ) @@ -353,15 +357,15 @@ def test_defaults_are_applied(self): conn.close() s.close() - def test_connection_error_retries(self): - """ ECONNREFUSED error should raise a connection error, with retries """ + def test_connection_error_retries(self) -> None: + """ECONNREFUSED error should raise a connection error, with retries""" port = find_unused_port() with HTTPConnectionPool(self.host, port) as pool: with pytest.raises(MaxRetryError) as e: pool.request("GET", "/", retries=Retry(connect=3)) assert type(e.value.reason) == NewConnectionError - def test_timeout_success(self): + def test_timeout_success(self) -> None: timeout = Timeout(connect=3, read=5, total=None) with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: pool.request("GET", "/") @@ -376,16 +380,75 @@ def test_timeout_success(self): with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: pool.request("GET", "/") - def test_tunnel(self): + socket_timeout_reuse_testdata = pytest.mark.parametrize( + ["timeout", "expect_settimeout_calls"], + [ + (1, (1, 1)), + (None, (None, None)), + (Timeout(read=4), (None, 4)), + (Timeout(read=4, connect=5), (5, 4)), + (Timeout(connect=6), (6, None)), + ], + ) + + @socket_timeout_reuse_testdata + def test_socket_timeout_updated_on_reuse_constructor( + self, + timeout: _TYPE_TIMEOUT, + expect_settimeout_calls: typing.Sequence[float | None], + ) -> None: + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + # Make a request to create a new connection. + pool.urlopen("GET", "/") + + # Grab the connection and mock the inner socket. + assert pool.pool is not None + conn = pool.pool.get_nowait() + conn_sock = mock.Mock(wraps=conn.sock) + conn.sock = conn_sock + pool._put_conn(conn) + + # Assert that sock.settimeout() is called with the new connect timeout, then the read timeout. + pool.urlopen("GET", "/", timeout=timeout) + conn_sock.settimeout.assert_has_calls( + [mock.call(x) for x in expect_settimeout_calls] + ) + + @socket_timeout_reuse_testdata + def test_socket_timeout_updated_on_reuse_parameter( + self, + timeout: _TYPE_TIMEOUT, + expect_settimeout_calls: typing.Sequence[float | None], + ) -> None: + with HTTPConnectionPool(self.host, self.port) as pool: + # Make a request to create a new connection. + pool.urlopen("GET", "/", timeout=LONG_TIMEOUT) + + # Grab the connection and mock the inner socket. + assert pool.pool is not None + conn = pool.pool.get_nowait() + conn_sock = mock.Mock(wraps=conn.sock) + conn.sock = conn_sock + pool._put_conn(conn) + + # Assert that sock.settimeout() is called with the new connect timeout, then the read timeout. + pool.urlopen("GET", "/", timeout=timeout) + conn_sock.settimeout.assert_has_calls( + [mock.call(x) for x in expect_settimeout_calls] + ) + + def test_tunnel(self) -> None: # note the actual httplib.py has no tests for this functionality timeout = Timeout(total=None) with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: conn = pool._get_conn() try: conn.set_tunnel(self.host, self.port) - conn._tunnel = mock.Mock(return_value=None) - pool._make_request(conn, "GET", "/") - conn._tunnel.assert_called_once_with() + with mock.patch.object( + conn, "_tunnel", create=True, return_value=None + ) as conn_tunnel: + pool._make_request(conn, "GET", "/") + conn_tunnel.assert_called_once_with() finally: conn.close() @@ -394,13 +457,15 @@ def test_tunnel(self): with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: conn = pool._get_conn() try: - conn._tunnel = mock.Mock(return_value=None) - pool._make_request(conn, "GET", "/") - assert not conn._tunnel.called + with mock.patch.object( + conn, "_tunnel", create=True, return_value=None + ) as conn_tunnel: + pool._make_request(conn, "GET", "/") + assert not conn_tunnel.called finally: conn.close() - def test_redirect(self): + def test_redirect(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/redirect", fields={"target": "/"}, redirect=False) assert r.status == 303 @@ -409,13 +474,13 @@ def test_redirect(self): assert r.status == 200 assert r.data == b"Dummy server!" - def test_bad_connect(self): + def test_bad_connect(self) -> None: with HTTPConnectionPool("badhost.invalid", self.port) as pool: with pytest.raises(MaxRetryError) as e: pool.request("GET", "/", retries=5) - assert type(e.value.reason) == NewConnectionError + assert type(e.value.reason) == NameResolutionError - def test_keepalive(self): + def test_keepalive(self) -> None: with HTTPConnectionPool(self.host, self.port, block=True, maxsize=1) as pool: r = pool.request("GET", "/keepalive?close=0") r = pool.request("GET", "/keepalive?close=0") @@ -424,7 +489,7 @@ def test_keepalive(self): assert pool.num_connections == 1 assert pool.num_requests == 2 - def test_keepalive_close(self): + def test_keepalive_close(self) -> None: with HTTPConnectionPool( self.host, self.port, block=True, maxsize=1, timeout=2 ) as pool: @@ -439,7 +504,7 @@ def test_keepalive_close(self): # We grab the HTTPConnection object straight from the Queue, # because _get_conn() is where the check & reset occurs - # pylint: disable-msg=W0212 + assert pool.pool is not None conn = pool.pool.get() assert conn.sock is None pool._put_conn(conn) @@ -473,13 +538,13 @@ def test_keepalive_close(self): # Next request r = pool.request("GET", "/keepalive?close=0") - def test_post_with_urlencode(self): + def test_post_with_urlencode(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: data = {"banana": "hammock", "lol": "cat"} r = pool.request("POST", "/echo", fields=data, encode_multipart=False) assert r.data.decode("utf-8") == urlencode(data) - def test_post_with_multipart(self): + def test_post_with_multipart(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: data = {"banana": "hammock", "lol": "cat"} r = pool.request("POST", "/echo", fields=data, encode_multipart=True) @@ -503,7 +568,7 @@ def test_post_with_multipart(self): assert body[i] == expected_body[i] - def test_post_with_multipart__iter__(self): + def test_post_with_multipart__iter__(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: data = {"hello": "world"} r = pool.request( @@ -524,7 +589,7 @@ def test_post_with_multipart__iter__(self): b"--boundary--\r\n", ] - def test_check_gzip(self): + def test_check_gzip(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request( "GET", "/encodingrequest", headers={"accept-encoding": "gzip"} @@ -532,7 +597,7 @@ def test_check_gzip(self): assert r.headers.get("content-encoding") == "gzip" assert r.data == b"hello, world!" - def test_check_deflate(self): + def test_check_deflate(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request( "GET", "/encodingrequest", headers={"accept-encoding": "deflate"} @@ -540,7 +605,7 @@ def test_check_deflate(self): assert r.headers.get("content-encoding") == "deflate" assert r.data == b"hello, world!" - def test_bad_decode(self): + def test_bad_decode(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: with pytest.raises(DecodeError): pool.request( @@ -556,7 +621,7 @@ def test_bad_decode(self): headers={"accept-encoding": "garbage-gzip"}, ) - def test_connection_count(self): + def test_connection_count(self) -> None: with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: pool.request("GET", "/") pool.request("GET", "/") @@ -565,7 +630,7 @@ def test_connection_count(self): assert pool.num_connections == 1 assert pool.num_requests == 3 - def test_connection_count_bigpool(self): + def test_connection_count_bigpool(self) -> None: with HTTPConnectionPool(self.host, self.port, maxsize=16) as http_pool: http_pool.request("GET", "/") http_pool.request("GET", "/") @@ -574,7 +639,7 @@ def test_connection_count_bigpool(self): assert http_pool.num_connections == 1 assert http_pool.num_requests == 3 - def test_partial_response(self): + def test_partial_response(self) -> None: with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: req_data = {"lol": "cat"} resp_data = urlencode(req_data).encode("utf-8") @@ -584,7 +649,7 @@ def test_partial_response(self): assert r.read(5) == resp_data[:5] assert r.read() == resp_data[5:] - def test_lazy_load_twice(self): + def test_lazy_load_twice(self) -> None: # This test is sad and confusing. Need to figure out what's # going on with partial reads and socket reuse. @@ -637,12 +702,13 @@ def test_lazy_load_twice(self): assert pool.num_connections == 1 - def test_for_double_release(self): + def test_for_double_release(self) -> None: MAXSIZE = 5 # Check default state with HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) as pool: assert pool.num_connections == 0 + assert pool.pool is not None assert pool.pool.qsize() == MAXSIZE # Make an empty slot for testing @@ -667,16 +733,17 @@ def test_for_double_release(self): pool.urlopen("GET", "/") assert pool.pool.qsize() == MAXSIZE - 2 - def test_release_conn_parameter(self): + def test_release_conn_parameter(self) -> None: MAXSIZE = 5 with HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) as pool: + assert pool.pool is not None assert pool.pool.qsize() == MAXSIZE # Make request without releasing connection pool.request("GET", "/", release_conn=False, preload_content=False) assert pool.pool.qsize() == MAXSIZE - 1 - def test_dns_error(self): + def test_dns_error(self) -> None: with HTTPConnectionPool( "thishostdoesnotexist.invalid", self.port, timeout=0.001 ) as pool: @@ -684,17 +751,17 @@ def test_dns_error(self): pool.request("GET", "/test", retries=2) @pytest.mark.parametrize("char", [" ", "\r", "\n", "\x00"]) - def test_invalid_method_not_allowed(self, char): + def test_invalid_method_not_allowed(self, char: str) -> None: with pytest.raises(ValueError): with HTTPConnectionPool(self.host, self.port) as pool: pool.request("GET" + char, "/") - def test_percent_encode_invalid_target_chars(self): + def test_percent_encode_invalid_target_chars(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/echo_params?q=\r&k=\n \n") assert r.data == b"[('k', '\\n \\n'), ('q', '\\r')]" - def test_source_address(self): + def test_source_address(self) -> None: for addr, is_ipv6 in VALID_SOURCE_ADDRESSES: if is_ipv6 and not HAS_IPV6_AND_DNS: warnings.warn("No IPv6 support: skipping.", NoIPv6Warning) @@ -703,17 +770,25 @@ def test_source_address(self): self.host, self.port, source_address=addr, retries=False ) as pool: r = pool.request("GET", "/source_address") - assert r.data == b(addr[0]) + assert r.data == addr[0].encode() - def test_source_address_error(self): - for addr in INVALID_SOURCE_ADDRESSES: - with HTTPConnectionPool( - self.host, self.port, source_address=addr, retries=False - ) as pool: + @pytest.mark.parametrize( + "invalid_source_address, is_ipv6", INVALID_SOURCE_ADDRESSES + ) + def test_source_address_error( + self, invalid_source_address: tuple[str, int], is_ipv6: bool + ) -> None: + with HTTPConnectionPool( + self.host, self.port, source_address=invalid_source_address, retries=False + ) as pool: + if is_ipv6: + with pytest.raises(NameResolutionError): + pool.request("GET", f"/source_address?{invalid_source_address}") + else: with pytest.raises(NewConnectionError): - pool.request("GET", "/source_address?{0}".format(addr)) + pool.request("GET", f"/source_address?{invalid_source_address}") - def test_stream_keepalive(self): + def test_stream_keepalive(self) -> None: x = 2 with HTTPConnectionPool(self.host, self.port) as pool: @@ -731,21 +806,21 @@ def test_stream_keepalive(self): assert pool.num_connections == 1 assert pool.num_requests == x - def test_read_chunked_short_circuit(self): + def test_read_chunked_short_circuit(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: response = pool.request("GET", "/chunked", preload_content=False) response.read() with pytest.raises(StopIteration): next(response.read_chunked()) - def test_read_chunked_on_closed_response(self): + def test_read_chunked_on_closed_response(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: response = pool.request("GET", "/chunked", preload_content=False) response.close() with pytest.raises(StopIteration): next(response.read_chunked()) - def test_chunked_gzip(self): + def test_chunked_gzip(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: response = pool.request( "GET", "/chunked_gzip", preload_content=False, decode_content=True @@ -753,7 +828,7 @@ def test_chunked_gzip(self): assert b"123" * 4 == response.read() - def test_cleanup_on_connection_error(self): + def test_cleanup_on_connection_error(self) -> None: """ Test that connections are recycled to the pool on connection errors where no http response is received. @@ -762,6 +837,7 @@ def test_cleanup_on_connection_error(self): with HTTPConnectionPool( self.host, self.port, maxsize=poolsize, block=True ) as http: + assert http.pool is not None assert http.pool.qsize() == poolsize # force a connection error by supplying a non-existent @@ -788,62 +864,87 @@ def test_cleanup_on_connection_error(self): # the pool should still contain poolsize elements assert http.pool.qsize() == http.pool.maxsize - def test_mixed_case_hostname(self): + def test_mixed_case_hostname(self) -> None: with HTTPConnectionPool("LoCaLhOsT", self.port) as pool: - response = pool.request("GET", "http://LoCaLhOsT:%d/" % self.port) + response = pool.request("GET", f"http://LoCaLhOsT:{self.port}/") assert response.status == 200 - def test_preserves_path_dot_segments(self): - """ ConnectionPool preserves dot segments in the URI """ + def test_preserves_path_dot_segments(self) -> None: + """ConnectionPool preserves dot segments in the URI""" with HTTPConnectionPool(self.host, self.port) as pool: response = pool.request("GET", "/echo_uri/seg0/../seg2") assert response.data == b"/echo_uri/seg0/../seg2" - def test_default_user_agent_header(self): - """ ConnectionPool has a default user agent """ + def test_default_user_agent_header(self) -> None: + """ConnectionPool has a default user agent""" default_ua = _get_default_user_agent() custom_ua = "I'm not a web scraper, what are you talking about?" custom_ua2 = "Yet Another User Agent" with HTTPConnectionPool(self.host, self.port) as pool: # Use default user agent if no user agent was specified. r = pool.request("GET", "/headers") - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert request_headers.get("User-Agent") == _get_default_user_agent() # Prefer the request user agent over the default. headers = {"UsEr-AGENt": custom_ua} r = pool.request("GET", "/headers", headers=headers) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert request_headers.get("User-Agent") == custom_ua # Do not modify pool headers when using the default user agent. pool_headers = {"foo": "bar"} pool.headers = pool_headers r = pool.request("GET", "/headers") - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert request_headers.get("User-Agent") == default_ua assert "User-Agent" not in pool_headers pool.headers.update({"User-Agent": custom_ua2}) r = pool.request("GET", "/headers") - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert request_headers.get("User-Agent") == custom_ua2 - def test_no_user_agent_header(self): - """ ConnectionPool can suppress sending a user agent header """ + @pytest.mark.parametrize( + "headers", + [ + None, + {}, + {"User-Agent": "key"}, + {"user-agent": "key"}, + {b"uSeR-AgEnT": b"key"}, + {b"user-agent": "key"}, + ], + ) + @pytest.mark.parametrize("chunked", [True, False]) + def test_user_agent_header_not_sent_twice( + self, headers: dict[str, str] | None, chunked: bool + ) -> None: + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/headers", headers=headers, chunked=chunked) + request_headers = r.json() + + if not headers: + assert request_headers["User-Agent"].startswith("python-urllib3/") + assert "key" not in request_headers["User-Agent"] + else: + assert request_headers["User-Agent"] == "key" + + def test_no_user_agent_header(self) -> None: + """ConnectionPool can suppress sending a user agent header""" custom_ua = "I'm not a web scraper, what are you talking about?" with HTTPConnectionPool(self.host, self.port) as pool: # Suppress user agent in the request headers. no_ua_headers = {"User-Agent": SKIP_HEADER} r = pool.request("GET", "/headers", headers=no_ua_headers) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert "User-Agent" not in request_headers assert no_ua_headers["User-Agent"] == SKIP_HEADER # Suppress user agent in the pool headers. pool.headers = no_ua_headers r = pool.request("GET", "/headers") - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert "User-Agent" not in request_headers assert no_ua_headers["User-Agent"] == SKIP_HEADER @@ -851,7 +952,7 @@ def test_no_user_agent_header(self): pool_headers = {"User-Agent": custom_ua} pool.headers = pool_headers r = pool.request("GET", "/headers", headers=no_ua_headers) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert "User-Agent" not in request_headers assert no_ua_headers["User-Agent"] == SKIP_HEADER assert pool_headers.get("User-Agent") == custom_ua @@ -871,7 +972,13 @@ def test_no_user_agent_header(self): "user_agent", ["User-Agent", "user-agent", b"User-Agent", b"user-agent", None] ) @pytest.mark.parametrize("chunked", [True, False]) - def test_skip_header(self, accept_encoding, host, user_agent, chunked): + def test_skip_header( + self, + accept_encoding: str | None, + host: str | None, + user_agent: str | None, + chunked: bool, + ) -> None: headers = {} if accept_encoding is not None: @@ -883,7 +990,7 @@ def test_skip_header(self, accept_encoding, host, user_agent, chunked): with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/headers", headers=headers, chunked=chunked) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() if accept_encoding is None: assert "Accept-Encoding" in request_headers @@ -900,17 +1007,15 @@ def test_skip_header(self, accept_encoding, host, user_agent, chunked): @pytest.mark.parametrize("header", ["Content-Length", "content-length"]) @pytest.mark.parametrize("chunked", [True, False]) - def test_skip_header_non_supported(self, header, chunked): + def test_skip_header_non_supported(self, header: str, chunked: bool) -> None: with HTTPConnectionPool(self.host, self.port) as pool: - with pytest.raises(ValueError) as e: + with pytest.raises( + ValueError, + match="urllib3.util.SKIP_HEADER only supports 'Accept-Encoding', 'Host', 'User-Agent'", + ) as e: pool.request( "GET", "/headers", headers={header: SKIP_HEADER}, chunked=chunked ) - assert ( - str(e.value) - == "urllib3.util.SKIP_HEADER only supports 'Accept-Encoding', 'Host', 'User-Agent'" - ) - # Ensure that the error message stays up to date with 'SKIP_HEADER_SUPPORTED_HEADERS' assert all( ("'" + header.title() + "'") in str(e.value) @@ -920,7 +1025,12 @@ def test_skip_header_non_supported(self, header, chunked): @pytest.mark.parametrize("chunked", [True, False]) @pytest.mark.parametrize("pool_request", [True, False]) @pytest.mark.parametrize("header_type", [dict, HTTPHeaderDict]) - def test_headers_not_modified_by_request(self, chunked, pool_request, header_type): + def test_headers_not_modified_by_request( + self, + chunked: bool, + pool_request: bool, + header_type: type[dict[str, str] | HTTPHeaderDict], + ) -> None: # Test that the .request*() methods of ConnectionPool and HTTPConnection # don't modify the given 'headers' structure, instead they should # make their own internal copies at request time. @@ -933,10 +1043,7 @@ def test_headers_not_modified_by_request(self, chunked, pool_request, header_typ pool.request("GET", "/headers", chunked=chunked) else: conn = pool._get_conn() - if chunked: - conn.request_chunked("GET", "/headers") - else: - conn.request("GET", "/headers") + conn.request("GET", "/headers", chunked=chunked) assert pool.headers == {"key": "val"} assert isinstance(pool.headers, header_type) @@ -946,63 +1053,58 @@ def test_headers_not_modified_by_request(self, chunked, pool_request, header_typ pool.request("GET", "/headers", headers=headers, chunked=chunked) else: conn = pool._get_conn() - if chunked: - conn.request_chunked("GET", "/headers", headers=headers) - else: - conn.request("GET", "/headers", headers=headers) + conn.request("GET", "/headers", headers=headers, chunked=chunked) assert headers == {"key": "val"} - def test_bytes_header(self): + def test_request_chunked_is_deprecated( + self, + ) -> None: + with HTTPConnectionPool(self.host, self.port) as pool: + conn = pool._get_conn() + + with pytest.warns(DeprecationWarning) as w: + conn.request_chunked("GET", "/headers") # type: ignore[attr-defined] + assert len(w) == 1 and str(w[0].message) == ( + "HTTPConnection.request_chunked() is deprecated and will be removed in urllib3 v2.1.0. " + "Instead use HTTPConnection.request(..., chunked=True)." + ) + + resp = conn.getresponse() + assert resp.status == 200 + assert resp.json()["Transfer-Encoding"] == "chunked" + + def test_bytes_header(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: - headers = {"User-Agent": b"test header"} + headers = {"User-Agent": "test header"} r = pool.request("GET", "/headers", headers=headers) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert "User-Agent" in request_headers assert request_headers["User-Agent"] == "test header" @pytest.mark.parametrize( - "user_agent", [u"Schönefeld/1.18.0", u"Schönefeld/1.18.0".encode("iso-8859-1")] + "user_agent", ["Schönefeld/1.18.0", "Schönefeld/1.18.0".encode("iso-8859-1")] ) - def test_user_agent_non_ascii_user_agent(self, user_agent): - if six.PY2 and not isinstance(user_agent, str): - pytest.skip( - "Python 2 raises UnicodeEncodeError when passed a unicode header" - ) - + def test_user_agent_non_ascii_user_agent(self, user_agent: str) -> None: with HTTPConnectionPool(self.host, self.port, retries=False) as pool: r = pool.urlopen( "GET", "/headers", headers={"User-Agent": user_agent}, ) - request_headers = json.loads(r.data.decode("utf8")) + request_headers = r.json() assert "User-Agent" in request_headers - assert request_headers["User-Agent"] == u"Schönefeld/1.18.0" - - @onlyPy2 - def test_user_agent_non_ascii_fails_on_python_2(self): - with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - with pytest.raises(UnicodeEncodeError) as e: - pool.urlopen( - "GET", - "/headers", - headers={"User-Agent": u"Schönefeld/1.18.0"}, - ) - assert str(e.value) == ( - "'ascii' codec can't encode character u'\\xf6' in " - "position 3: ordinal not in range(128)" - ) + assert request_headers["User-Agent"] == "Schönefeld/1.18.0" class TestRetry(HTTPDummyServerTestCase): - def test_max_retry(self): + def test_max_retry(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: with pytest.raises(MaxRetryError): pool.request("GET", "/redirect", fields={"target": "/"}, retries=0) - def test_disabled_retry(self): - """ Disabled retries should disable redirect handling. """ + def test_disabled_retry(self) -> None: + """Disabled retries should disable redirect handling.""" with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/redirect", fields={"target": "/"}, retries=False) assert r.status == 303 @@ -1018,11 +1120,11 @@ def test_disabled_retry(self): with HTTPConnectionPool( "thishostdoesnotexist.invalid", self.port, timeout=0.001 ) as pool: - with pytest.raises(NewConnectionError): + with pytest.raises(NameResolutionError): pool.request("GET", "/test", retries=False) - def test_read_retries(self): - """ Should retry for status codes in the whitelist """ + def test_read_retries(self) -> None: + """Should retry for status codes in the forcelist""" with HTTPConnectionPool(self.host, self.port) as pool: retry = Retry(read=1, status_forcelist=[418]) resp = pool.request( @@ -1033,8 +1135,8 @@ def test_read_retries(self): ) assert resp.status == 200 - def test_read_total_retries(self): - """ HTTP response w/ status code in the whitelist should be retried """ + def test_read_total_retries(self) -> None: + """HTTP response w/ status code in the forcelist should be retried""" with HTTPConnectionPool(self.host, self.port) as pool: headers = {"test-name": "test_read_total_retries"} retry = Retry(total=1, status_forcelist=[418]) @@ -1043,48 +1145,48 @@ def test_read_total_retries(self): ) assert resp.status == 200 - def test_retries_wrong_whitelist(self): - """HTTP response w/ status code not in whitelist shouldn't be retried""" + def test_retries_wrong_forcelist(self) -> None: + """HTTP response w/ status code not in forcelist shouldn't be retried""" with HTTPConnectionPool(self.host, self.port) as pool: retry = Retry(total=1, status_forcelist=[202]) resp = pool.request( "GET", "/successful_retry", - headers={"test-name": "test_wrong_whitelist"}, + headers={"test-name": "test_wrong_forcelist"}, retries=retry, ) assert resp.status == 418 - def test_default_method_whitelist_retried(self): - """ urllib3 should retry methods in the default method whitelist """ + def test_default_method_forcelist_retried(self) -> None: + """urllib3 should retry methods in the default method forcelist""" with HTTPConnectionPool(self.host, self.port) as pool: retry = Retry(total=1, status_forcelist=[418]) resp = pool.request( "OPTIONS", "/successful_retry", - headers={"test-name": "test_default_whitelist"}, + headers={"test-name": "test_default_forcelist"}, retries=retry, ) assert resp.status == 200 - def test_retries_wrong_method_list(self): - """Method not in our whitelist should not be retried, even if code matches""" + def test_retries_wrong_method_list(self) -> None: + """Method not in our allowed list should not be retried, even if code matches""" with HTTPConnectionPool(self.host, self.port) as pool: - headers = {"test-name": "test_wrong_method_whitelist"} - retry = Retry(total=1, status_forcelist=[418], method_whitelist=["POST"]) + headers = {"test-name": "test_wrong_allowed_method"} + retry = Retry(total=1, status_forcelist=[418], allowed_methods=["POST"]) resp = pool.request( "GET", "/successful_retry", headers=headers, retries=retry ) assert resp.status == 418 - def test_read_retries_unsuccessful(self): + def test_read_retries_unsuccessful(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: headers = {"test-name": "test_read_retries_unsuccessful"} resp = pool.request("GET", "/successful_retry", headers=headers, retries=1) assert resp.status == 418 - def test_retry_reuse_safe(self): - """ It should be possible to reuse a Retry object across requests """ + def test_retry_reuse_safe(self) -> None: + """It should be possible to reuse a Retry object across requests""" with HTTPConnectionPool(self.host, self.port) as pool: headers = {"test-name": "test_retry_safe"} retry = Retry(total=1, status_forcelist=[418]) @@ -1099,7 +1201,7 @@ def test_retry_reuse_safe(self): ) assert resp.status == 200 - def test_retry_return_in_response(self): + def test_retry_return_in_response(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: headers = {"test-name": "test_retry_return_in_response"} retry = Retry(total=2, status_forcelist=[418]) @@ -1107,20 +1209,22 @@ def test_retry_return_in_response(self): "GET", "/successful_retry", headers=headers, retries=retry ) assert resp.status == 200 + assert resp.retries is not None assert resp.retries.total == 1 assert resp.retries.history == ( RequestHistory("GET", "/successful_retry", None, 418, None), ) - def test_retry_redirect_history(self): + def test_retry_redirect_history(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: resp = pool.request("GET", "/redirect", fields={"target": "/"}) assert resp.status == 200 + assert resp.retries is not None assert resp.retries.history == ( RequestHistory("GET", "/redirect?target=%2F", None, 303, "/"), ) - def test_multi_redirect_history(self): + def test_multi_redirect_history(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request( "GET", @@ -1129,6 +1233,7 @@ def test_multi_redirect_history(self): redirect=False, ) assert r.status == 303 + assert r.retries is not None assert r.retries.history == tuple() with HTTPConnectionPool(self.host, self.port) as pool: @@ -1148,6 +1253,7 @@ def test_multi_redirect_history(self): (307, "/multi_redirect?redirect_codes=302,200"), (302, "/multi_redirect?redirect_codes=200"), ] + assert r.retries is not None actual = [ (history.status, history.redirect_location) for history in r.retries.history @@ -1156,7 +1262,7 @@ def test_multi_redirect_history(self): class TestRetryAfter(HTTPDummyServerTestCase): - def test_retry_after(self): + def test_retry_after(self) -> None: # Request twice in a second to get a 429 response. with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request( @@ -1214,7 +1320,7 @@ def test_retry_after(self): ) assert r.status == 418 - def test_redirect_after(self): + def test_redirect_after(self) -> None: with HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/redirect_after", retries=False) assert r.status == 303 @@ -1242,7 +1348,7 @@ def test_redirect_after(self): class TestFileBodiesOnRetryOrRedirect(HTTPDummyServerTestCase): - def test_retries_put_filehandle(self): + def test_retries_put_filehandle(self) -> None: """HTTP PUT retry with a file-like object should not timeout""" with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: retry = Retry(total=3, status_forcelist=[418]) @@ -1265,7 +1371,7 @@ def test_retries_put_filehandle(self): ) assert resp.status == 200 - def test_redirect_put_file(self): + def test_redirect_put_file(self) -> None: """PUT with file object should work with a redirection response""" with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: retry = Retry(total=3, status_forcelist=[418]) @@ -1290,12 +1396,12 @@ def test_redirect_put_file(self): assert resp.status == 200 assert resp.data == data - def test_redirect_with_failed_tell(self): + def test_redirect_with_failed_tell(self) -> None: """Abort request if failed to get a position from tell()""" class BadTellObject(io.BytesIO): - def tell(self): - raise IOError + def tell(self) -> typing.NoReturn: + raise OSError body = BadTellObject(b"the data") url = "/redirect?target=/successful_retry" @@ -1303,13 +1409,14 @@ def tell(self): # which is unsupported by BytesIO. headers = {"Content-Length": "8"} with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: - with pytest.raises(UnrewindableBodyError) as e: + with pytest.raises( + UnrewindableBodyError, match="Unable to record file position for" + ): pool.urlopen("PUT", url, headers=headers, body=body) - assert "Unable to record file position for" in str(e.value) class TestRetryPoolSize(HTTPDummyServerTestCase): - def test_pool_size_retry(self): + def test_pool_size_retry(self) -> None: retries = Retry(total=1, raise_on_status=False, status_forcelist=[404]) with HTTPConnectionPool( self.host, self.port, maxsize=10, retries=retries, block=True @@ -1319,7 +1426,7 @@ def test_pool_size_retry(self): class TestRedirectPoolSize(HTTPDummyServerTestCase): - def test_pool_size_redirect(self): + def test_pool_size_redirect(self) -> None: retries = Retry( total=1, raise_on_status=False, status_forcelist=[404], redirect=True ) diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 92e23c93f8..de71c5a664 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -1,29 +1,31 @@ +from __future__ import annotations + +import contextlib import datetime -import json -import logging import os.path import shutil import ssl import sys import tempfile import warnings +from pathlib import Path from test import ( LONG_TIMEOUT, SHORT_TIMEOUT, TARPIT_HOST, - notOpenSSL098, notSecureTransport, - onlyPy279OrNewer, requires_network, requires_ssl_context_keyfile_password, resolvesLocalhostFQDN, ) +from test.conftest import ServerConfig +from unittest import mock -import mock import pytest import trustme import urllib3.util as util +import urllib3.util.ssl_ from dummyserver.server import ( DEFAULT_CA, DEFAULT_CA_KEY, @@ -32,34 +34,20 @@ ) from dummyserver.testcase import HTTPSDummyServerTestCase from urllib3 import HTTPSConnectionPool -from urllib3.connection import RECENT_DATE, VerifiedHTTPSConnection +from urllib3.connection import RECENT_DATE, HTTPSConnection, VerifiedHTTPSConnection from urllib3.exceptions import ( ConnectTimeoutError, - InsecurePlatformWarning, InsecureRequestWarning, MaxRetryError, ProtocolError, SSLError, SystemTimeWarning, ) -from urllib3.packages import six +from urllib3.util.ssl_match_hostname import CertificateError from urllib3.util.timeout import Timeout from .. import has_alpn -# Retry failed tests -pytestmark = pytest.mark.flaky - -ResourceWarning = getattr( - six.moves.builtins, "ResourceWarning", type("ResourceWarning", (), {}) -) - - -log = logging.getLogger("urllib3.connectionpool") -log.setLevel(logging.NOTSET) -log.addHandler(logging.StreamHandler(sys.stdout)) - - TLSv1_CERTS = DEFAULT_CERTS.copy() TLSv1_CERTS["ssl_version"] = getattr(ssl, "PROTOCOL_TLSv1", None) @@ -81,14 +69,32 @@ class TestHTTPS(HTTPSDummyServerTestCase): - tls_protocol_name = None + tls_protocol_name: str | None = None - def tls_protocol_deprecated(self): + def tls_protocol_not_default(self) -> bool: return self.tls_protocol_name in {"TLSv1", "TLSv1.1"} + def tls_version(self) -> ssl.TLSVersion: + if self.tls_protocol_name is None: + return pytest.skip("Skipping base test class") + try: + from ssl import TLSVersion + except ImportError: + return pytest.skip("ssl.TLSVersion isn't available") + return TLSVersion[self.tls_protocol_name.replace(".", "_")] + + def ssl_version(self) -> int: + if self.tls_protocol_name is None: + return pytest.skip("Skipping base test class") + attribute = f"PROTOCOL_{self.tls_protocol_name.replace('.', '_')}" + ssl_version = getattr(ssl, attribute, None) + if ssl_version is None: + return pytest.skip(f"ssl.{attribute} isn't available") + return ssl_version # type: ignore[no-any-return] + @classmethod - def setup_class(cls): - super(TestHTTPS, cls).setup_class() + def setup_class(cls) -> None: + super().setup_class() cls.certs_dir = tempfile.mkdtemp() # Start from existing root CA as we don't want to change the server certificate yet @@ -102,7 +108,7 @@ def setup_class(cls): # client cert chain intermediate_ca = root_ca.create_child_ca() - cert = intermediate_ca.issue_cert(u"example.com") + cert = intermediate_ca.issue_cert("example.com") encrypted_key = encrypt_key_pem(cert.private_key_pem, b"letmein") cert.private_key_pem.write_to_path( @@ -121,27 +127,33 @@ def setup_class(cls): ) @classmethod - def teardown_class(cls): - super(TestHTTPS, cls).teardown_class() + def teardown_class(cls) -> None: + super().teardown_class() shutil.rmtree(cls.certs_dir) - def test_simple(self): + def test_simple(self) -> None: with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: r = https_pool.request("GET", "/") assert r.status == 200, r.data - @resolvesLocalhostFQDN - def test_dotted_fqdn(self): + @resolvesLocalhostFQDN() + def test_dotted_fqdn(self) -> None: with HTTPSConnectionPool( - self.host + ".", self.port, ca_certs=DEFAULT_CA + self.host + ".", + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as pool: r = pool.request("GET", "/") assert r.status == 200, r.data - def test_client_intermediate(self): + def test_client_intermediate(self) -> None: """Check that certificate chains work well with client certs We generate an intermediate CA from the root CA, and issue a client certificate @@ -155,12 +167,13 @@ def test_client_intermediate(self): key_file=os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_KEY), cert_file=os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_PEM), ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: r = https_pool.request("GET", "/certificate") - subject = json.loads(r.data.decode("utf-8")) + subject = r.json() assert subject["organizationalUnitName"].startswith("Testing cert") - def test_client_no_intermediate(self): + def test_client_no_intermediate(self) -> None: """Check that missing links in certificate chains indeed break The only difference with test_client_intermediate is that we don't send the @@ -172,12 +185,13 @@ def test_client_no_intermediate(self): cert_file=os.path.join(self.certs_dir, CLIENT_NO_INTERMEDIATE_PEM), key_file=os.path.join(self.certs_dir, CLIENT_INTERMEDIATE_KEY), ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: with pytest.raises((SSLError, ProtocolError)): https_pool.request("GET", "/certificate", retries=False) - @requires_ssl_context_keyfile_password - def test_client_key_password(self): + @requires_ssl_context_keyfile_password() + def test_client_key_password(self) -> None: with HTTPSConnectionPool( self.host, self.port, @@ -185,29 +199,34 @@ def test_client_key_password(self): key_file=os.path.join(self.certs_dir, PASSWORD_CLIENT_KEYFILE), cert_file=os.path.join(self.certs_dir, CLIENT_CERT), key_password="letmein", + ssl_minimum_version=self.tls_version(), ) as https_pool: r = https_pool.request("GET", "/certificate") - subject = json.loads(r.data.decode("utf-8")) + subject = r.json() assert subject["organizationalUnitName"].startswith("Testing cert") - @requires_ssl_context_keyfile_password - def test_client_encrypted_key_requires_password(self): + @requires_ssl_context_keyfile_password() + def test_client_encrypted_key_requires_password(self) -> None: with HTTPSConnectionPool( self.host, self.port, key_file=os.path.join(self.certs_dir, PASSWORD_CLIENT_KEYFILE), cert_file=os.path.join(self.certs_dir, CLIENT_CERT), key_password=None, + ssl_minimum_version=self.tls_version(), ) as https_pool: - with pytest.raises(MaxRetryError) as e: + with pytest.raises(MaxRetryError, match="password is required") as e: https_pool.request("GET", "/certificate") - assert "password is required" in str(e.value) assert isinstance(e.value.reason, SSLError) - def test_verified(self): + def test_verified(self) -> None: with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: conn = https_pool._new_conn() assert conn.__class__ == VerifiedHTTPSConnection @@ -216,24 +235,12 @@ def test_verified(self): r = https_pool.request("GET", "/") assert r.status == 200 - # If we're using a deprecated TLS version we can remove 'DeprecationWarning' - if self.tls_protocol_deprecated(): - w = [x for x in w if x.category != DeprecationWarning] - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert w == [] - else: - assert len(w) > 1 - assert any(x.category == InsecureRequestWarning for x in w) + assert [str(wm) for wm in w] == [] - def test_verified_with_context(self): - ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) + def test_verified_with_context(self) -> None: + ctx = util.ssl_.create_urllib3_context( + cert_reqs=ssl.CERT_REQUIRED, ssl_minimum_version=self.tls_version() + ) ctx.load_verify_locations(cafile=DEFAULT_CA) with HTTPSConnectionPool(self.host, self.port, ssl_context=ctx) as https_pool: conn = https_pool._new_conn() @@ -242,26 +249,12 @@ def test_verified_with_context(self): with mock.patch("warnings.warn") as warn: r = https_pool.request("GET", "/") assert r.status == 200 + assert not warn.called, warn.call_args_list - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] - else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning - - def test_context_combines_with_ca_certs(self): - ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) + def test_context_combines_with_ca_certs(self) -> None: + ctx = util.ssl_.create_urllib3_context( + cert_reqs=ssl.CERT_REQUIRED, ssl_minimum_version=self.tls_version() + ) with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA, ssl_context=ctx ) as https_pool: @@ -271,35 +264,21 @@ def test_context_combines_with_ca_certs(self): with mock.patch("warnings.warn") as warn: r = https_pool.request("GET", "/") assert r.status == 200 + assert not warn.called, warn.call_args_list - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] - else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning - - @onlyPy279OrNewer - @notSecureTransport # SecureTransport does not support cert directories - @notOpenSSL098 # OpenSSL 0.9.8 does not support cert directories - def test_ca_dir_verified(self, tmpdir): + @notSecureTransport() # SecureTransport does not support cert directories + def test_ca_dir_verified(self, tmp_path: Path) -> None: # OpenSSL looks up certificates by the hash for their name, see c_rehash # TODO infer the bytes using `cryptography.x509.Name.public_bytes`. # https://github.com/pyca/cryptography/pull/3236 - shutil.copyfile(DEFAULT_CA, str(tmpdir / "81deb5f7.0")) + shutil.copyfile(DEFAULT_CA, str(tmp_path / "81deb5f7.0")) with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_REQUIRED", ca_cert_dir=str(tmpdir) + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ca_cert_dir=str(tmp_path), + ssl_minimum_version=self.tls_version(), ) as https_pool: conn = https_pool._new_conn() assert conn.__class__ == VerifiedHTTPSConnection @@ -308,38 +287,64 @@ def test_ca_dir_verified(self, tmpdir): r = https_pool.request("GET", "/") assert r.status == 200 - # If we're using a deprecated TLS version we can remove 'DeprecationWarning' - if self.tls_protocol_deprecated(): - w = [x for x in w if x.category != DeprecationWarning] - - assert w == [] + assert [str(wm) for wm in w] == [] - def test_invalid_common_name(self): + def test_invalid_common_name(self) -> None: with HTTPSConnectionPool( - "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "127.0.0.1", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: with pytest.raises(MaxRetryError) as e: - https_pool.request("GET", "/") + https_pool.request("GET", "/", retries=0) assert isinstance(e.value.reason, SSLError) assert "doesn't match" in str( e.value.reason ) or "certificate verify failed" in str(e.value.reason) - def test_verified_with_bad_ca_certs(self): + def test_verified_with_bad_ca_certs(self) -> None: with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=self.bad_ca_path + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=self.bad_ca_path, + ssl_minimum_version=self.tls_version(), ) as https_pool: with pytest.raises(MaxRetryError) as e: https_pool.request("GET", "/") assert isinstance(e.value.reason, SSLError) - assert "certificate verify failed" in str(e.value.reason), ( - "Expected 'certificate verify failed', instead got: %r" % e.value.reason - ) + assert ( + "certificate verify failed" in str(e.value.reason) + # PyPy is more specific + or "self signed certificate in certificate chain" in str(e.value.reason) + ), f"Expected 'certificate verify failed', instead got: {e.value.reason!r}" + + def test_wrap_socket_failure_resource_leak(self) -> None: + with HTTPSConnectionPool( + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=self.bad_ca_path, + ssl_minimum_version=self.tls_version(), + ) as https_pool: + conn = https_pool._get_conn() + try: + with pytest.raises(ssl.SSLError): + conn.connect() + + assert conn.sock is not None # type: ignore[attr-defined] + finally: + conn.close() - def test_verified_without_ca_certs(self): + def test_verified_without_ca_certs(self) -> None: # default is cert_reqs=None which is ssl.CERT_NONE with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_REQUIRED" + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ssl_minimum_version=self.tls_version(), ) as https_pool: with pytest.raises(MaxRetryError) as e: https_pool.request("GET", "/") @@ -348,6 +353,8 @@ def test_verified_without_ca_certs(self): # not pyopenssl is injected assert ( "No root certificates specified" in str(e.value.reason) + # PyPy is more specific + or "self signed certificate in certificate chain" in str(e.value.reason) # PyPy sometimes uses all-caps here or "certificate verify failed" in str(e.value.reason).lower() or "invalid certificate chain" in str(e.value.reason) @@ -358,18 +365,22 @@ def test_verified_without_ca_certs(self): "instead got: %r" % e.value.reason ) - def test_no_ssl(self): + def test_no_ssl(self) -> None: with HTTPSConnectionPool(self.host, self.port) as pool: - pool.ConnectionCls = None - with pytest.raises(SSLError): + pool.ConnectionCls = None # type: ignore[assignment] + with pytest.raises(ImportError): pool._new_conn() - with pytest.raises(MaxRetryError) as cm: + with pytest.raises(ImportError): pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) - def test_unverified_ssl(self): - """ Test that bare HTTPSConnection can connect, make requests """ - with HTTPSConnectionPool(self.host, self.port, cert_reqs=ssl.CERT_NONE) as pool: + def test_unverified_ssl(self) -> None: + """Test that bare HTTPSConnection can connect, make requests""" + with HTTPSConnectionPool( + self.host, + self.port, + cert_reqs=ssl.CERT_NONE, + ssl_minimum_version=self.tls_version(), + ) as pool: with mock.patch("warnings.warn") as warn: r = pool.request("GET", "/") assert r.status == 200 @@ -381,9 +392,13 @@ def test_unverified_ssl(self): calls = warn.call_args_list assert InsecureRequestWarning in [x[0][1] for x in calls] - def test_ssl_unverified_with_ca_certs(self): + def test_ssl_unverified_with_ca_certs(self) -> None: with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_NONE", ca_certs=self.bad_ca_path + self.host, + self.port, + cert_reqs="CERT_NONE", + ca_certs=self.bad_ca_path, + ssl_minimum_version=self.tls_version(), ) as pool: with mock.patch("warnings.warn") as warn: r = pool.request("GET", "/") @@ -395,43 +410,39 @@ def test_ssl_unverified_with_ca_certs(self): # warnings, which we want to ignore here. calls = warn.call_args_list - # If we're using a deprecated TLS version we can remove 'DeprecationWarning' - if self.tls_protocol_deprecated(): - calls = [call for call in calls if call[0][1] != DeprecationWarning] - - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - category = calls[0][0][1] - elif util.HAS_SNI: - category = calls[1][0][1] - else: - category = calls[2][0][1] + category = calls[0][0][1] assert category == InsecureRequestWarning - def test_assert_hostname_false(self): + def test_assert_hostname_false(self) -> None: with HTTPSConnectionPool( - "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "localhost", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_hostname = False https_pool.request("GET", "/") - def test_assert_specific_hostname(self): + def test_assert_specific_hostname(self) -> None: with HTTPSConnectionPool( - "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "localhost", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_hostname = "localhost" https_pool.request("GET", "/") - def test_server_hostname(self): + def test_server_hostname(self) -> None: with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, server_hostname="localhost", + ssl_minimum_version=self.tls_version(), ) as https_pool: conn = https_pool._new_conn() conn.request("GET", "/") @@ -440,12 +451,16 @@ def test_server_hostname(self): # pyopenssl doesn't let you pull the server_hostname back off the # socket, so only add this assertion if the attribute is there (i.e. # the python ssl module). - if hasattr(conn.sock, "server_hostname"): - assert conn.sock.server_hostname == "localhost" + if hasattr(conn.sock, "server_hostname"): # type: ignore[attr-defined] + assert conn.sock.server_hostname == "localhost" # type: ignore[attr-defined] - def test_assert_fingerprint_md5(self): + def test_assert_fingerprint_md5(self) -> None: with HTTPSConnectionPool( - "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "localhost", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_fingerprint = ( "55:39:BF:70:05:12:43:FA:1F:D1:BF:4E:E8:1B:07:1D" @@ -453,18 +468,26 @@ def test_assert_fingerprint_md5(self): https_pool.request("GET", "/") - def test_assert_fingerprint_sha1(self): + def test_assert_fingerprint_sha1(self) -> None: with HTTPSConnectionPool( - "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "localhost", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_fingerprint = ( "72:8B:55:4C:9A:FC:1E:88:A1:1C:AD:1B:B2:E7:CC:3E:DB:C8:F9:8A" ) https_pool.request("GET", "/") - def test_assert_fingerprint_sha256(self): + def test_assert_fingerprint_sha256(self) -> None: with HTTPSConnectionPool( - "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "localhost", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_fingerprint = ( "E3:59:8E:69:FF:C5:9F:C7:88:87:44:58:22:7F:90:8D:D9:BC:12:C4:90:79:D5:" @@ -472,22 +495,30 @@ def test_assert_fingerprint_sha256(self): ) https_pool.request("GET", "/") - def test_assert_invalid_fingerprint(self): - def _test_request(pool): + def test_assert_invalid_fingerprint(self) -> None: + def _test_request(pool: HTTPSConnectionPool) -> SSLError: with pytest.raises(MaxRetryError) as cm: pool.request("GET", "/", retries=0) assert isinstance(cm.value.reason, SSLError) return cm.value.reason with HTTPSConnectionPool( - self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + self.host, + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: - https_pool.assert_fingerprint = ( "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA" ) e = _test_request(https_pool) - assert "Fingerprints did not match." in str(e) + expected = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + got = "728b554c9afc1e88a11cad1bb2e7cc3edbc8f98a" + assert ( + str(e) + == f'Fingerprints did not match. Expected "{expected}", got "{got}"' + ) # Uneven length https_pool.assert_fingerprint = "AA:A" @@ -499,7 +530,7 @@ def _test_request(pool): e = _test_request(https_pool) assert "Fingerprint of invalid length:" in str(e) - def test_verify_none_and_bad_fingerprint(self): + def test_verify_none_and_bad_fingerprint(self) -> None: with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=self.bad_ca_path ) as https_pool: @@ -510,32 +541,39 @@ def test_verify_none_and_bad_fingerprint(self): https_pool.request("GET", "/", retries=0) assert isinstance(cm.value.reason, SSLError) - def test_verify_none_and_good_fingerprint(self): + def test_verify_none_and_good_fingerprint(self) -> None: with HTTPSConnectionPool( - "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=self.bad_ca_path + "127.0.0.1", + self.port, + cert_reqs="CERT_NONE", + ca_certs=self.bad_ca_path, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_fingerprint = ( "72:8B:55:4C:9A:FC:1E:88:A1:1C:AD:1B:B2:E7:CC:3E:DB:C8:F9:8A" ) https_pool.request("GET", "/") - @notSecureTransport - def test_good_fingerprint_and_hostname_mismatch(self): + @notSecureTransport() + def test_good_fingerprint_and_hostname_mismatch(self) -> None: # This test doesn't run with SecureTransport because we don't turn off # hostname validation without turning off all validation, which this # test doesn't do (deliberately). We should revisit this if we make # new decisions. with HTTPSConnectionPool( - "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + "127.0.0.1", + self.port, + cert_reqs="CERT_REQUIRED", + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.assert_fingerprint = ( "72:8B:55:4C:9A:FC:1E:88:A1:1C:AD:1B:B2:E7:CC:3E:DB:C8:F9:8A" ) https_pool.request("GET", "/") - @requires_network - def test_https_timeout(self): - + @requires_network() + def test_https_timeout(self) -> None: timeout = Timeout(total=None, connect=SHORT_TIMEOUT) with HTTPSConnectionPool( TARPIT_HOST, @@ -543,6 +581,7 @@ def test_https_timeout(self): timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", + ssl_minimum_version=self.tls_version(), ) as https_pool: with pytest.raises(ConnectTimeoutError): https_pool.request("GET", "/") @@ -554,6 +593,7 @@ def test_https_timeout(self): timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.ca_certs = DEFAULT_CA https_pool.assert_fingerprint = ( @@ -562,27 +602,39 @@ def test_https_timeout(self): timeout = Timeout(total=None) with HTTPSConnectionPool( - self.host, self.port, timeout=timeout, cert_reqs="CERT_NONE" + self.host, + self.port, + timeout=timeout, + cert_reqs="CERT_NONE", + ssl_minimum_version=self.tls_version(), ) as https_pool: - https_pool.request("GET", "/") + with pytest.warns(InsecureRequestWarning): + https_pool.request("GET", "/") - def test_tunnel(self): - """ test the _tunnel behavior """ + def test_tunnel(self) -> None: + """test the _tunnel behavior""" timeout = Timeout(total=None) with HTTPSConnectionPool( - self.host, self.port, timeout=timeout, cert_reqs="CERT_NONE" + self.host, + self.port, + timeout=timeout, + cert_reqs="CERT_NONE", + ssl_minimum_version=self.tls_version(), ) as https_pool: conn = https_pool._new_conn() try: conn.set_tunnel(self.host, self.port) - conn._tunnel = mock.Mock() - https_pool._make_request(conn, "GET", "/") - conn._tunnel.assert_called_once_with() + with mock.patch.object( + conn, "_tunnel", create=True, return_value=None + ) as conn_tunnel: + with pytest.warns(InsecureRequestWarning): + https_pool._make_request(conn, "GET", "/") + conn_tunnel.assert_called_once_with() finally: conn.close() - @requires_network - def test_enhanced_timeout(self): + @requires_network() + def test_enhanced_timeout(self) -> None: with HTTPSConnectionPool( TARPIT_HOST, self.port, @@ -625,7 +677,7 @@ def test_enhanced_timeout(self): finally: conn.close() - def test_enhanced_ssl_connection(self): + def test_enhanced_ssl_connection(self) -> None: fingerprint = "72:8B:55:4C:9A:FC:1E:88:A1:1C:AD:1B:B2:E7:CC:3E:DB:C8:F9:8A" with HTTPSConnectionPool( @@ -634,14 +686,17 @@ def test_enhanced_ssl_connection(self): cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, assert_fingerprint=fingerprint, + ssl_minimum_version=self.tls_version(), ) as https_pool: r = https_pool.request("GET", "/") assert r.status == 200 - @onlyPy279OrNewer - def test_ssl_correct_system_time(self): + def test_ssl_correct_system_time(self) -> None: with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.cert_reqs = "CERT_REQUIRED" https_pool.ca_certs = DEFAULT_CA @@ -649,10 +704,12 @@ def test_ssl_correct_system_time(self): w = self._request_without_resource_warnings("GET", "/") assert [] == w - @onlyPy279OrNewer - def test_ssl_wrong_system_time(self): + def test_ssl_wrong_system_time(self) -> None: with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.cert_reqs = "CERT_REQUIRED" https_pool.ca_certs = DEFAULT_CA @@ -665,89 +722,137 @@ def test_ssl_wrong_system_time(self): warning = w[0] assert SystemTimeWarning == warning.category + assert isinstance(warning.message, Warning) assert str(RECENT_DATE) in warning.message.args[0] - def _request_without_resource_warnings(self, method, url): + def _request_without_resource_warnings( + self, method: str, url: str + ) -> list[warnings.WarningMessage]: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: https_pool.request(method, url) w = [x for x in w if not isinstance(x.message, ResourceWarning)] - # If we're using a deprecated TLS version we can remove 'DeprecationWarning' - if self.tls_protocol_deprecated(): - w = [x for x in w if x.category != DeprecationWarning] - return w - def test_set_ssl_version_to_tls_version(self): + def test_set_ssl_version_to_tls_version(self) -> None: if self.tls_protocol_name is None: pytest.skip("Skipping base test class") with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA ) as https_pool: - https_pool.ssl_version = self.certs["ssl_version"] - r = https_pool.request("GET", "/") + https_pool.ssl_version = ssl_version = self.certs["ssl_version"] + if ssl_version is getattr(ssl, "PROTOCOL_TLS", object()): + cmgr: contextlib.AbstractContextManager[ + object + ] = contextlib.nullcontext() + else: + cmgr = pytest.warns( + DeprecationWarning, + match=r"'ssl_version' option is deprecated and will be removed " + r"in urllib3 v2\.1\.0\. Instead use 'ssl_minimum_version'", + ) + with cmgr: + r = https_pool.request("GET", "/") assert r.status == 200, r.data - def test_set_cert_default_cert_required(self): + def test_set_cert_default_cert_required(self) -> None: conn = VerifiedHTTPSConnection(self.host, self.port) - conn.set_cert() + with pytest.warns(DeprecationWarning) as w: + conn.set_cert() assert conn.cert_reqs == ssl.CERT_REQUIRED + assert len(w) == 1 and str(w[0].message) == ( + "HTTPSConnection.set_cert() is deprecated and will be removed in urllib3 v2.1.0. " + "Instead provide the parameters to the HTTPSConnection constructor." + ) + + @pytest.mark.parametrize("verify_mode", [ssl.CERT_NONE, ssl.CERT_REQUIRED]) + def test_set_cert_inherits_cert_reqs_from_ssl_context( + self, verify_mode: int + ) -> None: + ssl_context = urllib3.util.ssl_.create_urllib3_context(cert_reqs=verify_mode) + assert ssl_context.verify_mode == verify_mode + + conn = HTTPSConnection(self.host, self.port, ssl_context=ssl_context) + with pytest.warns(DeprecationWarning) as w: + conn.set_cert() + + assert conn.cert_reqs == verify_mode + assert ( + conn.ssl_context is not None and conn.ssl_context.verify_mode == verify_mode + ) + assert len(w) == 1 and str(w[0].message) == ( + "HTTPSConnection.set_cert() is deprecated and will be removed in urllib3 v2.1.0. " + "Instead provide the parameters to the HTTPSConnection constructor." + ) - def test_tls_protocol_name_of_socket(self): + def test_tls_protocol_name_of_socket(self) -> None: if self.tls_protocol_name is None: pytest.skip("Skipping base test class") with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: conn = https_pool._get_conn() try: conn.connect() - if not hasattr(conn.sock, "version"): + if not hasattr(conn.sock, "version"): # type: ignore[attr-defined] pytest.skip("SSLSocket.version() not available") - assert conn.sock.version() == self.tls_protocol_name + assert conn.sock.version() == self.tls_protocol_name # type: ignore[attr-defined] finally: conn.close() - def test_default_tls_version_deprecations(self): + def test_ssl_version_is_deprecated(self) -> None: if self.tls_protocol_name is None: pytest.skip("Skipping base test class") with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, self.port, ca_certs=DEFAULT_CA, ssl_version=self.ssl_version() ) as https_pool: conn = https_pool._get_conn() try: - with warnings.catch_warnings(record=True) as w: + with pytest.warns(DeprecationWarning) as w: conn.connect() - if not hasattr(conn.sock, "version"): - pytest.skip("SSLSocket.version() not available") finally: conn.close() - if self.tls_protocol_deprecated(): - assert len(w) == 1 - assert str(w[0].message) == ( - "Negotiating TLSv1/TLSv1.1 by default is deprecated " - "and will be disabled in urllib3 v2.0.0. Connecting to " - "'%s' with '%s' can be enabled by explicitly opting-in " - "with 'ssl_version'" % (self.host, self.tls_protocol_name) + assert len(w) >= 1 + assert any(x.category == DeprecationWarning for x in w) + assert any( + str(x.message) + == ( + "'ssl_version' option is deprecated and will be removed in " + "urllib3 v2.1.0. Instead use 'ssl_minimum_version'" ) - else: - assert w == [] + for x in w + ) - def test_no_tls_version_deprecation_with_ssl_version(self): + @pytest.mark.parametrize( + "ssl_version", [None, ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_CLIENT] + ) + def test_ssl_version_with_protocol_tls_or_client_not_deprecated( + self, ssl_version: int | None + ) -> None: if self.tls_protocol_name is None: pytest.skip("Skipping base test class") + if self.tls_protocol_not_default(): + pytest.skip( + f"Skipping because '{self.tls_protocol_name}' isn't set by default" + ) with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA, ssl_version=util.PROTOCOL_TLS + self.host, self.port, ca_certs=DEFAULT_CA, ssl_version=ssl_version ) as https_pool: conn = https_pool._get_conn() try: @@ -756,17 +861,19 @@ def test_no_tls_version_deprecation_with_ssl_version(self): finally: conn.close() - assert w == [] + assert [str(wm) for wm in w if wm.category != ResourceWarning] == [] - def test_no_tls_version_deprecation_with_ssl_context(self): + def test_no_tls_version_deprecation_with_ssl_context(self) -> None: if self.tls_protocol_name is None: pytest.skip("Skipping base test class") + ctx = util.ssl_.create_urllib3_context(ssl_minimum_version=self.tls_version()) + with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA, - ssl_context=util.ssl_.create_urllib3_context(), + ssl_context=ctx, ) as https_pool: conn = https_pool._get_conn() try: @@ -775,49 +882,125 @@ def test_no_tls_version_deprecation_with_ssl_context(self): finally: conn.close() - assert w == [] + assert [str(wm) for wm in w if wm.category != ResourceWarning] == [] + + def test_tls_version_maximum_and_minimum(self) -> None: + if self.tls_protocol_name is None: + pytest.skip("Skipping base test class") + + from ssl import TLSVersion + + min_max_versions = [ + (self.tls_version(), self.tls_version()), + (TLSVersion.MINIMUM_SUPPORTED, self.tls_version()), + (TLSVersion.MINIMUM_SUPPORTED, TLSVersion.MAXIMUM_SUPPORTED), + ] + + for minimum_version, maximum_version in min_max_versions: + with HTTPSConnectionPool( + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=minimum_version, + ssl_maximum_version=maximum_version, + ) as https_pool: + conn = https_pool._get_conn() + try: + conn.connect() + assert conn.sock.version() == self.tls_protocol_name # type: ignore[attr-defined] + finally: + conn.close() @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python 3.8+") - def test_sslkeylogfile(self, tmpdir, monkeypatch): + def test_sslkeylogfile( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: if not hasattr(util.SSLContext, "keylog_filename"): pytest.skip("requires OpenSSL 1.1.1+") - keylog_file = tmpdir.join("keylogfile.txt") + + keylog_file = tmp_path / "keylogfile.txt" monkeypatch.setenv("SSLKEYLOGFILE", str(keylog_file)) + with HTTPSConnectionPool( - self.host, self.port, ca_certs=DEFAULT_CA + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), ) as https_pool: r = https_pool.request("GET", "/") assert r.status == 200, r.data - assert keylog_file.check(file=1), "keylogfile '%s' should exist" % str( + assert keylog_file.is_file(), "keylogfile '%s' should exist" % str( keylog_file ) - assert keylog_file.read().startswith( + assert keylog_file.read_text().startswith( "# TLS secrets log file" ), "keylogfile '%s' should start with '# TLS secrets log file'" % str( keylog_file ) @pytest.mark.parametrize("sslkeylogfile", [None, ""]) - def test_sslkeylogfile_empty(self, monkeypatch, sslkeylogfile): + def test_sslkeylogfile_empty( + self, monkeypatch: pytest.MonkeyPatch, sslkeylogfile: str | None + ) -> None: # Assert that an HTTPS connection doesn't error out when given # no SSLKEYLOGFILE or an empty value (ie 'SSLKEYLOGFILE=') if sslkeylogfile is not None: monkeypatch.setenv("SSLKEYLOGFILE", sslkeylogfile) else: monkeypatch.delenv("SSLKEYLOGFILE", raising=False) - with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + with HTTPSConnectionPool( + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), + ) as pool: r = pool.request("GET", "/") assert r.status == 200, r.data - def test_alpn_default(self): + def test_alpn_default(self) -> None: """Default ALPN protocols are sent by default.""" if not has_alpn() or not has_alpn(ssl.SSLContext): pytest.skip("ALPN-support not available") - with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + with HTTPSConnectionPool( + self.host, + self.port, + ca_certs=DEFAULT_CA, + ssl_minimum_version=self.tls_version(), + ) as pool: r = pool.request("GET", "/alpn_protocol", retries=0) assert r.status == 200 assert r.data.decode("utf-8") == util.ALPN_PROTOCOLS[0] + def test_default_ssl_context_ssl_min_max_versions(self) -> None: + ctx = urllib3.util.ssl_.create_urllib3_context() + assert ctx.minimum_version == ssl.TLSVersion.TLSv1_2 + # urllib3 sets a default maximum version only when it is + # injected with PyOpenSSL- or SecureTransport-backed + # SSL-support. + # Otherwise, the default maximum version is set by Python's + # `ssl.SSLContext`. The value respects OpenSSL configuration and + # can be different from `ssl.TLSVersion.MAXIMUM_SUPPORTED`. + # https://github.com/urllib3/urllib3/issues/2477#issuecomment-1151452150 + if util.IS_PYOPENSSL or util.IS_SECURETRANSPORT: + expected_maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED + else: + expected_maximum_version = ssl.SSLContext( + ssl.PROTOCOL_TLS_CLIENT + ).maximum_version + assert ctx.maximum_version == expected_maximum_version + + def test_ssl_context_ssl_version_uses_ssl_min_max_versions(self) -> None: + with pytest.warns( + DeprecationWarning, + match=r"'ssl_version' option is deprecated and will be removed in " + r"urllib3 v2\.1\.0\. Instead use 'ssl_minimum_version'", + ): + ctx = urllib3.util.ssl_.create_urllib3_context( + ssl_version=self.ssl_version() + ) + assert ctx.minimum_version == self.tls_version() + assert ctx.maximum_version == self.tls_version() + @pytest.mark.usefixtures("requires_tlsv1") class TestHTTPS_TLSv1(TestHTTPS): @@ -843,63 +1026,117 @@ class TestHTTPS_TLSv1_3(TestHTTPS): certs = TLSv1_3_CERTS -class TestHTTPS_NoSAN: - def test_warning_for_certs_without_a_san(self, no_san_server): - """Ensure that a warning is raised when the cert from the server has - no Subject Alternative Name.""" - with mock.patch("warnings.warn") as warn: +class TestHTTPS_Hostname: + def test_can_validate_san(self, san_server: ServerConfig) -> None: + """Ensure that urllib3 can validate SANs with IP addresses in them.""" + with HTTPSConnectionPool( + san_server.host, + san_server.port, + cert_reqs="CERT_REQUIRED", + ca_certs=san_server.ca_certs, + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 + + def test_common_name_without_san_fails(self, no_san_server: ServerConfig) -> None: + with HTTPSConnectionPool( + no_san_server.host, + no_san_server.port, + cert_reqs="CERT_REQUIRED", + ca_certs=no_san_server.ca_certs, + ) as https_pool: + with pytest.raises( + MaxRetryError, + ) as e: + https_pool.request("GET", "/") + assert "mismatch, certificate is not valid" in str( + e.value + ) or "no appropriate subjectAltName" in str(e.value) + + def test_common_name_without_san_with_different_common_name( + self, no_san_server_with_different_commmon_name: ServerConfig + ) -> None: + ctx = urllib3.util.ssl_.create_urllib3_context() + try: + ctx.hostname_checks_common_name = True + except AttributeError: + pytest.skip("Couldn't set 'SSLContext.hostname_checks_common_name'") + + with HTTPSConnectionPool( + no_san_server_with_different_commmon_name.host, + no_san_server_with_different_commmon_name.port, + cert_reqs="CERT_REQUIRED", + ca_certs=no_san_server_with_different_commmon_name.ca_certs, + ssl_context=ctx, + ) as https_pool: + with pytest.raises(MaxRetryError) as e: + https_pool.request("GET", "/") + assert "mismatch, certificate is not valid for 'localhost'" in str( + e.value + ) or "hostname 'localhost' doesn't match 'example.com'" in str(e.value) + + @pytest.mark.parametrize("use_assert_hostname", [True, False]) + def test_hostname_checks_common_name_respected( + self, no_san_server: ServerConfig, use_assert_hostname: bool + ) -> None: + ctx = urllib3.util.ssl_.create_urllib3_context() + if not hasattr(ctx, "hostname_checks_common_name"): + pytest.skip("Test requires 'SSLContext.hostname_checks_common_name'") + ctx.load_verify_locations(no_san_server.ca_certs) + try: + ctx.hostname_checks_common_name = True + except AttributeError: + pytest.skip("Couldn't set 'SSLContext.hostname_checks_common_name'") + + err: MaxRetryError | None + try: with HTTPSConnectionPool( no_san_server.host, no_san_server.port, cert_reqs="CERT_REQUIRED", - ca_certs=no_san_server.ca_certs, + ssl_context=ctx, + assert_hostname=no_san_server.host if use_assert_hostname else None, ) as https_pool: - r = https_pool.request("GET", "/") - assert r.status == 200 - assert warn.called + https_pool.request("GET", "/") + except MaxRetryError as e: + err = e + else: + err = None + # commonName is only valid for DNS names, not IP addresses. + if no_san_server.host == "localhost": + assert err is None -class TestHTTPS_IPSAN: - def test_can_validate_ip_san(self, ip_san_server): - """Ensure that urllib3 can validate SANs with IP addresses in them.""" - try: - import ipaddress # noqa: F401 - except ImportError: - pytest.skip("Only runs on systems with an ipaddress module") - - with HTTPSConnectionPool( - ip_san_server.host, - ip_san_server.port, - cert_reqs="CERT_REQUIRED", - ca_certs=ip_san_server.ca_certs, - ) as https_pool: - r = https_pool.request("GET", "/") - assert r.status == 200 + # IP addresses should fail for commonName. + else: + assert err is not None + assert type(err.reason) == SSLError + assert isinstance( + err.reason.args[0], (ssl.SSLCertVerificationError, CertificateError) + ) -class TestHTTPS_IPv6Addr: - def test_strip_square_brackets_before_validating(self, ipv6_addr_server): - """Test that the fix for #760 works.""" +class TestHTTPS_IPV4SAN: + def test_can_validate_ip_san(self, ipv4_san_server: ServerConfig) -> None: + """Ensure that urllib3 can validate SANs with IP addresses in them.""" with HTTPSConnectionPool( - "[::1]", - ipv6_addr_server.port, + ipv4_san_server.host, + ipv4_san_server.port, cert_reqs="CERT_REQUIRED", - ca_certs=ipv6_addr_server.ca_certs, + ca_certs=ipv4_san_server.ca_certs, ) as https_pool: r = https_pool.request("GET", "/") assert r.status == 200 class TestHTTPS_IPV6SAN: - def test_can_validate_ipv6_san(self, ipv6_san_server): + @pytest.mark.parametrize("host", ["::1", "[::1]"]) + def test_can_validate_ipv6_san( + self, ipv6_san_server: ServerConfig, host: str + ) -> None: """Ensure that urllib3 can validate SANs with IPv6 addresses in them.""" - try: - import ipaddress # noqa: F401 - except ImportError: - pytest.skip("Only runs on systems with an ipaddress module") - with HTTPSConnectionPool( - "[::1]", + host, ipv6_san_server.port, cert_reqs="CERT_REQUIRED", ca_certs=ipv6_san_server.ca_certs, diff --git a/test/with_dummyserver/test_no_ssl.py b/test/with_dummyserver/test_no_ssl.py index 43e79b70b6..b89f703fac 100644 --- a/test/with_dummyserver/test_no_ssl.py +++ b/test/with_dummyserver/test_no_ssl.py @@ -3,30 +3,31 @@ Note: Import urllib3 inside the test functions to get the importblocker to work """ +from __future__ import annotations + import pytest import urllib3 from dummyserver.testcase import HTTPDummyServerTestCase, HTTPSDummyServerTestCase +from urllib3.exceptions import InsecureRequestWarning from ..test_no_ssl import TestWithoutSSL -# Retry failed tests -pytestmark = pytest.mark.flaky - class TestHTTPWithoutSSL(HTTPDummyServerTestCase, TestWithoutSSL): - def test_simple(self): + def test_simple(self) -> None: with urllib3.HTTPConnectionPool(self.host, self.port) as pool: r = pool.request("GET", "/") assert r.status == 200, r.data class TestHTTPSWithoutSSL(HTTPSDummyServerTestCase, TestWithoutSSL): - def test_simple(self): + def test_simple(self) -> None: with urllib3.HTTPSConnectionPool( self.host, self.port, cert_reqs="NONE" ) as pool: - try: - pool.request("GET", "/") - except urllib3.exceptions.SSLError as e: - assert "SSL module is not available" in str(e) + with pytest.warns(InsecureRequestWarning): + try: + pool.request("GET", "/") + except urllib3.exceptions.SSLError as e: + assert "SSL module is not available" in str(e) diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index d877cc99ac..c4f1947037 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -1,32 +1,33 @@ -import json +from __future__ import annotations + +import gzip from test import LONG_TIMEOUT +from unittest import mock import pytest from dummyserver.server import HAS_IPV6 from dummyserver.testcase import HTTPDummyServerTestCase, IPv6HTTPDummyServerTestCase +from urllib3 import HTTPHeaderDict, HTTPResponse, request from urllib3.connectionpool import port_by_scheme from urllib3.exceptions import MaxRetryError, URLSchemeUnknown from urllib3.poolmanager import PoolManager from urllib3.util.retry import Retry -# Retry failed tests -pytestmark = pytest.mark.flaky - class TestPoolManager(HTTPDummyServerTestCase): @classmethod - def setup_class(cls): - super(TestPoolManager, cls).setup_class() - cls.base_url = "http://%s:%d" % (cls.host, cls.port) - cls.base_url_alt = "http://%s:%d" % (cls.host_alt, cls.port) + def setup_class(cls) -> None: + super().setup_class() + cls.base_url = f"http://{cls.host}:{cls.port}" + cls.base_url_alt = f"http://{cls.host_alt}:{cls.port}" - def test_redirect(self): + def test_redirect(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/"}, redirect=False, ) @@ -34,19 +35,19 @@ def test_redirect(self): r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/"}, ) assert r.status == 200 assert r.data == b"Dummy server!" - def test_redirect_twice(self): + def test_redirect_twice(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/redirect" % self.base_url}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/redirect"}, redirect=False, ) @@ -54,20 +55,18 @@ def test_redirect_twice(self): r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/redirect?target={self.base_url}/"}, ) assert r.status == 200 assert r.data == b"Dummy server!" - def test_redirect_to_relative_url(self): + def test_redirect_to_relative_url(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={"target": "/redirect"}, redirect=False, ) @@ -75,19 +74,19 @@ def test_redirect_to_relative_url(self): assert r.status == 303 r = http.request( - "GET", "%s/redirect" % self.base_url, fields={"target": "/redirect"} + "GET", f"{self.base_url}/redirect", fields={"target": "/redirect"} ) assert r.status == 200 assert r.data == b"Dummy server!" - def test_cross_host_redirect(self): + def test_cross_host_redirect(self) -> None: with PoolManager() as http: - cross_host_location = "%s/echo?a=b" % self.base_url_alt + cross_host_location = f"{self.base_url_alt}/echo?a=b" with pytest.raises(MaxRetryError): http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={"target": cross_host_location}, timeout=LONG_TIMEOUT, retries=0, @@ -95,23 +94,24 @@ def test_cross_host_redirect(self): r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/echo?a=b" % self.base_url_alt}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/echo?a=b"}, timeout=LONG_TIMEOUT, retries=1, ) + assert isinstance(r, HTTPResponse) + assert r._pool is not None assert r._pool.host == self.host_alt - def test_too_many_redirects(self): + def test_too_many_redirects(self) -> None: with PoolManager() as http: with pytest.raises(MaxRetryError): http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={ - "target": "%s/redirect?target=%s/" - % (self.base_url, self.base_url) + "target": f"{self.base_url}/redirect?target={self.base_url}/" }, retries=1, preload_content=False, @@ -120,10 +120,9 @@ def test_too_many_redirects(self): with pytest.raises(MaxRetryError): http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={ - "target": "%s/redirect?target=%s/" - % (self.base_url, self.base_url) + "target": f"{self.base_url}/redirect?target={self.base_url}/" }, retries=Retry(total=None, redirect=1), preload_content=False, @@ -135,103 +134,107 @@ def test_too_many_redirects(self): pool = http.connection_from_host(self.host, self.port) assert pool.num_connections == 1 - def test_redirect_cross_host_remove_headers(self): + def test_redirect_cross_host_remove_headers(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/headers"}, headers={"Authorization": "foo"}, ) assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = r.json() assert "Authorization" not in data r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/headers"}, headers={"authorization": "foo"}, ) assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = r.json() assert "authorization" not in data assert "Authorization" not in data - def test_redirect_cross_host_no_remove_headers(self): + def test_redirect_cross_host_no_remove_headers(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/headers"}, headers={"Authorization": "foo"}, retries=Retry(remove_headers_on_redirect=[]), ) assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = r.json() assert data["Authorization"] == "foo" - def test_redirect_cross_host_set_removed_headers(self): + def test_redirect_cross_host_set_removed_headers(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/headers"}, headers={"X-API-Secret": "foo", "Authorization": "bar"}, retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), ) assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = r.json() assert "X-API-Secret" not in data assert data["Authorization"] == "bar" + headers = {"x-api-secret": "foo", "authorization": "bar"} r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"x-api-secret": "foo", "authorization": "bar"}, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url_alt}/headers"}, + headers=headers, retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), ) assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = r.json() assert "x-api-secret" not in data assert "X-API-Secret" not in data assert data["Authorization"] == "bar" - def test_redirect_without_preload_releases_connection(self): + # Ensure the header argument itself is not modified in-place. + assert headers == {"x-api-secret": "foo", "authorization": "bar"} + + def test_redirect_without_preload_releases_connection(self) -> None: with PoolManager(block=True, maxsize=2) as http: - r = http.request( - "GET", "%s/redirect" % self.base_url, preload_content=False - ) + r = http.request("GET", f"{self.base_url}/redirect", preload_content=False) + assert isinstance(r, HTTPResponse) + assert r._pool is not None assert r._pool.num_requests == 2 assert r._pool.num_connections == 1 assert len(http.pools) == 1 - def test_unknown_scheme(self): + def test_unknown_scheme(self) -> None: with PoolManager() as http: unknown_scheme = "unknown" - unknown_scheme_url = "%s://host" % unknown_scheme + unknown_scheme_url = f"{unknown_scheme}://host" with pytest.raises(URLSchemeUnknown) as e: r = http.request("GET", unknown_scheme_url) assert e.value.scheme == unknown_scheme r = http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={"target": unknown_scheme_url}, redirect=False, ) @@ -240,31 +243,29 @@ def test_unknown_scheme(self): with pytest.raises(URLSchemeUnknown) as e: r = http.request( "GET", - "%s/redirect" % self.base_url, + f"{self.base_url}/redirect", fields={"target": unknown_scheme_url}, ) assert e.value.scheme == unknown_scheme - def test_raise_on_redirect(self): + def test_raise_on_redirect(self) -> None: with PoolManager() as http: r = http.request( "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/redirect?target={self.base_url}/"}, retries=Retry(total=None, redirect=1, raise_on_redirect=False), ) assert r.status == 303 - def test_raise_on_status(self): + def test_raise_on_status(self) -> None: with PoolManager() as http: with pytest.raises(MaxRetryError): # the default is to raise r = http.request( "GET", - "%s/status" % self.base_url, + f"{self.base_url}/status", fields={"status": "500 Internal Server Error"}, retries=Retry(total=1, status_forcelist=range(500, 600)), ) @@ -273,7 +274,7 @@ def test_raise_on_status(self): # raise explicitly r = http.request( "GET", - "%s/status" % self.base_url, + f"{self.base_url}/status", fields={"status": "500 Internal Server Error"}, retries=Retry( total=1, status_forcelist=range(500, 600), raise_on_status=True @@ -283,7 +284,7 @@ def test_raise_on_status(self): # don't raise r = http.request( "GET", - "%s/status" % self.base_url, + f"{self.base_url}/status", fields={"status": "500 Internal Server Error"}, retries=Retry( total=1, status_forcelist=range(500, 600), raise_on_status=False @@ -292,7 +293,7 @@ def test_raise_on_status(self): assert r.status == 500 - def test_missing_port(self): + def test_missing_port(self) -> None: # Can a URL that lacks an explicit port like ':80' succeed, or # will all such URLs fail with an error? @@ -302,53 +303,140 @@ def test_missing_port(self): # our test server happens to be listening. port_by_scheme["http"] = self.port try: - r = http.request("GET", "http://%s/" % self.host, retries=0) + r = http.request("GET", f"http://{self.host}/", retries=0) finally: port_by_scheme["http"] = 80 assert r.status == 200 assert r.data == b"Dummy server!" - def test_headers(self): + def test_headers(self) -> None: with PoolManager(headers={"Foo": "bar"}) as http: - r = http.request("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) + r = http.request("GET", f"{self.base_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" - r = http.request("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) + r = http.request("POST", f"{self.base_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" - r = http.request_encode_url("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.base_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" - r = http.request_encode_body("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_body("POST", f"{self.base_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" r = http.request_encode_url( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + "GET", f"{self.base_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" r = http.request_encode_body( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + "GET", f"{self.base_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" - def test_http_with_ssl_keywords(self): + def test_headers_http_header_dict(self) -> None: + # Test uses a list of headers to assert the order + # that headers are sent in the request too. + + headers = HTTPHeaderDict() + headers.add("Foo", "bar") + headers.add("Multi", "1") + headers.add("Baz", "quux") + headers.add("Multi", "2") + + with PoolManager(headers=headers) as http: + r = http.request("GET", f"{self.base_url}/multi_headers") + returned_headers = r.json()["headers"] + assert returned_headers[-4:] == [ + ["Foo", "bar"], + ["Multi", "1"], + ["Multi", "2"], + ["Baz", "quux"], + ] + + r = http.request( + "GET", + f"{self.base_url}/multi_headers", + headers={ + **headers, + "Extra": "extra", + "Foo": "new", + }, + ) + returned_headers = r.json()["headers"] + assert returned_headers[-4:] == [ + ["Foo", "new"], + ["Multi", "1, 2"], + ["Baz", "quux"], + ["Extra", "extra"], + ] + + def test_headers_http_multi_header_multipart(self) -> None: + headers = HTTPHeaderDict() + headers.add("Multi", "1") + headers.add("Multi", "2") + old_headers = headers.copy() + + with PoolManager(headers=headers) as http: + r = http.request( + "POST", + f"{self.base_url}/multi_headers", + fields={"k": "v"}, + multipart_boundary="b", + encode_multipart=True, + ) + returned_headers = r.json()["headers"] + assert returned_headers[4:] == [ + ["Multi", "1"], + ["Multi", "2"], + ["Content-Type", "multipart/form-data; boundary=b"], + ] + # Assert that the previous headers weren't modified. + assert headers == old_headers + + # Set a default value for the Content-Type + headers["Content-Type"] = "multipart/form-data; boundary=b; field=value" + r = http.request( + "POST", + f"{self.base_url}/multi_headers", + fields={"k": "v"}, + multipart_boundary="b", + encode_multipart=True, + ) + returned_headers = r.json()["headers"] + assert returned_headers[4:] == [ + ["Multi", "1"], + ["Multi", "2"], + # Uses the set value, not the one that would be generated. + ["Content-Type", "multipart/form-data; boundary=b; field=value"], + ] + + def test_body(self) -> None: + with PoolManager() as http: + r = http.request("POST", f"{self.base_url}/echo", body=b"test") + assert r.data == b"test" + + def test_http_with_ssl_keywords(self) -> None: with PoolManager(ca_certs="REQUIRED") as http: - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + r = http.request("GET", f"http://{self.host}:{self.port}/") + assert r.status == 200 + + def test_http_with_server_hostname(self) -> None: + with PoolManager(server_hostname="example.com") as http: + r = http.request("GET", f"http://{self.host}:{self.port}/") assert r.status == 200 - def test_http_with_ca_cert_dir(self): + def test_http_with_ca_cert_dir(self) -> None: with PoolManager(ca_certs="REQUIRED", ca_cert_dir="/nosuchdir") as http: - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + r = http.request("GET", f"http://{self.host}:{self.port}/") assert r.status == 200 @pytest.mark.parametrize( @@ -364,20 +452,167 @@ def test_http_with_ca_cert_dir(self): ("/echo_uri?[]", b"/echo_uri?%5B%5D"), ], ) - def test_encode_http_target(self, target, expected_target): + def test_encode_http_target(self, target: str, expected_target: bytes) -> None: with PoolManager() as http: - url = "http://%s:%d%s" % (self.host, self.port, target) + url = f"http://{self.host}:{self.port}{target}" r = http.request("GET", url) assert r.data == expected_target + def test_top_level_request(self) -> None: + r = request("GET", f"{self.base_url}/") + assert r.status == 200 + assert r.data == b"Dummy server!" + + def test_top_level_request_without_keyword_args(self) -> None: + body = "" + with pytest.raises(TypeError): + request("GET", f"{self.base_url}/", body) # type: ignore[misc] + + def test_top_level_request_with_body(self) -> None: + r = request("POST", f"{self.base_url}/echo", body=b"test") + assert r.status == 200 + assert r.data == b"test" + + def test_top_level_request_with_preload_content(self) -> None: + r = request("GET", f"{self.base_url}/echo", preload_content=False) + assert r.status == 200 + assert r.connection is not None + r.data + assert r.connection is None + + def test_top_level_request_with_decode_content(self) -> None: + r = request( + "GET", + f"{self.base_url}/encodingrequest", + headers={"accept-encoding": "gzip"}, + decode_content=False, + ) + assert r.status == 200 + assert gzip.decompress(r.data) == b"hello, world!" + + r = request( + "GET", + f"{self.base_url}/encodingrequest", + headers={"accept-encoding": "gzip"}, + decode_content=True, + ) + assert r.status == 200 + assert r.data == b"hello, world!" + + def test_top_level_request_with_redirect(self) -> None: + r = request( + "GET", + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/"}, + redirect=False, + ) + + assert r.status == 303 + + r = request( + "GET", + f"{self.base_url}/redirect", + fields={"target": f"{self.base_url}/"}, + redirect=True, + ) + + assert r.status == 200 + assert r.data == b"Dummy server!" + + def test_top_level_request_with_retries(self) -> None: + r = request("GET", f"{self.base_url}/redirect", retries=False) + assert r.status == 303 + + r = request("GET", f"{self.base_url}/redirect", retries=3) + assert r.status == 200 + + def test_top_level_request_with_timeout(self) -> None: + with mock.patch("urllib3.poolmanager.RequestMethods.request") as mockRequest: + mockRequest.return_value = HTTPResponse(status=200) + + r = request("GET", f"{self.base_url}/redirect", timeout=2.5) + + assert r.status == 200 + + mockRequest.assert_called_with( + "GET", + f"{self.base_url}/redirect", + body=None, + fields=None, + headers=None, + preload_content=True, + decode_content=True, + redirect=True, + retries=None, + timeout=2.5, + json=None, + ) + + @pytest.mark.parametrize( + "headers", + [ + None, + {"content-Type": "application/json"}, + {"content-Type": "text/plain"}, + {"attribute": "value", "CONTENT-TYPE": "application/json"}, + HTTPHeaderDict(cookie="foo, bar"), + ], + ) + def test_request_with_json(self, headers: HTTPHeaderDict) -> None: + body = {"attribute": "value"} + r = request( + method="POST", url=f"{self.base_url}/echo_json", headers=headers, json=body + ) + assert r.status == 200 + assert r.json() == body + if headers is not None and "application/json" not in headers.values(): + assert "text/plain" in r.headers["Content-Type"].replace(" ", "").split(",") + else: + assert "application/json" in r.headers["Content-Type"].replace( + " ", "" + ).split(",") + + def test_top_level_request_with_json_with_httpheaderdict(self) -> None: + body = {"attribute": "value"} + header = HTTPHeaderDict(cookie="foo, bar") + with PoolManager(headers=header) as http: + r = http.request(method="POST", url=f"{self.base_url}/echo_json", json=body) + assert r.status == 200 + assert r.json() == body + assert "application/json" in r.headers["Content-Type"].replace( + " ", "" + ).split(",") + + def test_top_level_request_with_body_and_json(self) -> None: + match = "request got values for both 'body' and 'json' parameters which are mutually exclusive" + with pytest.raises(TypeError, match=match): + body = {"attribute": "value"} + request(method="POST", url=f"{self.base_url}/echo", body="", json=body) + + def test_top_level_request_with_invalid_body(self) -> None: + class BadBody: + def __repr__(self) -> str: + return "" + + with pytest.raises(TypeError) as e: + request( + method="POST", + url=f"{self.base_url}/echo", + body=BadBody(), # type: ignore[arg-type] + ) + assert str(e.value) == ( + "'body' must be a bytes-like object, file-like " + "object, or iterable. Instead was " + ) + @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not supported on this system") class TestIPv6PoolManager(IPv6HTTPDummyServerTestCase): @classmethod - def setup_class(cls): - super(TestIPv6PoolManager, cls).setup_class() - cls.base_url = "http://[%s]:%d" % (cls.host, cls.port) + def setup_class(cls) -> None: + super().setup_class() + cls.base_url = f"http://[{cls.host}]:{cls.port}" - def test_ipv6(self): + def test_ipv6(self) -> None: with PoolManager() as http: http.request("GET", self.base_url) diff --git a/test/with_dummyserver/test_proxy_poolmanager.py b/test/with_dummyserver/test_proxy_poolmanager.py index 67cee77a58..f4620643f5 100644 --- a/test/with_dummyserver/test_proxy_poolmanager.py +++ b/test/with_dummyserver/test_proxy_poolmanager.py @@ -1,53 +1,54 @@ -import json +from __future__ import annotations + +import binascii +import hashlib +import ipaddress import os.path +import pathlib import shutil import socket +import ssl import tempfile -from test import ( - LONG_TIMEOUT, - SHORT_TIMEOUT, - onlyPy2, - onlyPy3, - onlySecureTransport, - withPyOpenSSL, -) +from test import LONG_TIMEOUT, SHORT_TIMEOUT, onlySecureTransport, withPyOpenSSL +from test.conftest import ServerConfig import pytest import trustme +import urllib3.exceptions from dummyserver.server import DEFAULT_CA, HAS_IPV6, get_unreachable_address from dummyserver.testcase import HTTPDummyProxyTestCase, IPv6HTTPDummyProxyTestCase +from urllib3 import HTTPResponse from urllib3._collections import HTTPHeaderDict -from urllib3.connectionpool import VerifiedHTTPSConnection, connection_from_url +from urllib3.connection import VerifiedHTTPSConnection +from urllib3.connectionpool import connection_from_url from urllib3.exceptions import ( ConnectTimeoutError, + InsecureRequestWarning, MaxRetryError, ProxyError, + ProxySchemeUnknown, ProxySchemeUnsupported, + ReadTimeoutError, SSLError, ) from urllib3.poolmanager import ProxyManager, proxy_from_url from urllib3.util.ssl_ import create_urllib3_context +from urllib3.util.timeout import Timeout from .. import TARPIT_HOST, requires_network -# Retry failed tests -pytestmark = pytest.mark.flaky - class TestHTTPProxyManager(HTTPDummyProxyTestCase): @classmethod - def setup_class(cls): - super(TestHTTPProxyManager, cls).setup_class() - cls.http_url = "http://%s:%d" % (cls.http_host, cls.http_port) - cls.http_url_alt = "http://%s:%d" % (cls.http_host_alt, cls.http_port) - cls.https_url = "https://%s:%d" % (cls.https_host, cls.https_port) - cls.https_url_alt = "https://%s:%d" % (cls.https_host_alt, cls.https_port) - cls.proxy_url = "http://%s:%d" % (cls.proxy_host, cls.proxy_port) - cls.https_proxy_url = "https://%s:%d" % ( - cls.proxy_host, - cls.https_proxy_port, - ) + def setup_class(cls) -> None: + super().setup_class() + cls.http_url = f"http://{cls.http_host}:{int(cls.http_port)}" + cls.http_url_alt = f"http://{cls.http_host_alt}:{int(cls.http_port)}" + cls.https_url = f"https://{cls.https_host}:{int(cls.https_port)}" + cls.https_url_alt = f"https://{cls.https_host_alt}:{int(cls.https_port)}" + cls.proxy_url = f"http://{cls.proxy_host}:{int(cls.proxy_port)}" + cls.https_proxy_url = f"https://{cls.proxy_host}:{int(cls.https_proxy_port)}" # Generate another CA to test verification failure cls.certs_dir = tempfile.mkdtemp() @@ -57,29 +58,27 @@ def setup_class(cls): bad_ca.cert_pem.write_to_path(cls.bad_ca_path) @classmethod - def teardown_class(cls): - super(TestHTTPProxyManager, cls).teardown_class() + def teardown_class(cls) -> None: + super().teardown_class() shutil.rmtree(cls.certs_dir) - def test_basic_proxy(self): + def test_basic_proxy(self) -> None: with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url) + r = http.request("GET", f"{self.http_url}/") assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) + r = http.request("GET", f"{self.https_url}/") assert r.status == 200 - @onlyPy3 - def test_https_proxy(self): + def test_https_proxy(self) -> None: with proxy_from_url(self.https_proxy_url, ca_certs=DEFAULT_CA) as https: - r = https.request("GET", "%s/" % self.https_url) + r = https.request("GET", f"{self.https_url}/") assert r.status == 200 - r = https.request("GET", "%s/" % self.http_url) + r = https.request("GET", f"{self.http_url}/") assert r.status == 200 - @onlyPy3 - def test_https_proxy_with_proxy_ssl_context(self): + def test_https_proxy_with_proxy_ssl_context(self) -> None: proxy_ssl_context = create_urllib3_context() proxy_ssl_context.load_verify_locations(DEFAULT_CA) with proxy_from_url( @@ -87,67 +86,54 @@ def test_https_proxy_with_proxy_ssl_context(self): proxy_ssl_context=proxy_ssl_context, ca_certs=DEFAULT_CA, ) as https: - r = https.request("GET", "%s/" % self.https_url) + r = https.request("GET", f"{self.https_url}/") assert r.status == 200 - r = https.request("GET", "%s/" % self.http_url) - assert r.status == 200 - - @onlyPy2 - def test_https_proxy_not_supported(self): - with proxy_from_url(self.https_proxy_url, ca_certs=DEFAULT_CA) as https: - r = https.request("GET", "%s/" % self.http_url) + r = https.request("GET", f"{self.http_url}/") assert r.status == 200 - with pytest.raises(ProxySchemeUnsupported) as excinfo: - https.request("GET", "%s/" % self.https_url) - - assert "is not supported in Python 2" in str(excinfo.value) - @withPyOpenSSL - @onlyPy3 - def test_https_proxy_pyopenssl_not_supported(self): + def test_https_proxy_pyopenssl_not_supported(self) -> None: with proxy_from_url(self.https_proxy_url, ca_certs=DEFAULT_CA) as https: - r = https.request("GET", "%s/" % self.http_url) + r = https.request("GET", f"{self.http_url}/") assert r.status == 200 - with pytest.raises(ProxySchemeUnsupported) as excinfo: - https.request("GET", "%s/" % self.https_url) - - assert "isn't available on non-native SSLContext" in str(excinfo.value) + with pytest.raises( + ProxySchemeUnsupported, match="isn't available on non-native SSLContext" + ): + https.request("GET", f"{self.https_url}/") - @onlySecureTransport - @onlyPy3 - def test_https_proxy_securetransport_not_supported(self): + @onlySecureTransport() + def test_https_proxy_securetransport_not_supported(self) -> None: with proxy_from_url(self.https_proxy_url, ca_certs=DEFAULT_CA) as https: - r = https.request("GET", "%s/" % self.http_url) + r = https.request("GET", f"{self.http_url}/") assert r.status == 200 - with pytest.raises(ProxySchemeUnsupported) as excinfo: - https.request("GET", "%s/" % self.https_url) + with pytest.raises( + ProxySchemeUnsupported, match="isn't available on non-native SSLContext" + ): + https.request("GET", f"{self.https_url}/") - assert "isn't available on non-native SSLContext" in str(excinfo.value) - - def test_https_proxy_forwarding_for_https(self): + def test_https_proxy_forwarding_for_https(self) -> None: with proxy_from_url( self.https_proxy_url, ca_certs=DEFAULT_CA, use_forwarding_for_https=True, ) as https: - r = https.request("GET", "%s/" % self.http_url) + r = https.request("GET", f"{self.http_url}/") assert r.status == 200 - r = https.request("GET", "%s/" % self.https_url) + r = https.request("GET", f"{self.https_url}/") assert r.status == 200 - def test_nagle_proxy(self): - """ Test that proxy connections do not have TCP_NODELAY turned on """ + def test_nagle_proxy(self) -> None: + """Test that proxy connections do not have TCP_NODELAY turned on""" with ProxyManager(self.proxy_url) as http: hc2 = http.connection_from_host(self.http_host, self.http_port) conn = hc2._get_conn() try: hc2._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn.sock.getsockopt( + tcp_nodelay_setting = conn.sock.getsockopt( # type: ignore[attr-defined] socket.IPPROTO_TCP, socket.TCP_NODELAY ) assert tcp_nodelay_setting == 0, ( @@ -157,31 +143,39 @@ def test_nagle_proxy(self): finally: conn.close() - def test_proxy_conn_fail(self): + @pytest.mark.parametrize("proxy_scheme", ["http", "https"]) + @pytest.mark.parametrize("target_scheme", ["http", "https"]) + def test_proxy_conn_fail_from_dns( + self, proxy_scheme: str, target_scheme: str + ) -> None: host, port = get_unreachable_address() with proxy_from_url( - "http://%s:%s/" % (host, port), retries=1, timeout=LONG_TIMEOUT + f"{proxy_scheme}://{host}:{port}/", retries=1, timeout=LONG_TIMEOUT ) as http: - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.https_url) - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.http_url) + if target_scheme == "https": + target_url = self.https_url + else: + target_url = self.http_url with pytest.raises(MaxRetryError) as e: - http.request("GET", "%s/" % self.http_url) + http.request("GET", f"{target_url}/") assert type(e.value.reason) == ProxyError + assert ( + type(e.value.reason.original_error) + == urllib3.exceptions.NameResolutionError + ) - def test_oldapi(self): + def test_oldapi(self) -> None: with ProxyManager( - connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA + connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA # type: ignore[arg-type] ) as http: - r = http.request("GET", "%s/" % self.http_url) + r = http.request("GET", f"{self.http_url}/") assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) + r = http.request("GET", f"{self.https_url}/") assert r.status == 200 - def test_proxy_verified(self): + def test_proxy_verified(self) -> None: with proxy_from_url( self.proxy_url, cert_reqs="REQUIRED", ca_certs=self.bad_ca_path ) as http: @@ -189,9 +183,11 @@ def test_proxy_verified(self): with pytest.raises(MaxRetryError) as e: https_pool.request("GET", "/", retries=0) assert isinstance(e.value.reason, SSLError) - assert "certificate verify failed" in str(e.value.reason), ( - "Expected 'certificate verify failed', instead got: %r" % e.value.reason - ) + assert ( + "certificate verify failed" in str(e.value.reason) + # PyPy is more specific + or "self signed certificate in certificate chain" in str(e.value.reason) + ), f"Expected 'certificate verify failed', instead got: {e.value.reason!r}" http = proxy_from_url( self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA @@ -207,17 +203,18 @@ def test_proxy_verified(self): ) https_fail_pool = http._new_pool("https", "127.0.0.1", self.https_port) - with pytest.raises(MaxRetryError) as e: + with pytest.raises( + MaxRetryError, match="doesn't match|IP address mismatch" + ) as e: https_fail_pool.request("GET", "/", retries=0) assert isinstance(e.value.reason, SSLError) - assert "doesn't match" in str(e.value.reason) - def test_redirect(self): + def test_redirect(self) -> None: with proxy_from_url(self.proxy_url) as http: r = http.request( "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, + f"{self.http_url}/redirect", + fields={"target": f"{self.http_url}/"}, redirect=False, ) @@ -225,183 +222,163 @@ def test_redirect(self): r = http.request( "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, + f"{self.http_url}/redirect", + fields={"target": f"{self.http_url}/"}, ) assert r.status == 200 assert r.data == b"Dummy server!" - def test_cross_host_redirect(self): + def test_cross_host_redirect(self) -> None: with proxy_from_url(self.proxy_url) as http: - cross_host_location = "%s/echo?a=b" % self.http_url_alt + cross_host_location = f"{self.http_url_alt}/echo?a=b" with pytest.raises(MaxRetryError): http.request( "GET", - "%s/redirect" % self.http_url, + f"{self.http_url}/redirect", fields={"target": cross_host_location}, retries=0, ) r = http.request( "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.http_url_alt}, + f"{self.http_url}/redirect", + fields={"target": f"{self.http_url_alt}/echo?a=b"}, retries=1, ) + assert isinstance(r, HTTPResponse) + assert r._pool is not None assert r._pool.host != self.http_host_alt - def test_cross_protocol_redirect(self): + def test_cross_protocol_redirect(self) -> None: with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: - cross_protocol_location = "%s/echo?a=b" % self.https_url + cross_protocol_location = f"{self.https_url}/echo?a=b" with pytest.raises(MaxRetryError): http.request( "GET", - "%s/redirect" % self.http_url, + f"{self.http_url}/redirect", fields={"target": cross_protocol_location}, retries=0, ) r = http.request( "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.https_url}, + f"{self.http_url}/redirect", + fields={"target": f"{self.https_url}/echo?a=b"}, retries=1, ) + assert isinstance(r, HTTPResponse) + assert r._pool is not None assert r._pool.host == self.https_host - def test_headers(self): + def test_headers(self) -> None: with proxy_from_url( self.proxy_url, headers={"Foo": "bar"}, proxy_headers={"Hickory": "dickory"}, ca_certs=DEFAULT_CA, ) as http: - - r = http.request_encode_url("GET", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.http_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) + assert returned_headers.get("Host") == f"{self.http_host}:{self.http_port}" - r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.http_url_alt}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host_alt, - self.http_port, + assert ( + returned_headers.get("Host") == f"{self.http_host_alt}:{self.http_port}" ) - r = http.request_encode_url("GET", "%s/headers" % self.https_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.https_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, + assert ( + returned_headers.get("Host") == f"{self.https_host}:{self.https_port}" ) - r = http.request_encode_body("POST", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_body("POST", f"{self.http_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) + assert returned_headers.get("Host") == f"{self.http_host}:{self.http_port}" r = http.request_encode_url( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + "GET", f"{self.http_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) + assert returned_headers.get("Host") == f"{self.http_host}:{self.http_port}" r = http.request_encode_url( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + "GET", f"{self.https_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, + assert ( + returned_headers.get("Host") == f"{self.https_host}:{self.https_port}" ) r = http.request_encode_body( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + "GET", f"{self.http_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) + assert returned_headers.get("Host") == f"{self.http_host}:{self.http_port}" r = http.request_encode_body( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + "GET", f"{self.https_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, + assert ( + returned_headers.get("Host") == f"{self.https_host}:{self.https_port}" ) - @onlyPy3 - def test_https_headers(self): + def test_https_headers(self) -> None: with proxy_from_url( self.https_proxy_url, headers={"Foo": "bar"}, proxy_headers={"Hickory": "dickory"}, ca_certs=DEFAULT_CA, ) as http: - - r = http.request_encode_url("GET", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.http_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) + assert returned_headers.get("Host") == f"{self.http_host}:{self.http_port}" - r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.http_url_alt}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host_alt, - self.http_port, + assert ( + returned_headers.get("Host") == f"{self.http_host_alt}:{self.http_port}" ) r = http.request_encode_body( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + "GET", f"{self.https_url}/headers", headers={"Baz": "quux"} ) - returned_headers = json.loads(r.data.decode()) + returned_headers = r.json() assert returned_headers.get("Foo") is None assert returned_headers.get("Baz") == "quux" assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, + assert ( + returned_headers.get("Host") == f"{self.https_host}:{self.https_port}" ) - def test_https_headers_forwarding_for_https(self): + def test_https_headers_forwarding_for_https(self) -> None: with proxy_from_url( self.https_proxy_url, headers={"Foo": "bar"}, @@ -409,17 +386,15 @@ def test_https_headers_forwarding_for_https(self): ca_certs=DEFAULT_CA, use_forwarding_for_https=True, ) as http: - - r = http.request_encode_url("GET", "%s/headers" % self.https_url) - returned_headers = json.loads(r.data.decode()) + r = http.request_encode_url("GET", f"{self.https_url}/headers") + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, + assert ( + returned_headers.get("Host") == f"{self.https_host}:{self.https_port}" ) - def test_headerdict(self): + def test_headerdict(self) -> None: default_headers = HTTPHeaderDict(a="b") proxy_headers = HTTPHeaderDict() proxy_headers.add("foo", "bar") @@ -428,14 +403,12 @@ def test_headerdict(self): self.proxy_url, headers=default_headers, proxy_headers=proxy_headers ) as http: request_headers = HTTPHeaderDict(baz="quux") - r = http.request( - "GET", "%s/headers" % self.http_url, headers=request_headers - ) - returned_headers = json.loads(r.data.decode()) + r = http.request("GET", f"{self.http_url}/headers", headers=request_headers) + returned_headers = r.json() assert returned_headers.get("Foo") == "bar" assert returned_headers.get("Baz") == "quux" - def test_proxy_pooling(self): + def test_proxy_pooling(self) -> None: with proxy_from_url(self.proxy_url, cert_reqs="NONE") as http: for x in range(2): http.urlopen("GET", self.http_url) @@ -446,14 +419,16 @@ def test_proxy_pooling(self): assert len(http.pools) == 1 for x in range(2): - http.urlopen("GET", self.https_url) + with pytest.warns(InsecureRequestWarning): + http.urlopen("GET", self.https_url) assert len(http.pools) == 2 for x in range(2): - http.urlopen("GET", self.https_url_alt) + with pytest.warns(InsecureRequestWarning): + http.urlopen("GET", self.https_url_alt) assert len(http.pools) == 3 - def test_proxy_pooling_ext(self): + def test_proxy_pooling_ext(self) -> None: with proxy_from_url(self.proxy_url) as http: hc1 = http.connection_from_url(self.http_url) hc2 = http.connection_from_host(self.http_host, self.http_port) @@ -475,49 +450,386 @@ def test_proxy_pooling_ext(self): assert sc2 != sc3 assert sc3 == sc4 - @pytest.mark.timeout(0.5) - @requires_network - def test_https_proxy_timeout(self): - with proxy_from_url("https://{host}".format(host=TARPIT_HOST)) as https: + @requires_network() + @pytest.mark.parametrize( + ["proxy_scheme", "target_scheme", "use_forwarding_for_https"], + [ + ("http", "http", False), + ("https", "http", False), + # 'use_forwarding_for_https' is only valid for HTTPS+HTTPS. + ("https", "https", True), + ], + ) + def test_forwarding_proxy_request_timeout( + self, proxy_scheme: str, target_scheme: str, use_forwarding_for_https: bool + ) -> None: + proxy_url = self.https_proxy_url if proxy_scheme == "https" else self.proxy_url + target_url = f"{target_scheme}://{TARPIT_HOST}" + + with proxy_from_url( + proxy_url, + ca_certs=DEFAULT_CA, + use_forwarding_for_https=use_forwarding_for_https, + ) as proxy: with pytest.raises(MaxRetryError) as e: - https.request("GET", self.http_url, timeout=SHORT_TIMEOUT) - assert type(e.value.reason) == ConnectTimeoutError + timeout = Timeout(connect=LONG_TIMEOUT, read=SHORT_TIMEOUT) + proxy.request("GET", target_url, timeout=timeout) + + # We sent the request to the proxy but didn't get any response + # so we're not sure if that's being caused by the proxy or the + # target so we put the blame on the target. + assert type(e.value.reason) == ReadTimeoutError + + @requires_network() + @pytest.mark.parametrize( + ["proxy_scheme", "target_scheme"], [("http", "https"), ("https", "https")] + ) + def test_tunneling_proxy_request_timeout( + self, proxy_scheme: str, target_scheme: str + ) -> None: + proxy_url = self.https_proxy_url if proxy_scheme == "https" else self.proxy_url + target_url = f"{target_scheme}://{TARPIT_HOST}" - @pytest.mark.timeout(0.5) - @requires_network - def test_https_proxy_pool_timeout(self): with proxy_from_url( - "https://{host}".format(host=TARPIT_HOST), timeout=SHORT_TIMEOUT - ) as https: + proxy_url, + ca_certs=DEFAULT_CA, + ) as proxy: + with pytest.raises(MaxRetryError) as e: + timeout = Timeout(connect=LONG_TIMEOUT, read=SHORT_TIMEOUT) + proxy.request("GET", target_url, timeout=timeout) + + assert type(e.value.reason) == ReadTimeoutError + + @requires_network() + @pytest.mark.parametrize( + ["proxy_scheme", "target_scheme", "use_forwarding_for_https"], + [ + ("http", "http", False), + ("https", "http", False), + # 'use_forwarding_for_https' is only valid for HTTPS+HTTPS. + ("https", "https", True), + ], + ) + def test_forwarding_proxy_connect_timeout( + self, proxy_scheme: str, target_scheme: str, use_forwarding_for_https: bool + ) -> None: + proxy_url = f"{proxy_scheme}://{TARPIT_HOST}" + target_url = self.https_url if target_scheme == "https" else self.http_url + + with proxy_from_url( + proxy_url, + ca_certs=DEFAULT_CA, + timeout=SHORT_TIMEOUT, + use_forwarding_for_https=use_forwarding_for_https, + ) as proxy: + with pytest.raises(MaxRetryError) as e: + proxy.request("GET", target_url) + + assert type(e.value.reason) == ProxyError + assert type(e.value.reason.original_error) == ConnectTimeoutError + + @requires_network() + @pytest.mark.parametrize( + ["proxy_scheme", "target_scheme"], [("http", "https"), ("https", "https")] + ) + def test_tunneling_proxy_connect_timeout( + self, proxy_scheme: str, target_scheme: str + ) -> None: + proxy_url = f"{proxy_scheme}://{TARPIT_HOST}" + target_url = self.https_url if target_scheme == "https" else self.http_url + + with proxy_from_url( + proxy_url, ca_certs=DEFAULT_CA, timeout=SHORT_TIMEOUT + ) as proxy: + with pytest.raises(MaxRetryError) as e: + proxy.request("GET", target_url) + + assert type(e.value.reason) == ProxyError + assert type(e.value.reason.original_error) == ConnectTimeoutError + + @requires_network() + @pytest.mark.parametrize( + ["target_scheme", "use_forwarding_for_https"], + [ + ("http", False), + ("https", False), + ("https", True), + ], + ) + def test_https_proxy_tls_error( + self, target_scheme: str, use_forwarding_for_https: str + ) -> None: + target_url = self.https_url if target_scheme == "https" else self.http_url + proxy_ctx = ssl.create_default_context() + with proxy_from_url( + self.https_proxy_url, + proxy_ssl_context=proxy_ctx, + use_forwarding_for_https=use_forwarding_for_https, + ) as proxy: + with pytest.raises(MaxRetryError) as e: + proxy.request("GET", target_url) + assert type(e.value.reason) == ProxyError + assert type(e.value.reason.original_error) == SSLError + + @requires_network() + @pytest.mark.parametrize( + ["proxy_scheme", "use_forwarding_for_https"], + [ + ("http", False), + ("https", False), + ("https", True), + ], + ) + def test_proxy_https_target_tls_error( + self, proxy_scheme: str, use_forwarding_for_https: str + ) -> None: + if proxy_scheme == "https" and use_forwarding_for_https: + pytest.skip("Test is expected to fail due to urllib3/urllib3#2577") + + proxy_url = self.https_proxy_url if proxy_scheme == "https" else self.proxy_url + proxy_ctx = ssl.create_default_context() + proxy_ctx.load_verify_locations(DEFAULT_CA) + ctx = ssl.create_default_context() + + with proxy_from_url( + proxy_url, + proxy_ssl_context=proxy_ctx, + ssl_context=ctx, + use_forwarding_for_https=use_forwarding_for_https, + ) as proxy: with pytest.raises(MaxRetryError) as e: - https.request("GET", self.http_url) - assert type(e.value.reason) == ConnectTimeoutError + proxy.request("GET", self.https_url) + assert type(e.value.reason) == SSLError - def test_scheme_host_case_insensitive(self): + def test_scheme_host_case_insensitive(self) -> None: """Assert that upper-case schemes and hosts are normalized.""" with proxy_from_url(self.proxy_url.upper(), ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url.upper()) + r = http.request("GET", f"{self.http_url.upper()}/") assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url.upper()) + r = http.request("GET", f"{self.https_url.upper()}/") assert r.status == 200 + @pytest.mark.parametrize( + "url, error_msg", + [ + ( + "127.0.0.1", + "Proxy URL had no scheme, should start with http:// or https://", + ), + ( + "localhost:8080", + "Proxy URL had no scheme, should start with http:// or https://", + ), + ( + "ftp://google.com", + "Proxy URL had unsupported scheme ftp, should use http:// or https://", + ), + ], + ) + def test_invalid_schema(self, url: str, error_msg: str) -> None: + with pytest.raises(ProxySchemeUnknown, match=error_msg): + proxy_from_url(url) + @pytest.mark.skipif(not HAS_IPV6, reason="Only runs on IPv6 systems") class TestIPv6HTTPProxyManager(IPv6HTTPDummyProxyTestCase): @classmethod - def setup_class(cls): + def setup_class(cls) -> None: HTTPDummyProxyTestCase.setup_class() - cls.http_url = "http://%s:%d" % (cls.http_host, cls.http_port) - cls.http_url_alt = "http://%s:%d" % (cls.http_host_alt, cls.http_port) - cls.https_url = "https://%s:%d" % (cls.https_host, cls.https_port) - cls.https_url_alt = "https://%s:%d" % (cls.https_host_alt, cls.https_port) - cls.proxy_url = "http://[%s]:%d" % (cls.proxy_host, cls.proxy_port) + cls.http_url = f"http://{cls.http_host}:{int(cls.http_port)}" + cls.http_url_alt = f"http://{cls.http_host_alt}:{int(cls.http_port)}" + cls.https_url = f"https://{cls.https_host}:{int(cls.https_port)}" + cls.https_url_alt = f"https://{cls.https_host_alt}:{int(cls.https_port)}" + cls.proxy_url = f"http://[{cls.proxy_host}]:{int(cls.proxy_port)}" - def test_basic_ipv6_proxy(self): + def test_basic_ipv6_proxy(self) -> None: with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url) + r = http.request("GET", f"{self.http_url}/") assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) + r = http.request("GET", f"{self.https_url}/") assert r.status == 200 + + +class TestHTTPSProxyVerification: + @staticmethod + def _get_proxy_fingerprint_md5(ca_path: str) -> str: + proxy_pem_path = pathlib.Path(ca_path).parent / "proxy.pem" + proxy_der = ssl.PEM_cert_to_DER_cert(proxy_pem_path.read_text()) + proxy_hashed = hashlib.md5(proxy_der).digest() + fingerprint = binascii.hexlify(proxy_hashed).decode("ascii") + return fingerprint + + @staticmethod + def _get_certificate_formatted_proxy_host(host: str) -> str: + try: + addr = ipaddress.ip_address(host) + except ValueError: + return host + + if addr.version != 6: + return host + + # Transform ipv6 like '::1' to 0:0:0:0:0:0:0:1 via '0000:0000:0000:0000:0000:0000:0000:0001' + return addr.exploded.replace("0000", "0").replace("000", "") + + def test_https_proxy_assert_fingerprint_md5( + self, no_san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = no_san_proxy_with_server + proxy_url = f"https://{proxy.host}:{proxy.port}" + destination_url = f"https://{server.host}:{server.port}" + + proxy_fingerprint = self._get_proxy_fingerprint_md5(proxy.ca_certs) + with proxy_from_url( + proxy_url, + ca_certs=proxy.ca_certs, + proxy_assert_fingerprint=proxy_fingerprint, + ) as https: + https.request("GET", destination_url) + + def test_https_proxy_assert_fingerprint_md5_non_matching( + self, no_san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = no_san_proxy_with_server + proxy_url = f"https://{proxy.host}:{proxy.port}" + destination_url = f"https://{server.host}:{server.port}" + + proxy_fingerprint = self._get_proxy_fingerprint_md5(proxy.ca_certs) + new_char = "b" if proxy_fingerprint[5] == "a" else "a" + proxy_fingerprint = proxy_fingerprint[:5] + new_char + proxy_fingerprint[6:] + + with proxy_from_url( + proxy_url, + ca_certs=proxy.ca_certs, + proxy_assert_fingerprint=proxy_fingerprint, + ) as https: + with pytest.raises(MaxRetryError) as e: + https.request("GET", destination_url) + + assert "Fingerprints did not match" in str(e) + + def test_https_proxy_assert_hostname( + self, san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = san_proxy_with_server + destination_url = f"https://{server.host}:{server.port}" + + with proxy_from_url( + proxy.base_url, ca_certs=proxy.ca_certs, proxy_assert_hostname=proxy.host + ) as https: + https.request("GET", destination_url) + + def test_https_proxy_assert_hostname_non_matching( + self, san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = san_proxy_with_server + destination_url = f"https://{server.host}:{server.port}" + + proxy_hostname = "example.com" + with proxy_from_url( + proxy.base_url, + ca_certs=proxy.ca_certs, + proxy_assert_hostname=proxy_hostname, + ) as https: + with pytest.raises(MaxRetryError) as e: + https.request("GET", destination_url) + + proxy_host = self._get_certificate_formatted_proxy_host(proxy.host) + msg = f"hostname \\'{proxy_hostname}\\' doesn\\'t match \\'{proxy_host}\\'" + assert msg in str(e) + + def test_https_proxy_hostname_verification( + self, no_localhost_san_server: ServerConfig + ) -> None: + bad_server = no_localhost_san_server + bad_proxy_url = f"https://{bad_server.host}:{bad_server.port}" + + # An exception will be raised before we contact the destination domain. + test_url = "testing.com" + with proxy_from_url(bad_proxy_url, ca_certs=bad_server.ca_certs) as https: + with pytest.raises(MaxRetryError) as e: + https.request("GET", "http://%s/" % test_url) + assert isinstance(e.value.reason, ProxyError) + + ssl_error = e.value.reason.original_error + assert isinstance(ssl_error, SSLError) + assert "hostname 'localhost' doesn't match" in str( + ssl_error + ) or "Hostname mismatch" in str(ssl_error) + + with pytest.raises(MaxRetryError) as e: + https.request("GET", "https://%s/" % test_url) + assert isinstance(e.value.reason, ProxyError) + + ssl_error = e.value.reason.original_error + assert isinstance(ssl_error, SSLError) + assert "hostname 'localhost' doesn't match" in str( + ssl_error + ) or "Hostname mismatch" in str(ssl_error) + + def test_https_proxy_ipv4_san( + self, ipv4_san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = ipv4_san_proxy_with_server + proxy_url = f"https://{proxy.host}:{proxy.port}" + destination_url = f"https://{server.host}:{server.port}" + with proxy_from_url(proxy_url, ca_certs=proxy.ca_certs) as https: + r = https.request("GET", destination_url) + assert r.status == 200 + + def test_https_proxy_ipv6_san( + self, ipv6_san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = ipv6_san_proxy_with_server + proxy_url = f"https://[{proxy.host}]:{proxy.port}" + destination_url = f"https://{server.host}:{server.port}" + with proxy_from_url(proxy_url, ca_certs=proxy.ca_certs) as https: + r = https.request("GET", destination_url) + assert r.status == 200 + + @pytest.mark.parametrize("target_scheme", ["http", "https"]) + def test_https_proxy_no_san( + self, + no_san_proxy_with_server: tuple[ServerConfig, ServerConfig], + target_scheme: str, + ) -> None: + proxy, server = no_san_proxy_with_server + proxy_url = f"https://{proxy.host}:{proxy.port}" + destination_url = f"{target_scheme}://{server.host}:{server.port}" + + with proxy_from_url(proxy_url, ca_certs=proxy.ca_certs) as https: + with pytest.raises(MaxRetryError) as e: + https.request("GET", destination_url) + assert isinstance(e.value.reason, ProxyError) + + ssl_error = e.value.reason.original_error + assert isinstance(ssl_error, SSLError) + assert "no appropriate subjectAltName fields were found" in str( + ssl_error + ) or "Hostname mismatch, certificate is not valid for 'localhost'" in str( + ssl_error + ) + + def test_https_proxy_no_san_hostname_checks_common_name( + self, no_san_proxy_with_server: tuple[ServerConfig, ServerConfig] + ) -> None: + proxy, server = no_san_proxy_with_server + proxy_url = f"https://{proxy.host}:{proxy.port}" + destination_url = f"https://{server.host}:{server.port}" + + proxy_ctx = urllib3.util.ssl_.create_urllib3_context() + try: + proxy_ctx.hostname_checks_common_name = True + # PyPy doesn't like us setting 'hostname_checks_common_name' + # but also has it enabled by default so we need to handle that. + except AttributeError: + pass + if getattr(proxy_ctx, "hostname_checks_common_name", False) is not True: + pytest.skip("Test requires 'SSLContext.hostname_checks_common_name=True'") + + with proxy_from_url( + proxy_url, ca_certs=proxy.ca_certs, proxy_ssl_context=proxy_ctx + ) as https: + https.request("GET", destination_url) diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index 7c06875439..0f7e24c264 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -1,68 +1,71 @@ # TODO: Break this module up into pieces. Maybe group by functionality tested # rather than the socket level-ness of it. -from dummyserver.server import ( - DEFAULT_CA, - DEFAULT_CERTS, - encrypt_key_pem, - get_unreachable_address, -) -from dummyserver.testcase import SocketDummyServerTestCase, consume_socket -from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, util -from urllib3._collections import HTTPHeaderDict -from urllib3.connection import HTTPConnection, _get_default_user_agent -from urllib3.exceptions import ( - MaxRetryError, - ProtocolError, - ProxyError, - ReadTimeoutError, - SSLError, -) -from urllib3.packages.six.moves import http_client as httplib -from urllib3.poolmanager import proxy_from_url -from urllib3.util import ssl_, ssl_wrap_socket -from urllib3.util.retry import Retry -from urllib3.util.timeout import Timeout - -from .. import LogRecorder, has_alpn, onlyPy3 - -try: - from mimetools import Message as MimeToolMessage -except ImportError: - - class MimeToolMessage(object): - pass - +from __future__ import annotations +import contextlib +import errno +import io import os import os.path import select import shutil import socket import ssl +import sys import tempfile +import time +import typing from collections import OrderedDict +from pathlib import Path from test import ( LONG_TIMEOUT, SHORT_TIMEOUT, - notPyPy2, notSecureTransport, notWindows, requires_ssl_context_keyfile_password, resolvesLocalhostFQDN, ) from threading import Event +from unittest import mock -import mock import pytest import trustme -# Retry failed tests -pytestmark = pytest.mark.flaky +from dummyserver.server import ( + DEFAULT_CA, + DEFAULT_CERTS, + encrypt_key_pem, + get_unreachable_address, +) +from dummyserver.testcase import SocketDummyServerTestCase, consume_socket +from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager, util +from urllib3._collections import HTTPHeaderDict +from urllib3.connection import HTTPConnection, _get_default_user_agent +from urllib3.connectionpool import _url_from_pool +from urllib3.exceptions import ( + InsecureRequestWarning, + MaxRetryError, + ProtocolError, + ProxyError, + ReadTimeoutError, + SSLError, +) +from urllib3.poolmanager import proxy_from_url +from urllib3.util import ssl_, ssl_wrap_socket +from urllib3.util.retry import Retry +from urllib3.util.timeout import Timeout + +from .. import LogRecorder, has_alpn + +if typing.TYPE_CHECKING: + from _typeshed import StrOrBytesPath +else: + StrOrBytesPath = object class TestCookies(SocketDummyServerTestCase): - def test_multi_setcookie(self): - def multicookie_response_handler(listener): + def test_multi_setcookie(self) -> None: + def multicookie_response_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -85,13 +88,11 @@ def multicookie_response_handler(listener): class TestSNI(SocketDummyServerTestCase): - def test_hostname_in_first_request_packet(self): - if not util.HAS_SNI: - pytest.skip("SNI-support not available") + def test_hostname_in_first_request_packet(self) -> None: done_receiving = Event() self.buf = b"" - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] self.buf = sock.recv(65536) # We only accept one packet @@ -112,14 +113,14 @@ def socket_handler(listener): class TestALPN(SocketDummyServerTestCase): - def test_alpn_protocol_in_first_request_packet(self): + def test_alpn_protocol_in_first_request_packet(self) -> None: if not has_alpn(): pytest.skip("ALPN-support not available") done_receiving = Event() self.buf = b"" - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] self.buf = sock.recv(65536) # We only accept one packet @@ -140,16 +141,48 @@ def socket_handler(listener): ), "missing ALPN protocol in SSL handshake" +def original_ssl_wrap_socket( + sock: socket.socket, + keyfile: StrOrBytesPath | None = None, + certfile: StrOrBytesPath | None = None, + server_side: bool = False, + cert_reqs: ssl.VerifyMode = ssl.CERT_NONE, + ssl_version: int = ssl.PROTOCOL_TLS, + ca_certs: str | None = None, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + ciphers: str | None = None, +) -> ssl.SSLSocket: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + context = ssl.SSLContext(ssl_version) + context.verify_mode = cert_reqs + if ca_certs: + context.load_verify_locations(ca_certs) + if certfile: + context.load_cert_chain(certfile, keyfile) + if ciphers: + context.set_ciphers(ciphers) + return context.wrap_socket( + sock=sock, + server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + ) + + class TestClientCerts(SocketDummyServerTestCase): """ Tests for client certificate support. """ @classmethod - def setup_class(cls): + def setup_class(cls) -> None: cls.tmpdir = tempfile.mkdtemp() ca = trustme.CA() - cert = ca.issue_cert(u"localhost") + cert = ca.issue_cert("localhost") encrypted_key = encrypt_key_pem(cert.private_key_pem, b"letmein") cls.ca_path = os.path.join(cls.tmpdir, "ca.pem") @@ -164,14 +197,15 @@ def setup_class(cls): cert.private_key_pem.write_to_path(cls.key_path) encrypted_key.write_to_path(cls.password_key_path) - def teardown_class(cls): + @classmethod + def teardown_class(cls) -> None: shutil.rmtree(cls.tmpdir) - def _wrap_in_ssl(self, sock): + def _wrap_in_ssl(self, sock: socket.socket) -> ssl.SSLSocket: """ Given a single socket, wraps it in TLS. """ - return ssl.wrap_socket( + return original_ssl_wrap_socket( sock, ssl_version=ssl.PROTOCOL_SSLv23, cert_reqs=ssl.CERT_REQUIRED, @@ -181,7 +215,7 @@ def _wrap_in_ssl(self, sock): server_side=True, ) - def test_client_certs_two_files(self): + def test_client_certs_two_files(self) -> None: """ Having a client cert in a separate file to its associated key works properly. @@ -189,7 +223,7 @@ def test_client_certs_two_files(self): done_receiving = Event() client_certs = [] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] sock = self._wrap_in_ssl(sock) @@ -225,7 +259,7 @@ def socket_handler(listener): assert len(client_certs) == 1 - def test_client_certs_one_file(self): + def test_client_certs_one_file(self) -> None: """ Having a client cert and its associated private key in just one file works properly. @@ -233,7 +267,7 @@ def test_client_certs_one_file(self): done_receiving = Event() client_certs = [] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] sock = self._wrap_in_ssl(sock) @@ -268,13 +302,13 @@ def socket_handler(listener): assert len(client_certs) == 1 - def test_missing_client_certs_raises_error(self): + def test_missing_client_certs_raises_error(self) -> None: """ Having client certs not be present causes an error. """ done_receiving = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] try: @@ -294,22 +328,22 @@ def socket_handler(listener): done_receiving.set() done_receiving.set() - @requires_ssl_context_keyfile_password - def test_client_cert_with_string_password(self): - self.run_client_cert_with_password_test(u"letmein") + @requires_ssl_context_keyfile_password() + def test_client_cert_with_string_password(self) -> None: + self.run_client_cert_with_password_test("letmein") - @requires_ssl_context_keyfile_password - def test_client_cert_with_bytes_password(self): + @requires_ssl_context_keyfile_password() + def test_client_cert_with_bytes_password(self) -> None: self.run_client_cert_with_password_test(b"letmein") - def run_client_cert_with_password_test(self, password): + def run_client_cert_with_password_test(self, password: bytes | str) -> None: """ Tests client certificate password functionality """ done_receiving = Event() client_certs = [] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] sock = self._wrap_in_ssl(sock) @@ -332,6 +366,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) + assert ssl_.SSLContext is not None ssl_context = ssl_.SSLContext(ssl_.PROTOCOL_SSLv23) ssl_context.load_cert_chain( certfile=self.cert_path, keyfile=self.password_key_path, password=password @@ -349,35 +384,36 @@ def socket_handler(listener): assert len(client_certs) == 1 - @requires_ssl_context_keyfile_password - def test_load_keyfile_with_invalid_password(self): + @requires_ssl_context_keyfile_password() + def test_load_keyfile_with_invalid_password(self) -> None: + assert ssl_.SSLContext is not None context = ssl_.SSLContext(ssl_.PROTOCOL_SSLv23) - - # Different error is raised depending on context. - if ssl_.IS_PYOPENSSL: - from OpenSSL.SSL import Error - - expected_error = Error - else: - expected_error = ssl.SSLError - - with pytest.raises(expected_error): + with pytest.raises(ssl.SSLError): context.load_cert_chain( certfile=self.cert_path, keyfile=self.password_key_path, password=b"letmei", ) + # For SecureTransport, the validation that would raise an error in + # this case is deferred. + @notSecureTransport() + def test_load_invalid_cert_file(self) -> None: + assert ssl_.SSLContext is not None + context = ssl_.SSLContext(ssl_.PROTOCOL_SSLv23) + with pytest.raises(ssl.SSLError): + context.load_cert_chain(certfile=self.password_key_path) + class TestSocketClosing(SocketDummyServerTestCase): - def test_recovery_when_server_closes_connection(self): + def test_recovery_when_server_closes_connection(self) -> None: # Does the pool work seamlessly if an open connection in the # connection pool gets hung up on by the server, then reaches # the front of the queue again? done_closing = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for i in 0, 1: sock = listener.accept()[0] @@ -385,7 +421,7 @@ def socket_handler(listener): while not buf.endswith(b"\r\n\r\n"): buf = sock.recv(65536) - body = "Response %d" % i + body = f"Response {int(i)}" sock.send( ( "HTTP/1.1 200 OK\r\n" @@ -411,18 +447,19 @@ def socket_handler(listener): assert response.status == 200 assert response.data == b"Response 1" - def test_connection_refused(self): + def test_connection_refused(self) -> None: # Does the pool retry if there is no listener on the port? host, port = get_unreachable_address() with HTTPConnectionPool(host, port, maxsize=3, block=True) as http: with pytest.raises(MaxRetryError): http.request("GET", "/", retries=0, release_conn=False) + assert http.pool is not None assert http.pool.qsize() == http.pool.maxsize - def test_connection_read_timeout(self): + def test_connection_read_timeout(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] while not sock.recv(65536).endswith(b"\r\n\r\n"): pass @@ -445,12 +482,13 @@ def socket_handler(listener): finally: timed_out.set() + assert http.pool is not None assert http.pool.qsize() == http.pool.maxsize - def test_read_timeout_dont_retry_method_not_in_allowlist(self): + def test_read_timeout_dont_retry_method_not_in_allowlist(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] sock.recv(65536) timed_out.wait() @@ -466,11 +504,11 @@ def socket_handler(listener): finally: timed_out.set() - def test_https_connection_read_timeout(self): - """ Handshake timeouts should fail with a Timeout""" + def test_https_connection_read_timeout(self) -> None: + """Handshake timeouts should fail with a Timeout""" timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] while not sock.recv(65536): pass @@ -478,6 +516,7 @@ def socket_handler(listener): timed_out.wait() sock.close() + # first ReadTimeoutError due to SocketTimeout self._start_server(socket_handler) with HTTPSConnectionPool( self.host, self.port, timeout=LONG_TIMEOUT, retries=False @@ -488,8 +527,15 @@ def socket_handler(listener): finally: timed_out.set() - def test_timeout_errors_cause_retries(self): - def socket_handler(listener): + # second ReadTimeoutError due to errno + with HTTPSConnectionPool(host=self.host): + err = OSError() + err.errno = errno.EAGAIN + with pytest.raises(ReadTimeoutError): + pool._raise_timeout(err, "", 0) + + def test_timeout_errors_cause_retries(self) -> None: + def socket_handler(listener: socket.socket) -> None: sock_timeout = listener.accept()[0] # Wait for a second request before closing the first socket. @@ -532,10 +578,10 @@ def socket_handler(listener): finally: socket.setdefaulttimeout(default_timeout) - def test_delayed_body_read_timeout(self): + def test_delayed_body_read_timeout(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" body = "Hi" @@ -569,10 +615,10 @@ def socket_handler(listener): finally: timed_out.set() - def test_delayed_body_read_timeout_with_preload(self): + def test_delayed_body_read_timeout_with_preload(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" body = "Hi" @@ -599,11 +645,11 @@ def socket_handler(listener): finally: timed_out.set() - def test_incomplete_response(self): + def test_incomplete_response(self) -> None: body = "Response" partial_body = body[:2] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] # Consume request @@ -629,10 +675,10 @@ def socket_handler(listener): with pytest.raises(ProtocolError): response.read() - def test_retry_weird_http_version(self): - """ Retry class should handle httplib.BadStatusLine errors properly """ + def test_retry_weird_http_version(self) -> None: + """Retry class should handle httplib.BadStatusLine errors properly""" - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] # First request. # Pause before responding so the first request times out. @@ -679,10 +725,10 @@ def socket_handler(listener): assert response.status == 200 assert response.data == b"foo" - def test_connection_cleanup_on_read_timeout(self): + def test_connection_cleanup_on_read_timeout(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" body = "Hi" @@ -702,6 +748,7 @@ def socket_handler(listener): self._start_server(socket_handler) with HTTPConnectionPool(self.host, self.port) as pool: + assert pool.pool is not None poolsize = pool.pool.qsize() response = pool.urlopen( "GET", "/", retries=0, preload_content=False, timeout=LONG_TIMEOUT @@ -713,11 +760,11 @@ def socket_handler(listener): finally: timed_out.set() - def test_connection_cleanup_on_protocol_error_during_read(self): + def test_connection_cleanup_on_protocol_error_during_read(self) -> None: body = "Response" partial_body = body[:2] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] # Consume request @@ -739,6 +786,7 @@ def socket_handler(listener): self._start_server(socket_handler) with HTTPConnectionPool(self.host, self.port) as pool: + assert pool.pool is not None poolsize = pool.pool.qsize() response = pool.request("GET", "/", retries=0, preload_content=False) @@ -746,10 +794,10 @@ def socket_handler(listener): response.read() assert poolsize == pool.pool.qsize() - def test_connection_closed_on_read_timeout_preload_false(self): + def test_connection_closed_on_read_timeout_preload_false(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] # Consume request @@ -759,14 +807,12 @@ def socket_handler(listener): # Send partial chunked response and then hang. sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "8\r\n" - "12345678\r\n" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"8\r\n" + b"12345678\r\n" ) timed_out.wait(5) @@ -785,15 +831,13 @@ def socket_handler(listener): # Send complete chunked response. new_sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "8\r\n" - "12345678\r\n" - "0\r\n\r\n" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"8\r\n" + b"12345678\r\n" + b"0\r\n\r\n" ) new_sock.close() @@ -817,11 +861,11 @@ def socket_handler(listener): ) assert len(response.read()) == 8 - def test_closing_response_actually_closes_connection(self): + def test_closing_response_actually_closes_connection(self) -> None: done_closing = Event() complete = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -829,12 +873,10 @@ def socket_handler(listener): buf = sock.recv(65536) sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 0\r\n" - "\r\n" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 0\r\n" + b"\r\n" ) # Wait for the socket to close. @@ -858,7 +900,7 @@ def socket_handler(listener): successful = complete.wait(timeout=LONG_TIMEOUT) assert successful, "Timed out waiting for connection close" - def test_release_conn_param_is_respected_after_timeout_retry(self): + def test_release_conn_param_is_respected_after_timeout_retry(self) -> None: """For successful ```urlopen(release_conn=False)```, the connection isn't released, even after a retry. @@ -871,7 +913,7 @@ def test_release_conn_param_is_respected_after_timeout_retry(self): [1] """ - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] consume_socket(sock) @@ -890,15 +932,13 @@ def socket_handler(listener): # Send complete chunked response. sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Transfer-Encoding: chunked\r\n" - "\r\n" - "8\r\n" - "12345678\r\n" - "0\r\n\r\n" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"8\r\n" + b"12345678\r\n" + b"0\r\n\r\n" ) sock.close() @@ -919,6 +959,7 @@ def socket_handler(listener): # The connection should still be on the response object, and none # should be in the pool. We opened two though. assert pool.num_connections == 2 + assert pool.pool is not None assert pool.pool.qsize() == 0 assert response.connection is not None @@ -927,10 +968,67 @@ def socket_handler(listener): assert pool.pool.qsize() == 1 assert response.connection is None + def test_socket_close_socket_then_file(self) -> None: + def consume_ssl_socket(listener: socket.socket) -> None: + try: + with listener.accept()[0] as sock, original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) as ssl_sock: + consume_socket(ssl_sock) + except (ConnectionResetError, ConnectionAbortedError, OSError): + pass + + self._start_server(consume_ssl_socket) + with socket.create_connection( + (self.host, self.port) + ) as sock, contextlib.closing( + ssl_wrap_socket(sock, server_hostname=self.host, ca_certs=DEFAULT_CA) + ) as ssl_sock, ssl_sock.makefile( + "rb" + ) as f: + ssl_sock.close() + f.close() + # SecureTransport is supposed to raise OSError but raises + # ssl.SSLError when closed because ssl_sock.context is None + with pytest.raises((OSError, ssl.SSLError)): + ssl_sock.sendall(b"hello") + assert ssl_sock.fileno() == -1 + + def test_socket_close_stays_open_with_makefile_open(self) -> None: + def consume_ssl_socket(listener: socket.socket) -> None: + try: + with listener.accept()[0] as sock, original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) as ssl_sock: + consume_socket(ssl_sock) + except (ConnectionResetError, ConnectionAbortedError, OSError): + pass + + self._start_server(consume_ssl_socket) + with socket.create_connection( + (self.host, self.port) + ) as sock, contextlib.closing( + ssl_wrap_socket(sock, server_hostname=self.host, ca_certs=DEFAULT_CA) + ) as ssl_sock, ssl_sock.makefile( + "rb" + ): + ssl_sock.close() + ssl_sock.close() + ssl_sock.sendall(b"hello") + assert ssl_sock.fileno() > 0 + class TestProxyManager(SocketDummyServerTestCase): - def test_simple(self): - def echo_socket_handler(listener): + def test_simple(self) -> None: + def echo_socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -949,7 +1047,7 @@ def echo_socket_handler(listener): sock.close() self._start_server(echo_socket_handler) - base_url = "http://%s:%d" % (self.host, self.port) + base_url = f"http://{self.host}:{self.port}" with proxy_from_url(base_url) as proxy: r = proxy.request("GET", "http://google.com/") @@ -969,8 +1067,8 @@ def echo_socket_handler(listener): ] ) - def test_headers(self): - def echo_socket_handler(listener): + def test_headers(self) -> None: + def echo_socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -989,7 +1087,7 @@ def echo_socket_handler(listener): sock.close() self._start_server(echo_socket_handler) - base_url = "http://%s:%d" % (self.host, self.port) + base_url = f"http://{self.host}:{self.port}" # Define some proxy headers. proxy_headers = HTTPHeaderDict({"For The Proxy": "YEAH!"}) @@ -1004,10 +1102,10 @@ def echo_socket_handler(listener): # OrderedDict/MultiDict). assert b"For The Proxy: YEAH!\r\n" in r.data - def test_retries(self): + def test_retries(self) -> None: close_event = Event() - def echo_socket_handler(listener): + def echo_socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] # First request, which should fail sock.close() @@ -1032,7 +1130,7 @@ def echo_socket_handler(listener): close_event.set() self._start_server(echo_socket_handler) - base_url = "http://%s:%d" % (self.host, self.port) + base_url = f"http://{self.host}:{self.port}" with proxy_from_url(base_url) as proxy: conn = proxy.connection_from_url("http://www.google.com") @@ -1051,8 +1149,8 @@ def echo_socket_handler(listener): retries=False, ) - def test_connect_reconn(self): - def proxy_ssl_one(listener): + def test_connect_reconn(self) -> None: + def proxy_ssl_one(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1060,21 +1158,17 @@ def proxy_ssl_one(listener): buf += sock.recv(65536) s = buf.decode("utf-8") if not s.startswith("CONNECT "): - sock.send( - ( - "HTTP/1.1 405 Method not allowed\r\nAllow: CONNECT\r\n\r\n" - ).encode("utf-8") - ) + sock.send(b"HTTP/1.1 405 Method not allowed\r\nAllow: CONNECT\r\n\r\n") sock.close() return - if not s.startswith("CONNECT %s:443" % (self.host,)): - sock.send(("HTTP/1.1 403 Forbidden\r\n\r\n").encode("utf-8")) + if not s.startswith(f"CONNECT {self.host}:443"): + sock.send(b"HTTP/1.1 403 Forbidden\r\n\r\n") sock.close() return - sock.send(("HTTP/1.1 200 Connection Established\r\n\r\n").encode("utf-8")) - ssl_sock = ssl.wrap_socket( + sock.send(b"HTTP/1.1 200 Connection Established\r\n\r\n") + ssl_sock = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], @@ -1087,36 +1181,34 @@ def proxy_ssl_one(listener): buf += ssl_sock.recv(65536) ssl_sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 2\r\n" - "Connection: close\r\n" - "\r\n" - "Hi" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 2\r\n" + b"Connection: close\r\n" + b"\r\n" + b"Hi" ) ssl_sock.close() - def echo_socket_handler(listener): + def echo_socket_handler(listener: socket.socket) -> None: proxy_ssl_one(listener) proxy_ssl_one(listener) self._start_server(echo_socket_handler) - base_url = "http://%s:%d" % (self.host, self.port) + base_url = f"http://{self.host}:{self.port}" with proxy_from_url(base_url, ca_certs=DEFAULT_CA) as proxy: - url = "https://{0}".format(self.host) + url = f"https://{self.host}" conn = proxy.connection_from_url(url) r = conn.urlopen("GET", url, retries=0) assert r.status == 200 r = conn.urlopen("GET", url, retries=0) assert r.status == 200 - def test_connect_ipv6_addr(self): + def test_connect_ipv6_addr(self) -> None: ipv6_addr = "2001:4998:c:a06::2:4008" - def echo_socket_handler(listener): + def echo_socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1124,9 +1216,9 @@ def echo_socket_handler(listener): buf += sock.recv(65536) s = buf.decode("utf-8") - if s.startswith("CONNECT [%s]:443" % (ipv6_addr,)): + if s.startswith(f"CONNECT [{ipv6_addr}]:443"): sock.send(b"HTTP/1.1 200 Connection Established\r\n\r\n") - ssl_sock = ssl.wrap_socket( + ssl_sock = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], @@ -1149,30 +1241,60 @@ def echo_socket_handler(listener): sock.close() self._start_server(echo_socket_handler) - base_url = "http://%s:%d" % (self.host, self.port) + base_url = f"http://{self.host}:{self.port}" with proxy_from_url(base_url, cert_reqs="NONE") as proxy: - url = "https://[{0}]".format(ipv6_addr) + url = f"https://[{ipv6_addr}]" conn = proxy.connection_from_url(url) try: - r = conn.urlopen("GET", url, retries=0) + with pytest.warns(InsecureRequestWarning): + r = conn.urlopen("GET", url, retries=0) assert r.status == 200 except MaxRetryError: - self.fail("Invalid IPv6 format in HTTP CONNECT request") + pytest.fail("Invalid IPv6 format in HTTP CONNECT request") + + @pytest.mark.parametrize("target_scheme", ["http", "https"]) + def test_https_proxymanager_connected_to_http_proxy( + self, target_scheme: str + ) -> None: + errored = Event() + + def http_socket_handler(listener: socket.socket) -> None: + sock = listener.accept()[0] + sock.send(b"HTTP/1.0 501 Not Implemented\r\nConnection: close\r\n\r\n") + errored.wait() + sock.close() + + self._start_server(http_socket_handler) + base_url = f"https://{self.host}:{self.port}" + + with ProxyManager(base_url, cert_reqs="NONE") as proxy: + with pytest.raises(MaxRetryError) as e: + proxy.request("GET", f"{target_scheme}://example.com", retries=0) + + errored.set() # Avoid a ConnectionAbortedError on Windows. + + assert type(e.value.reason) == ProxyError + assert "Your proxy appears to only use HTTP and not HTTPS" in str( + e.value.reason + ) class TestSSL(SocketDummyServerTestCase): - def test_ssl_failure_midway_through_conn(self): - def socket_handler(listener): + def test_ssl_failure_midway_through_conn(self) -> None: + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] sock2 = sock.dup() - ssl_sock = ssl.wrap_socket( - sock, - server_side=True, - keyfile=DEFAULT_CERTS["keyfile"], - certfile=DEFAULT_CERTS["certfile"], - ca_certs=DEFAULT_CA, - ) + try: + ssl_sock = original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + except ssl.SSLError: + return buf = b"" while not buf.endswith(b"\r\n\r\n"): @@ -1180,13 +1302,11 @@ def socket_handler(listener): # Deliberately send from the non-SSL socket. sock2.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 2\r\n" - "\r\n" - "Hi" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"Hi" ) sock2.close() ssl_sock.close() @@ -1197,13 +1317,15 @@ def socket_handler(listener): pool.request("GET", "/", retries=0) assert isinstance(cm.value.reason, SSLError) - @notSecureTransport - def test_ssl_read_timeout(self): + @notSecureTransport() + def test_ssl_read_timeout(self) -> None: timed_out = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - ssl_sock = ssl.wrap_socket( + # disable Nagle's algorithm so there's no delay in sending a partial body + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + ssl_sock = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], @@ -1216,13 +1338,11 @@ def socket_handler(listener): # Send incomplete message (note Content-Length) ssl_sock.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 10\r\n" - "\r\n" - "Hi-" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + b"Hi-" ) timed_out.wait() @@ -1240,17 +1360,22 @@ def socket_handler(listener): finally: timed_out.set() - def test_ssl_failed_fingerprint_verification(self): - def socket_handler(listener): + def test_ssl_failed_fingerprint_verification(self) -> None: + def socket_handler(listener: socket.socket) -> None: for i in range(2): sock = listener.accept()[0] - ssl_sock = ssl.wrap_socket( - sock, - server_side=True, - keyfile=DEFAULT_CERTS["keyfile"], - certfile=DEFAULT_CERTS["certfile"], - ca_certs=DEFAULT_CA, - ) + try: + ssl_sock = original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + except (ssl.SSLError, ConnectionResetError): + if i == 1: + raise + return ssl_sock.send( b"HTTP/1.1 200 OK\r\n" @@ -1266,7 +1391,7 @@ def socket_handler(listener): # GitHub's fingerprint. Valid, but not matching. fingerprint = "A0:C4:A7:46:00:ED:A7:2D:C0:BE:CB:9A:8C:B6:07:CA:58:EE:74:5E" - def request(): + def request() -> None: pool = HTTPSConnectionPool( self.host, self.port, assert_fingerprint=fingerprint ) @@ -1286,12 +1411,12 @@ def request(): with pytest.raises(MaxRetryError): request() - def test_retry_ssl_error(self): - def socket_handler(listener): + def test_retry_ssl_error(self) -> None: + def socket_handler(listener: socket.socket) -> None: # first request, trigger an SSLError sock = listener.accept()[0] sock2 = sock.dup() - ssl_sock = ssl.wrap_socket( + ssl_sock = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], @@ -1303,20 +1428,18 @@ def socket_handler(listener): # Deliberately send from the non-SSL socket to trigger an SSLError sock2.send( - ( - "HTTP/1.1 200 OK\r\n" - "Content-Type: text/plain\r\n" - "Content-Length: 4\r\n" - "\r\n" - "Fail" - ).encode("utf-8") + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"Fail" ) sock2.close() ssl_sock.close() # retried request sock = listener.accept()[0] - ssl_sock = ssl.wrap_socket( + ssl_sock = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], @@ -1339,16 +1462,19 @@ def socket_handler(listener): response = pool.urlopen("GET", "/", retries=1) assert response.data == b"Success" - def test_ssl_load_default_certs_when_empty(self): - def socket_handler(listener): + def test_ssl_load_default_certs_when_empty(self) -> None: + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - ssl_sock = ssl.wrap_socket( - sock, - server_side=True, - keyfile=DEFAULT_CERTS["keyfile"], - certfile=DEFAULT_CERTS["certfile"], - ca_certs=DEFAULT_CA, - ) + try: + ssl_sock = original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + except (ssl.SSLError, OSError): + return buf = b"" while not buf.endswith(b"\r\n\r\n"): @@ -1371,21 +1497,26 @@ def socket_handler(listener): with mock.patch("urllib3.util.ssl_.SSLContext", lambda *_, **__: context): self._start_server(socket_handler) with HTTPSConnectionPool(self.host, self.port) as pool: - with pytest.raises(MaxRetryError): + # Without a proper `SSLContext`, this request will fail in some + # arbitrary way, but we only want to know if load_default_certs() was + # called, which is why we accept any `Exception` here. + with pytest.raises(Exception): pool.request("GET", "/", timeout=SHORT_TIMEOUT) context.load_default_certs.assert_called_with() - @notPyPy2 - def test_ssl_dont_load_default_certs_when_given(self): - def socket_handler(listener): + def test_ssl_dont_load_default_certs_when_given(self) -> None: + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] - ssl_sock = ssl.wrap_socket( - sock, - server_side=True, - keyfile=DEFAULT_CERTS["keyfile"], - certfile=DEFAULT_CERTS["certfile"], - ca_certs=DEFAULT_CA, - ) + try: + ssl_sock = original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + except (ssl.SSLError, OSError): + return buf = b"" while not buf.endswith(b"\r\n\r\n"): @@ -1412,45 +1543,45 @@ def socket_handler(listener): {"ca_certs": "a", "ca_cert_dir": "a"}, {"ssl_context": context}, ]: - self._start_server(socket_handler) with HTTPSConnectionPool(self.host, self.port, **kwargs) as pool: - with pytest.raises(MaxRetryError): + with pytest.raises(Exception): pool.request("GET", "/", timeout=SHORT_TIMEOUT) context.load_default_certs.assert_not_called() - def test_load_verify_locations_exception(self): + def test_load_verify_locations_exception(self) -> None: """ Ensure that load_verify_locations raises SSLError for all backends """ with pytest.raises(SSLError): - ssl_wrap_socket(None, ca_certs="/tmp/fake-file") + ssl_wrap_socket(None, ca_certs="/tmp/fake-file") # type: ignore[call-overload] - def test_ssl_custom_validation_failure_terminates(self, tmpdir): + def test_ssl_custom_validation_failure_terminates(self, tmpdir: Path) -> None: """ Ensure that the underlying socket is terminated if custom validation fails. """ server_closed = Event() - def is_closed_socket(sock): + def is_closed_socket(sock: socket.socket) -> bool: try: - sock.settimeout(SHORT_TIMEOUT) # Python 3 - sock.recv(1) # Python 2 - except (OSError, socket.error): + sock.settimeout(SHORT_TIMEOUT) + except OSError: return True return False - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] try: - _ = ssl.wrap_socket( + _ = original_ssl_wrap_socket( sock, server_side=True, keyfile=DEFAULT_CERTS["keyfile"], certfile=DEFAULT_CERTS["certfile"], ca_certs=DEFAULT_CA, ) + except ConnectionResetError: + return except ssl.SSLError as e: assert "alert unknown ca" in str(e) if is_closed_socket(sock): @@ -1470,9 +1601,66 @@ def socket_handler(listener): pool.request("GET", "/", retries=False, timeout=LONG_TIMEOUT) assert server_closed.wait(LONG_TIMEOUT), "The socket was not terminated" + # SecureTransport can read only small pieces of data at the moment. + # https://github.com/urllib3/urllib3/pull/2674 + @notSecureTransport() + @pytest.mark.skipif( + os.environ.get("CI") == "true" and sys.implementation.name == "pypy", + reason="too slow to run in CI", + ) + @pytest.mark.parametrize( + "preload_content,read_amt", [(True, None), (False, None), (False, 2**31)] + ) + def test_requesting_large_resources_via_ssl( + self, preload_content: bool, read_amt: int | None + ) -> None: + """ + Ensure that it is possible to read 2 GiB or more via an SSL + socket. + https://github.com/urllib3/urllib3/issues/2513 + """ + content_length = 2**31 # (`int` max value in C) + 1. + ssl_ready = Event() + + def socket_handler(listener: socket.socket) -> None: + sock = listener.accept()[0] + ssl_sock = original_ssl_wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + ssl_ready.set() + + while not ssl_sock.recv(65536).endswith(b"\r\n\r\n"): + continue + + ssl_sock.send( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: %d\r\n\r\n" % content_length + ) + + chunks = 2 + for i in range(chunks): + ssl_sock.sendall(bytes(content_length // chunks)) + + ssl_sock.close() + sock.close() + + self._start_server(socket_handler) + ssl_ready.wait(5) + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA, retries=False + ) as pool: + response = pool.request("GET", "/", preload_content=preload_content) + data = response.data if preload_content else response.read(read_amt) + assert len(data) == content_length + class TestErrorWrapping(SocketDummyServerTestCase): - def test_bad_statusline(self): + def test_bad_statusline(self) -> None: self.start_response_handler( b"HTTP/1.1 Omg What Is This?\r\n" b"Content-Length: 0\r\n" b"\r\n" ) @@ -1480,7 +1668,7 @@ def test_bad_statusline(self): with pytest.raises(ProtocolError): pool.request("GET", "/") - def test_unknown_protocol(self): + def test_unknown_protocol(self) -> None: self.start_response_handler( b"HTTP/1000 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n" ) @@ -1490,8 +1678,7 @@ def test_unknown_protocol(self): class TestHeaders(SocketDummyServerTestCase): - @onlyPy3 - def test_httplib_headers_case_insensitive(self): + def test_httplib_headers_case_insensitive(self) -> None: self.start_response_handler( b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n" @@ -1503,11 +1690,11 @@ def test_httplib_headers_case_insensitive(self): r = pool.request("GET", "/") assert HEADERS == dict(r.headers.items()) # to preserve case sensitivity - def start_parsing_handler(self): - self.parsed_headers = OrderedDict() - self.received_headers = [] + def start_parsing_handler(self) -> None: + self.parsed_headers: typing.OrderedDict[str, str] = OrderedDict() + self.received_headers: list[bytes] = [] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1522,21 +1709,19 @@ def socket_handler(listener): (key, value) = header.split(b": ") self.parsed_headers[key.decode("ascii")] = value.decode("ascii") - sock.send( - ("HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n").encode("utf-8") - ) + sock.send(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n") sock.close() self._start_server(socket_handler) - def test_headers_are_sent_with_the_original_case(self): + def test_headers_are_sent_with_the_original_case(self) -> None: headers = {"foo": "bar", "bAz": "quux"} self.start_parsing_handler() expected_headers = { "Accept-Encoding": "identity", - "Host": "{0}:{1}".format(self.host, self.port), + "Host": f"{self.host}:{self.port}", "User-Agent": _get_default_user_agent(), } expected_headers.update(headers) @@ -1545,13 +1730,13 @@ def test_headers_are_sent_with_the_original_case(self): pool.request("GET", "/", headers=HTTPHeaderDict(headers)) assert expected_headers == self.parsed_headers - def test_ua_header_can_be_overridden(self): + def test_ua_header_can_be_overridden(self) -> None: headers = {"uSeR-AgENt": "Definitely not urllib3!"} self.start_parsing_handler() expected_headers = { "Accept-Encoding": "identity", - "Host": "{0}:{1}".format(self.host, self.port), + "Host": f"{self.host}:{self.port}", } expected_headers.update(headers) @@ -1559,49 +1744,48 @@ def test_ua_header_can_be_overridden(self): pool.request("GET", "/", headers=HTTPHeaderDict(headers)) assert expected_headers == self.parsed_headers - def test_request_headers_are_sent_in_the_original_order(self): + def test_request_headers_are_sent_in_the_original_order(self) -> None: # NOTE: Probability this test gives a false negative is 1/(K!) K = 16 # NOTE: Provide headers in non-sorted order (i.e. reversed) # so that if the internal implementation tries to sort them, # a change will be detected. expected_request_headers = [ - (u"X-Header-%d" % i, str(i)) for i in reversed(range(K)) + (f"X-Header-{int(i)}", str(i)) for i in reversed(range(K)) ] - def filter_non_x_headers(d): + def filter_non_x_headers( + d: typing.OrderedDict[str, str] + ) -> list[tuple[str, str]]: return [(k, v) for (k, v) in d.items() if k.startswith("X-Header-")] - request_headers = OrderedDict() - self.start_parsing_handler() with HTTPConnectionPool(self.host, self.port, retries=False) as pool: pool.request("GET", "/", headers=OrderedDict(expected_request_headers)) - request_headers = filter_non_x_headers(self.parsed_headers) - assert expected_request_headers == request_headers + assert expected_request_headers == filter_non_x_headers(self.parsed_headers) - @resolvesLocalhostFQDN - def test_request_host_header_ignores_fqdn_dot(self): + @resolvesLocalhostFQDN() + def test_request_host_header_ignores_fqdn_dot(self) -> None: self.start_parsing_handler() with HTTPConnectionPool(self.host + ".", self.port, retries=False) as pool: pool.request("GET", "/") self.assert_header_received( - self.received_headers, "Host", "%s:%s" % (self.host, self.port) + self.received_headers, "Host", f"{self.host}:{self.port}" ) - def test_response_headers_are_returned_in_the_original_order(self): + def test_response_headers_are_returned_in_the_original_order(self) -> None: # NOTE: Probability this test gives a false negative is 1/(K!) K = 16 # NOTE: Provide headers in non-sorted order (i.e. reversed) # so that if the internal implementation tries to sort them, # a change will be detected. expected_response_headers = [ - ("X-Header-%d" % i, str(i)) for i in reversed(range(K)) + (f"X-Header-{int(i)}", str(i)) for i in reversed(range(K)) ] - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1628,13 +1812,80 @@ def socket_handler(listener): ] assert expected_response_headers == actual_response_headers + @pytest.mark.parametrize( + "method_type, body_type", + [ + ("GET", None), + ("POST", None), + ("POST", "bytes"), + ("POST", "bytes-io"), + ], + ) + def test_headers_sent_with_add( + self, method_type: str, body_type: str | None + ) -> None: + """ + Confirm that when adding headers with combine=True that we simply append to the + most recent value, rather than create a new header line. + """ + body: None | bytes | io.BytesIO + if body_type is None: + body = None + elif body_type == "bytes": + body = b"my-body" + elif body_type == "bytes-io": + body = io.BytesIO(b"bytes-io-body") + body.seek(0, 0) + else: + raise ValueError("Unknonw body type") + + buffer: bytes = b"" + + def socket_handler(listener: socket.socket) -> None: + nonlocal buffer + sock = listener.accept()[0] + sock.settimeout(0) + + start = time.time() + while time.time() - start < (LONG_TIMEOUT / 2): + try: + buffer += sock.recv(65536) + except OSError: + continue + + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Server: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + sock.close() + + self._start_server(socket_handler) + + headers = HTTPHeaderDict() + headers.add("A", "1") + headers.add("C", "3") + headers.add("B", "2") + headers.add("B", "3") + headers.add("A", "4", combine=False) + headers.add("C", "5", combine=True) + headers.add("C", "6") + + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request( + method_type, + "/", + body=body, + headers=headers, + ) + assert r.status == 200 + assert b"A: 1\r\nA: 4\r\nC: 3, 5\r\nC: 6\r\nB: 2\r\nB: 3" in buffer + -@pytest.mark.skipif( - issubclass(httplib.HTTPMessage, MimeToolMessage), - reason="Header parsing errors not available", -) class TestBrokenHeaders(SocketDummyServerTestCase): - def _test_broken_header_parsing(self, headers, unparsed_data_check=None): + def _test_broken_header_parsing( + self, headers: list[bytes], unparsed_data_check: str | None = None + ) -> None: self.start_response_handler( ( b"HTTP/1.1 200 OK\r\n" @@ -1652,29 +1903,30 @@ def _test_broken_header_parsing(self, headers, unparsed_data_check=None): for record in logs: if ( "Failed to parse headers" in record.msg - and pool._absolute_url("/") == record.args[0] + and isinstance(record.args, tuple) + and _url_from_pool(pool, "/") == record.args[0] ): if ( unparsed_data_check is None or unparsed_data_check in record.getMessage() ): return - self.fail("Missing log about unparsed headers") + pytest.fail("Missing log about unparsed headers") - def test_header_without_name(self): + def test_header_without_name(self) -> None: self._test_broken_header_parsing([b": Value", b"Another: Header"]) - def test_header_without_name_or_value(self): + def test_header_without_name_or_value(self) -> None: self._test_broken_header_parsing([b":", b"Another: Header"]) - def test_header_without_colon_or_value(self): + def test_header_without_colon_or_value(self) -> None: self._test_broken_header_parsing( [b"Broken Header", b"Another: Header"], "Broken Header" ) class TestHeaderParsingContentType(SocketDummyServerTestCase): - def _test_okay_header_parsing(self, header): + def _test_okay_header_parsing(self, header: bytes) -> None: self.start_response_handler( (b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n") + header + b"\r\n\r\n" ) @@ -1686,15 +1938,15 @@ def _test_okay_header_parsing(self, header): for record in logs: assert "Failed to parse headers" not in record.msg - def test_header_text_plain(self): + def test_header_text_plain(self) -> None: self._test_okay_header_parsing(b"Content-type: text/plain") - def test_header_message_rfc822(self): + def test_header_message_rfc822(self) -> None: self._test_okay_header_parsing(b"Content-type: message/rfc822") class TestHEAD(SocketDummyServerTestCase): - def test_chunked_head_response_does_not_hang(self): + def test_chunked_head_response_does_not_hang(self) -> None: self.start_response_handler( b"HTTP/1.1 200 OK\r\n" b"Transfer-Encoding: chunked\r\n" @@ -1707,7 +1959,7 @@ def test_chunked_head_response_does_not_hang(self): # stream will use the read_chunked method here. assert [] == list(r.stream()) - def test_empty_head_response_does_not_hang(self): + def test_empty_head_response_does_not_hang(self) -> None: self.start_response_handler( b"HTTP/1.1 200 OK\r\n" b"Content-Length: 256\r\n" @@ -1722,10 +1974,10 @@ def test_empty_head_response_does_not_hang(self): class TestStream(SocketDummyServerTestCase): - def test_stream_none_unchunked_response_does_not_hang(self): + def test_stream_none_unchunked_response_does_not_hang(self) -> None: done_event = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1753,10 +2005,10 @@ def socket_handler(listener): class TestBadContentLength(SocketDummyServerTestCase): - def test_enforce_content_length_get(self): + def test_enforce_content_length_get(self) -> None: done_event = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1780,21 +2032,14 @@ def socket_handler(listener): "GET", url="/", preload_content=False, enforce_content_length=True ) data = get_response.stream(100) - # Read "good" data before we try to read again. - # This won't trigger till generator is exhausted. - next(data) - try: + with pytest.raises(ProtocolError, match="12 bytes read, 10 more expected"): next(data) - assert False - except ProtocolError as e: - assert "12 bytes read, 10 more expected" in str(e) - done_event.set() - def test_enforce_content_length_no_body(self): + def test_enforce_content_length_no_body(self) -> None: done_event = Event() - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: sock = listener.accept()[0] buf = b"" @@ -1823,8 +2068,8 @@ def socket_handler(listener): class TestRetryPoolSizeDrainFail(SocketDummyServerTestCase): - def test_pool_size_retry_drain_fail(self): - def socket_handler(listener): + def test_pool_size_retry_drain_fail(self) -> None: + def socket_handler(listener: socket.socket) -> None: for _ in range(2): sock = listener.accept()[0] while not sock.recv(65536).endswith(b"\r\n\r\n"): @@ -1850,8 +2095,8 @@ def socket_handler(listener): class TestBrokenPipe(SocketDummyServerTestCase): - @notWindows - def test_ignore_broken_pipe_errors(self, monkeypatch): + @notWindows() + def test_ignore_broken_pipe_errors(self, monkeypatch: pytest.MonkeyPatch) -> None: # On Windows an aborted connection raises an error on # attempts to read data out of a socket that's been closed. sock_shut = Event() @@ -1859,12 +2104,12 @@ def test_ignore_broken_pipe_errors(self, monkeypatch): # a buffer that will cause two sendall calls buf = "a" * 1024 * 1024 * 4 - def connect_and_wait(*args, **kw): + def connect_and_wait(*args: typing.Any, **kw: typing.Any) -> None: ret = orig_connect(*args, **kw) assert sock_shut.wait(5) return ret - def socket_handler(listener): + def socket_handler(listener: socket.socket) -> None: for i in range(2): sock = listener.accept()[0] sock.send( @@ -1893,8 +2138,8 @@ def socket_handler(listener): class TestMultipartResponse(SocketDummyServerTestCase): - def test_multipart_assert_header_parsing_no_defects(self): - def socket_handler(listener): + def test_multipart_assert_header_parsing_no_defects(self) -> None: + def socket_handler(listener: socket.socket) -> None: for _ in range(2): sock = listener.accept()[0] while not sock.recv(65536).endswith(b"\r\n\r\n"): @@ -1918,7 +2163,7 @@ def socket_handler(listener): from urllib3.connectionpool import log with mock.patch.object(log, "warning") as log_warning: - with HTTPConnectionPool(self.host, self.port, timeout=3) as pool: + with HTTPConnectionPool(self.host, self.port, timeout=LONG_TIMEOUT) as pool: resp = pool.urlopen("GET", "/") assert resp.status == 404 assert ( @@ -1927,3 +2172,234 @@ def socket_handler(listener): ) assert len(resp.data) == 73 log_warning.assert_not_called() + + +class TestContentFraming(SocketDummyServerTestCase): + @pytest.mark.parametrize("content_length", [None, 0]) + @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH"]) + def test_content_length_0_by_default( + self, method: str, content_length: int | None + ) -> None: + buffer = bytearray() + + def socket_handler(listener: socket.socket) -> None: + nonlocal buffer + sock = listener.accept()[0] + while not buffer.endswith(b"\r\n\r\n"): + buffer += sock.recv(65536) + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Server: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + sock.close() + + self._start_server(socket_handler) + + headers = {} + if content_length is not None: + headers["Content-Length"] = str(content_length) + + with HTTPConnectionPool(self.host, self.port, timeout=3) as pool: + resp = pool.request(method, "/") + assert resp.status == 200 + + sent_bytes = bytes(buffer) + assert b"Accept-Encoding: identity\r\n" in sent_bytes + assert b"Content-Length: 0\r\n" in sent_bytes + assert b"transfer-encoding" not in sent_bytes.lower() + + @pytest.mark.parametrize("chunked", [True, False]) + @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH"]) + @pytest.mark.parametrize("body_type", ["file", "generator", "bytes"]) + def test_chunked_specified( + self, method: str, chunked: bool, body_type: str + ) -> None: + buffer = bytearray() + + def socket_handler(listener: socket.socket) -> None: + nonlocal buffer + sock = listener.accept()[0] + sock.settimeout(0) + + start = time.time() + while time.time() - start < (LONG_TIMEOUT / 2): + try: + buffer += sock.recv(65536) + except OSError: + continue + + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Server: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + sock.close() + + self._start_server(socket_handler) + + body: typing.Any + if body_type == "generator": + + def body_generator() -> typing.Generator[bytes, None, None]: + yield b"x" * 10 + + body = body_generator() + elif body_type == "file": + body = io.BytesIO(b"x" * 10) + body.seek(0, 0) + else: + if chunked is False: + pytest.skip("urllib3 uses Content-Length in this case") + body = b"x" * 10 + + with HTTPConnectionPool( + self.host, self.port, timeout=LONG_TIMEOUT, retries=False + ) as pool: + resp = pool.request(method, "/", chunked=chunked, body=body) + assert resp.status == 200 + + sent_bytes = bytes(buffer) + assert sent_bytes.count(b":") == 5 + assert b"Host: localhost:" in sent_bytes + assert b"Accept-Encoding: identity\r\n" in sent_bytes + assert b"Transfer-Encoding: chunked\r\n" in sent_bytes + assert b"User-Agent: python-urllib3/" in sent_bytes + assert b"content-length" not in sent_bytes.lower() + assert b"\r\n\r\na\r\nxxxxxxxxxx\r\n0\r\n\r\n" in sent_bytes + + @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH"]) + @pytest.mark.parametrize( + "body_type", ["file", "generator", "bytes", "bytearray", "file_text"] + ) + def test_chunked_not_specified(self, method: str, body_type: str) -> None: + buffer = bytearray() + + def socket_handler(listener: socket.socket) -> None: + nonlocal buffer + sock = listener.accept()[0] + sock.settimeout(0) + + start = time.time() + while time.time() - start < (LONG_TIMEOUT / 2): + try: + buffer += sock.recv(65536) + except OSError: + continue + + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Server: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + sock.close() + + self._start_server(socket_handler) + + body: typing.Any + if body_type == "generator": + + def body_generator() -> typing.Generator[bytes, None, None]: + yield b"x" * 10 + + body = body_generator() + should_be_chunked = True + + elif body_type == "file": + body = io.BytesIO(b"x" * 10) + body.seek(0, 0) + should_be_chunked = True + + elif body_type == "file_text": + body = io.StringIO("x" * 10) + body.seek(0, 0) + should_be_chunked = True + + elif body_type == "bytearray": + body = bytearray(b"x" * 10) + should_be_chunked = False + + else: + body = b"x" * 10 + should_be_chunked = False + + with HTTPConnectionPool( + self.host, self.port, timeout=LONG_TIMEOUT, retries=False + ) as pool: + resp = pool.request(method, "/", body=body) + assert resp.status == 200 + + sent_bytes = bytes(buffer) + assert sent_bytes.count(b":") == 5 + assert b"Host: localhost:" in sent_bytes + assert b"Accept-Encoding: identity\r\n" in sent_bytes + assert b"User-Agent: python-urllib3/" in sent_bytes + + if should_be_chunked: + assert b"content-length" not in sent_bytes.lower() + assert b"Transfer-Encoding: chunked\r\n" in sent_bytes + assert b"\r\n\r\na\r\nxxxxxxxxxx\r\n0\r\n\r\n" in sent_bytes + + else: + assert b"Content-Length: 10\r\n" in sent_bytes + assert b"transfer-encoding" not in sent_bytes.lower() + assert sent_bytes.endswith(b"\r\n\r\nxxxxxxxxxx") + + @pytest.mark.parametrize( + "header_transform", + [str.lower, str.title, str.upper], + ) + @pytest.mark.parametrize( + ["header", "header_value", "expected"], + [ + ("content-length", "10", b": 10\r\n\r\nxxxxxxxx"), + ( + "transfer-encoding", + "chunked", + b": chunked\r\n\r\n8\r\nxxxxxxxx\r\n0\r\n\r\n", + ), + ], + ) + def test_framing_set_via_headers( + self, + header_transform: typing.Callable[[str], str], + header: str, + header_value: str, + expected: bytes, + ) -> None: + buffer = bytearray() + + def socket_handler(listener: socket.socket) -> None: + nonlocal buffer + sock = listener.accept()[0] + sock.settimeout(0) + + start = time.time() + while time.time() - start < (LONG_TIMEOUT / 2): + try: + buffer += sock.recv(65536) + except OSError: + continue + + sock.sendall( + b"HTTP/1.1 200 OK\r\n" + b"Server: example.com\r\n" + b"Content-Length: 0\r\n\r\n" + ) + sock.close() + + self._start_server(socket_handler) + + with HTTPConnectionPool( + self.host, self.port, timeout=LONG_TIMEOUT, retries=False + ) as pool: + resp = pool.request( + "POST", + "/", + body=b"xxxxxxxx", + headers={header_transform(header): header_value}, + ) + assert resp.status == 200 + + sent_bytes = bytes(buffer) + assert sent_bytes.endswith(expected) diff --git a/towncrier.toml b/towncrier.toml new file mode 100644 index 0000000000..e1be27a5e9 --- /dev/null +++ b/towncrier.toml @@ -0,0 +1,7 @@ +[tool.towncrier] +package = "urllib3" +package_dir = "src" +filename = "CHANGES.rst" +directory = "changelog" +issue_format = "`#{issue} `__" +title_format = "{version} ({project_date})"