diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..b8690975 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,19 @@ +[run] +parallel = true +concurrency = + thread + multiprocessing +source = + waitress +omit = + waitress/tests/fixtureapps/getline.py + +[paths] +source = + src/waitress + */src/waitress + */site-packages/waitress + +[report] +show_missing = true +precision = 2 diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..a5f3a73f --- /dev/null +++ b/.flake8 @@ -0,0 +1,36 @@ +# Recommended flake8 settings while editing, we use Black for the final linting/say in how code is formatted +# +# pip install flake8 flake8-bugbear +# +# This will warn/error on things that black does not fix, on purpose. +# +# Run: +# +# tox -e run-flake8 +# +# To have it automatically create and install the appropriate tools, and run +# flake8 across the source code/tests + +[flake8] +# max line length is set to 88 in black, here it is set to 80 and we enable bugbear's B950 warning, which is: +# +# B950: Line too long. This is a pragmatic equivalent of pycodestyle’s E501: it +# considers “max-line-length” but only triggers when the value has been +# exceeded by more than 10%. You will no longer be forced to reformat code due +# to the closing parenthesis being one character too far to satisfy the linter. +# At the same time, if you do significantly violate the line length, you will +# receive a message that states what the actual limit is. This is inspired by +# Raymond Hettinger’s “Beyond PEP 8” talk and highway patrol not stopping you +# if you drive < 5mph too fast. Disable E501 to avoid duplicate warnings. +max-line-length = 80 +max-complexity = 12 +select = E,F,W,C,B,B9 +ignore = + # E123 closing bracket does not match indentation of opening bracket’s line + E123 + # E203 whitespace before ‘:’ (Not PEP8 compliant, Python Black) + E203 + # E501 line too long (82 > 79 characters) (replaced by B950 from flake8-bugbear, https://github.com/PyCQA/flake8-bugbear) + E501 + # W503 line break before binary operator (Not PEP8 compliant, Python Black) + W503 diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml new file mode 100644 index 00000000..5a313c2e --- /dev/null +++ b/.github/workflows/ci-tests.yml @@ -0,0 +1,97 @@ +name: Build and test + +on: + # Only on pushes to master or one of the release branches we build on push + push: + branches: + - master + - "[0-9].[0-9]+-branch" + tags: + # Build pull requests + pull_request: + +jobs: + test: + strategy: + matrix: + py: + - "3.7" + - "3.8" + - "3.9" + - "3.10" + - "pypy-3.8" + # Pre-release + - "3.11.0-alpha - 3.11.0" + os: + - "ubuntu-latest" + - "windows-latest" + - "macos-latest" + architecture: + - x64 + - x86 + include: + - py: "pypy-3.8" + toxenv: "pypy38" + exclude: + # Linux and macOS don't have x86 python + - os: "ubuntu-latest" + architecture: x86 + - os: "macos-latest" + architecture: x86 + + name: "Python: ${{ matrix.py }}-${{ matrix.architecture }} on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.py }} + architecture: ${{ matrix.architecture }} + - run: pip install tox + - name: Running tox with specific toxenv + if: ${{ matrix.toxenv != '' }} + env: + TOXENV: ${{ matrix.toxenv }} + run: tox + - name: Running tox for current python version + if: ${{ matrix.toxenv == '' }} + run: tox -e py + + coverage: + runs-on: ubuntu-latest + name: Validate coverage + steps: + - uses: actions/checkout@v2 + - name: Setup python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + architecture: x64 + + - run: pip install tox + - run: tox -e py310,coverage + docs: + runs-on: ubuntu-latest + name: Build the documentation + steps: + - uses: actions/checkout@v2 + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + architecture: x64 + - run: pip install tox + - run: tox -e docs + lint: + runs-on: ubuntu-latest + name: Lint the package + steps: + - uses: actions/checkout@v2 + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + architecture: x64 + - run: pip install tox + - run: tox -e lint diff --git a/.gitignore b/.gitignore index e1dbc2af..146736ff 100644 --- a/.gitignore +++ b/.gitignore @@ -2,13 +2,10 @@ *.pyc env*/ .coverage -.idea/ +.coverage.* .tox/ -nosetests.xml -waitress/coverage.xml dist/ -keep/ build/ coverage.xml -nosetests*.xml -py*-cover.xml +docs/_themes +docs/_build diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 401f191b..00000000 --- a/.travis.yml +++ /dev/null @@ -1,39 +0,0 @@ -# Wire up travis -language: python -sudo: false - -matrix: - include: - - python: 2.7 - env: TOXENV=py27 - - python: 3.3 - env: TOXENV=py33 - - python: 3.4 - env: TOXENV=py34 - - python: 3.5 - env: TOXENV=py35 - - python: 3.6 - env: TOXENV=py36 - - python: pypy - env: TOXENV=pypy - - python: pypy3 - env: TOXENV=pypy3 - - python: 3.5 - env: TOXENV=py2-cover,py3-cover,coverage - - python: 3.5 - env: TOXENV=docs - allow_failures: - - env: TOXENV=pypy3 - -install: - - travis_retry pip install tox - -script: - - travis_retry tox - -notifications: - email: - - pyramid-checkins@lists.repoze.org - irc: - channels: - - "chat.freenode.net#pyramid" diff --git a/CHANGES.txt b/CHANGES.txt index e1726c6e..17ca87e7 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,94 +1,102 @@ -1.1.0 (2017-10-10) ------------------- +2.1.2 +----- -Features -~~~~~~~~ +Bugfix +~~~~~~ -- Waitress now has a __main__ and thus may be called with ``python -mwaitress`` +- When expose_tracebacks is enabled waitress would fail to properly encode + unicode thereby causing another error during error handling. See + https://github.com/Pylons/waitress/pull/378 -Bugfixes -~~~~~~~~ +- Header length checking had a calculation that was done incorrectly when the + data was received across multple socket reads. This calculation has been + corrected, and no longer will Waitress send back a 413 Request Entity Too + Large. See https://github.com/Pylons/waitress/pull/376 -- Waitress no longer allows lowercase HTTP verbs. This change was made to fall - in line with most HTTP servers. See https://github.com/Pylons/waitress/pull/170 +Security Bugfix +~~~~~~~~~~~~~~~ -- When receiving non-ascii bytes in the request URL, waitress will no longer - abruptly close the connection, instead returning a 400 Bad Request. See - https://github.com/Pylons/waitress/pull/162 and - https://github.com/Pylons/waitress/issues/64 +- in 2.1.0 a new feature was introduced that allowed the WSGI thread to start + sending data to the socket. However this introduced a race condition whereby + a socket may be closed in the sending thread while the main thread is about + to call select() therey causing the entire application to be taken down. + Waitress will no longer close the socket in the WSGI thread, instead waking + up the main thread to cleanup. See https://github.com/Pylons/waitress/pull/377 -1.0.2 (2017-02-04) ------------------- +2.1.1 +----- -Features -~~~~~~~~ +Security Bugfix +~~~~~~~~~~~~~~~ -- Python 3.6 is now officially supported in Waitress +- Waitress now validates that chunked encoding extensions are valid, and don't + contain invalid characters that are not allowed. They are still skipped/not + processed, but if they contain invalid data we no longer continue in and + return a 400 Bad Request. This stops potential HTTP desync/HTTP request + smuggling. Thanks to Zhang Zeyu for reporting this issue. See + https://github.com/Pylons/waitress/security/advisories/GHSA-4f7p-27jc-3c36 -Bugfixes -~~~~~~~~ +- Waitress now validates that the chunk length is only valid hex digits when + parsing chunked encoding, and values such as ``0x01`` and ``+01`` are no + longer supported. This stops potential HTTP desync/HTTP request smuggling. + Thanks to Zhang Zeyu for reporting this issue. See + https://github.com/Pylons/waitress/security/advisories/GHSA-4f7p-27jc-3c36 -- Add a work-around for libc issue on Linux not following the documented - standards. If getnameinfo() fails because of DNS not being available it - should return the IP address instead of the reverse DNS entry, however - instead getnameinfo() raises. We catch this, and ask getnameinfo() - for the same information again, explicitly asking for IP address instead of - reverse DNS hostname. See https://github.com/Pylons/waitress/issues/149 and - https://github.com/Pylons/waitress/pull/153 +- Waitress now validates that the Content-Length sent by a remote contains only + digits in accordance with RFC7230 and will return a 400 Bad Request when the + Content-Length header contains invalid data, such as ``+10`` which would + previously get parsed as ``10`` and accepted. This stops potential HTTP + desync/HTTP request smuggling Thanks to Zhang Zeyu for reporting this issue. See + https://github.com/Pylons/waitress/security/advisories/GHSA-4f7p-27jc-3c36 -1.0.1 (2016-10-22) ------------------- +2.1.0 +----- -Bugfixes -~~~~~~~~ +Python Version Support +~~~~~~~~~~~~~~~~~~~~~~ -- IPv6 support on Windows was broken due to missing constants in the socket - module. This has been resolved by setting the constants on Windows if they - are missing. See https://github.com/Pylons/waitress/issues/138 +- Python 3.6 is no longer supported by Waitress -- A ValueError was raised on Windows when passing a string for the port, on - Windows in Python 2 using service names instead of port numbers doesn't work - with `getaddrinfo`. This has been resolved by attempting to convert the port - number to an integer, if that fails a ValueError will be raised. See - https://github.com/Pylons/waitress/issues/139 +- Python 3.10 is fully supported by Waitress +Bugfix +~~~~~~ -1.0.0 (2016-08-31) ------------------- +- ``wsgi.file_wrapper`` now sets the ``seekable``, ``seek``, and ``tell`` + attributes from the underlying file if the underlying file is seekable. This + allows WSGI middleware to implement things like range requests for example -Bugfixes -~~~~~~~~ + See https://github.com/Pylons/waitress/issues/359 and + https://github.com/Pylons/waitress/pull/363 -- Removed `AI_ADDRCONFIG` from the call to `getaddrinfo`, this resolves an - issue whereby `getaddrinfo` wouldn't return any addresses to `bind` to on - hosts where there is no internet connection but localhost is requested to be - bound to. See https://github.com/Pylons/waitress/issues/131 for more - information. +- In Python 3 ``OSError`` is no longer subscriptable, this caused failures on + Windows attempting to loop to find an socket that would work for use in the + trigger. -Deprecations -~~~~~~~~~~~~ + See https://github.com/Pylons/waitress/pull/361 -- Python 2.6 is no longer supported. +- Fixed an issue whereby ``BytesIO`` objects were not properly closed, and + thereby would not get cleaned up until garbage collection would get around to + it. -Features -~~~~~~~~ + This led to potential for random memory spikes/memory issues, see + https://github.com/Pylons/waitress/pull/358 and + https://github.com/Pylons/waitress/issues/357 . -- IPv6 support + With thanks to Florian Schulze for testing/vaidating this fix! -- Waitress is now able to listen on multiple sockets, including IPv4 and IPv6. - Instead of passing in a host/port combination you now provide waitress with a - space delineated list, and it will create as many sockets as required. +Features +~~~~~~~~ - .. code-block:: python +- When the WSGI app starts sending data to the output buffer, we now attempt to + send data directly to the socket. This avoids needing to wake up the main + thread to start sending data. Allowing faster transmission of the first byte. + See https://github.com/Pylons/waitress/pull/364 - from waitress import serve - serve(wsgiapp, listen='0.0.0.0:8080 [::]:9090 *:6543') + With thanks to Michael Merickel for being a great rubber ducky! -Security -~~~~~~~~ +- Add REQUEST_URI to the WSGI environment. -- Waitress will now drop HTTP headers that contain an underscore in the key - when received from a client. This is to stop any possible underscore/dash - conflation that may lead to security issues. See - https://github.com/Pylons/waitress/pull/80 and - https://www.djangoproject.com/weblog/2015/jan/13/security/ + REQUEST_URI is similar to ``request_uri`` in nginx. It is a string that + contains the request path before separating the query string and + decoding ``%``-escaped characters. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3f309fc6..6d6df5d1 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -88,7 +88,7 @@ Committed Code. Licensing Exceptions ==================== -Code committed within the ``docs/`` subdirectory of the Pyramid source +Code committed within the ``docs/`` subdirectory of the Waitress source control repository and "docstrings" which appear in the documentation generated by running "make" within this directory is licensed under the Creative Commons Attribution-Noncommercial-Share Alike 3.0 United States @@ -98,7 +98,7 @@ List of Contributors ==================== The below-signed are contributors to a code repository that is part of the -project named "Pyramid". Each below-signed contributor has read, understand +project named "Waitress". Each below-signed contributor has read, understand and agrees to the terms above in the section within this document entitled "Pylons Project Contributor Agreement" as of the date beside his or her name. @@ -107,6 +107,8 @@ Contributors - Chris McDonough, 2011/12/17 +- Michael Merickel, 2012/01/16 + - Damien Baty, 2012/10/25 - Georges Dubus, 2012/11/24 @@ -140,3 +142,7 @@ Contributors - Atsushi Odagiri, 2017-02-12 - David D Lowe, 2017-06-02 + +- Jack Wearden, 2018-05-18 + +- Frank Krick, 2018-10-29 diff --git a/HISTORY.txt b/HISTORY.txt index 99325824..2eb829d2 100644 --- a/HISTORY.txt +++ b/HISTORY.txt @@ -1,3 +1,521 @@ +2.0.0 (2021-03-07) +------------------ + +Friendly Reminder +~~~~~~~~~~~~~~~~~ + +This release still contains a variety of deprecation notices about defaults +that can be set for a variety of options. + +Please note that this is your last warning, and you should update your +configuration if you do NOT want to use the new defaults. + +See the arguments documentation page for all supported options, and pay +attention to the warnings: + +https://docs.pylonsproject.org/projects/waitress/en/stable/arguments.html + +Without further ado, here's a short list of great changes thanks to our +contributors! + +Bugfixes/Features +~~~~~~~~~~~~~~~~~ + +- Fix a crash on startup when listening to multiple interfaces. + See https://github.com/Pylons/waitress/pull/332 + +- Waitress no longer attempts to guess at what the ``server_name`` should be for + a listen socket, instead it always use a new adjustment/argument named + ``server_name``. + + Please see the documentation for ``server_name`` in + https://docs.pylonsproject.org/projects/waitress/en/latest/arguments.html and + see https://github.com/Pylons/waitress/pull/329 + +- Allow tasks to notice if the client disconnected. + + This inserts a callable ``waitress.client_disconnected`` into the environment + that allows the task to check if the client disconnected while waiting for + the response at strategic points in the execution and to cancel the + operation. + + It requires setting the new adjustment ``channel_request_lookahead`` to a value + larger than 0, which continues to read requests from a channel even if a + request is already being processed on that channel, up to the given count, + since a client disconnect is detected by reading from a readable socket and + receiving an empty result. + + See https://github.com/Pylons/waitress/pull/310 + +- Drop Python 2.7 and 3.5 support + +- The server now issues warning output when it there are enough open + connections (controlled by "connection_limit"), that it is no longer + accepting new connections. This situation was previously difficult to + diagnose. + See https://github.com/Pylons/waitress/pull/322 + +1.4.4 (2020-06-01) +------------------ + +- Fix an issue with keep-alive connections in which memory usage was higher + than expected because output buffers were being reused across requests on + a long-lived connection and each buffer would not be freed until it was full + or the connection was closed. Buffers are now rotated per-request to + stabilize their behavior. + + See https://github.com/Pylons/waitress/pull/300 + +- Waitress threads have been updated to contain their thread number. This will + allow loggers that use that information to print the thread that the log is + coming from. + + See https://github.com/Pylons/waitress/pull/302 + +1.4.3 (2020-02-02) +------------------ + +Security Fixes +~~~~~~~~~~~~~~ + +- In Waitress version 1.4.2 a new regular expression was added to validate the + headers that Waitress receives to make sure that it matches RFC7230. + Unfortunately the regular expression was written in a way that with invalid + input it leads to catastrophic backtracking which allows for a Denial of + Service and CPU usage going to a 100%. + + This was reported by Fil Zembowicz to the Pylons Project. Please see + https://github.com/Pylons/waitress/security/advisories/GHSA-73m2-3pwg-5fgc + for more information. + +1.4.2 (2020-01-02) +------------------ + +Security Fixes +~~~~~~~~~~~~~~ + +- This is a follow-up to the fix introduced in 1.4.1 to tighten up the way + Waitress strips whitespace from header values. This makes sure Waitress won't + accidentally treat non-printable characters as whitespace and lead to a + potental HTTP request smuggling/splitting security issue. + + Thanks to ZeddYu Lu for the extra test cases. + + Please see the security advisory for more information: + https://github.com/Pylons/waitress/security/advisories/GHSA-m5ff-3wj3-8ph4 + + CVE-ID: CVE-2019-16789 + +Bugfixes +~~~~~~~~ + +- Updated the regex used to validate header-field content to match the errata + that was published for RFC7230. + + See: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + + +1.4.1 (2019-12-24) +------------------ + +Security Fixes +~~~~~~~~~~~~~~ + +- Waitress did not properly validate that the HTTP headers it received were + properly formed, thereby potentially allowing a front-end server to treat a + request different from Waitress. This could lead to HTTP request + smuggling/splitting. + + Please see the security advisory for more information: + https://github.com/Pylons/waitress/security/advisories/GHSA-m5ff-3wj3-8ph4 + + CVE-ID: CVE-2019-16789 + +1.4.0 (2019-12-20) +------------------ + +Bugfixes +~~~~~~~~ + +- Waitress used to slam the door shut on HTTP pipelined requests without + setting the ``Connection: close`` header as appropriate in the response. This + is of course not very friendly. Waitress now explicitly sets the header when + responding with an internally generated error such as 400 Bad Request or 500 + Internal Server Error to notify the remote client that it will be closing the + connection after the response is sent. + +- Waitress no longer allows any spaces to exist between the header field-name + and the colon. While waitress did not strip the space and thereby was not + vulnerable to any potential header field-name confusion, it should have sent + back a 400 Bad Request. See https://github.com/Pylons/waitress/issues/273 + +Security Fixes +~~~~~~~~~~~~~~ + +- Waitress implemented a "MAY" part of the RFC7230 + (https://tools.ietf.org/html/rfc7230#section-3.5) which states: + + Although the line terminator for the start-line and header fields is + the sequence CRLF, a recipient MAY recognize a single LF as a line + terminator and ignore any preceding CR. + + Unfortunately if a front-end server does not parse header fields with an LF + the same way as it does those with a CRLF it can lead to the front-end and + the back-end server parsing the same HTTP message in two different ways. This + can lead to a potential for HTTP request smuggling/splitting whereby Waitress + may see two requests while the front-end server only sees a single HTTP + message. + + For more information I can highly recommend the blog post by ZeddYu Lu + https://blog.zeddyu.info/2019/12/08/HTTP-Smuggling-en/ + + Please see the security advisory for more information: + https://github.com/Pylons/waitress/security/advisories/GHSA-pg36-wpm5-g57p + + CVE-ID: CVE-2019-16785 + +- Waitress used to treat LF the same as CRLF in ``Transfer-Encoding: chunked`` + requests, while the maintainer doesn't believe this could lead to a security + issue, this is no longer supported and all chunks are now validated to be + properly framed with CRLF as required by RFC7230. + +- Waitress now validates that the ``Transfer-Encoding`` header contains only + transfer codes that it is able to decode. At the moment that includes the + only valid header value being ``chunked``. + + That means that if the following header is sent: + + ``Transfer-Encoding: gzip, chunked`` + + Waitress will send back a 501 Not Implemented with an error message stating + as such, as while Waitress supports ``chunked`` encoding it does not support + ``gzip`` and it is unable to pass that to the underlying WSGI environment + correctly. + + Waitress DOES NOT implement support for ``Transfer-Encoding: identity`` + eventhough ``identity`` was valid in RFC2616, it was removed in RFC7230. + Please update your clients to remove the ``Transfer-Encoding`` header if the + only transfer coding is ``identity`` or update your client to use + ``Transfer-Encoding: chunked`` instead of ``Transfer-Encoding: identity, + chunked``. + + Please see the security advisory for more information: + https://github.com/Pylons/waitress/security/advisories/GHSA-g2xc-35jw-c63p + + CVE-ID: CVE-2019-16786 + +- While validating the ``Transfer-Encoding`` header, Waitress now properly + handles line-folded ``Transfer-Encoding`` headers or those that contain + multiple comma seperated values. This closes a potential issue where a + front-end server may treat the request as being a chunked request (and thus + ignoring the Content-Length) and Waitress using the Content-Length as it was + looking for the single value ``chunked`` and did not support comma seperated + values. + +- Waitress used to explicitly set the Content-Length header to 0 if it was + unable to parse it as an integer (for example if the Content-Length header + was sent twice (and thus folded together), or was invalid) thereby allowing + for a potential request to be split and treated as two requests by HTTP + pipelining support in Waitress. If Waitress is now unable to parse the + Content-Length header, a 400 Bad Request is sent back to the client. + + Please see the security advisory for more information: + https://github.com/Pylons/waitress/security/advisories/GHSA-4ppp-gpcr-7qf6 + +1.3.1 (2019-08-27) +------------------ + +Bugfixes +~~~~~~~~ + +- Waitress won't accidentally throw away part of the path if it starts with a + double slash (``GET //testing/whatever HTTP/1.0``). WSGI applications will + now receive a ``PATH_INFO`` in the environment that contains + ``//testing/whatever`` as required. See + https://github.com/Pylons/waitress/issues/260 and + https://github.com/Pylons/waitress/pull/261 + + +1.3.0 (2019-04-22) +------------------ + +Deprecations +~~~~~~~~~~~~ + +- The ``send_bytes`` adjustment now defaults to ``1`` and is deprecated + pending removal in a future release. + and https://github.com/Pylons/waitress/pull/246 + +Features +~~~~~~~~ + +- Add a new ``outbuf_high_watermark`` adjustment which is used to apply + backpressure on the ``app_iter`` to avoid letting it spin faster than data + can be written to the socket. This stabilizes responses that iterate quickly + with a lot of data. + See https://github.com/Pylons/waitress/pull/242 + +- Stop early and close the ``app_iter`` when attempting to write to a closed + socket due to a client disconnect. This should notify a long-lived streaming + response when a client hangs up. + See https://github.com/Pylons/waitress/pull/238 + and https://github.com/Pylons/waitress/pull/240 + and https://github.com/Pylons/waitress/pull/241 + +- Adjust the flush to output ``SO_SNDBUF`` bytes instead of whatever was + set in the ``send_bytes`` adjustment. ``send_bytes`` now only controls how + much waitress will buffer internally before flushing to the kernel, whereas + previously it used to also throttle how much data was sent to the kernel. + This change enables a streaming ``app_iter`` containing small chunks to + still be flushed efficiently. + See https://github.com/Pylons/waitress/pull/246 + +Bugfixes +~~~~~~~~ + +- Upon receiving a request that does not include HTTP/1.0 or HTTP/1.1 we will + no longer set the version to the string value "None". See + https://github.com/Pylons/waitress/pull/252 and + https://github.com/Pylons/waitress/issues/110 + +- When a client closes a socket unexpectedly there was potential for memory + leaks in which data was written to the buffers after they were closed, + causing them to reopen. + See https://github.com/Pylons/waitress/pull/239 + +- Fix the queue depth warnings to only show when all threads are busy. + See https://github.com/Pylons/waitress/pull/243 + and https://github.com/Pylons/waitress/pull/247 + +- Trigger the ``app_iter`` to close as part of shutdown. This will only be + noticeable for users of the internal server api. In more typical operations + the server will die before benefiting from these changes. + See https://github.com/Pylons/waitress/pull/245 + +- Fix a bug in which a streaming ``app_iter`` may never cleanup data that has + already been sent. This would cause buffers in waitress to grow without + bounds. These buffers now properly rotate and release their data. + See https://github.com/Pylons/waitress/pull/242 + +- Fix a bug in which non-seekable subclasses of ``io.IOBase`` would trigger + an exception when passed to the ``wsgi.file_wrapper`` callback. + See https://github.com/Pylons/waitress/pull/249 + +1.2.1 (2019-01-25) +------------------ + +Bugfixes +~~~~~~~~ + +- When given an IPv6 address in ``X-Forwarded-For`` or ``Forwarded for=`` + waitress was placing the IP address in ``REMOTE_ADDR`` with brackets: + ``[2001:db8::0]``, this does not match the requirements in the CGI spec which + ``REMOTE_ADDR`` was lifted from. Waitress will now place the bare IPv6 + address in ``REMOTE_ADDR``: ``2001:db8::0``. See + https://github.com/Pylons/waitress/pull/232 and + https://github.com/Pylons/waitress/issues/230 + +1.2.0 (2019-01-15) +------------------ + +No changes since the last beta release. Enjoy Waitress! + +1.2.0b3 (2019-01-07) +-------------------- + +Bugfixes +~~~~~~~~ + +- Modified ``clear_untrusted_proxy_headers`` to be usable without a + ``trusted_proxy``. + https://github.com/Pylons/waitress/pull/228 + +- Modified ``trusted_proxy_count`` to error when used without a + ``trusted_proxy``. + https://github.com/Pylons/waitress/pull/228 + +1.2.0b2 (2019-02-02) +-------------------- + +Bugfixes +~~~~~~~~ + +- Fixed logic to no longer warn on writes where the output is required to have + a body but there may not be any data to be written. Solves issue posted on + the Pylons Project mailing list with 1.2.0b1. + +1.2.0b1 (2018-12-31) +-------------------- + +Happy New Year! + +Features +~~~~~~~~ + +- Setting the ``trusted_proxy`` setting to ``'*'`` (wildcard) will allow all + upstreams to be considered trusted proxies, thereby allowing services behind + Cloudflare/ELBs to function correctly whereby there may not be a singular IP + address that requests are received from. + + Using this setting is potentially dangerous if your server is also available + from anywhere on the internet, and further protections should be used to lock + down access to Waitress. See https://github.com/Pylons/waitress/pull/224 + +- Waitress has increased its support of the X-Forwarded-* headers and includes + Forwarded (RFC7239) support. This may be used to allow proxy servers to + influence the WSGI environment. See + https://github.com/Pylons/waitress/pull/209 + + This also provides a new security feature when using Waitress behind a proxy + in that it is possible to remove untrusted proxy headers thereby making sure + that downstream WSGI applications don't accidentally use those proxy headers + to make security decisions. + + The documentation has more information, see the following new arguments: + + - trusted_proxy_count + - trusted_proxy_headers + - clear_untrusted_proxy_headers + - log_untrusted_proxy_headers (useful for debugging) + + Be aware that the defaults for these are currently backwards compatible with + older versions of Waitress, this will change in a future release of waitress. + If you expect to need this behaviour please explicitly set these variables in + your configuration, or pin this version of waitress. + + Documentation: + https://docs.pylonsproject.org/projects/waitress/en/latest/reverse-proxy.html + +- Waitress can now accept a list of sockets that are already pre-bound rather + than creating its own to allow for socket activation. Support for init + systems/other systems that create said activated sockets is not included. See + https://github.com/Pylons/waitress/pull/215 + +- Server header can be omitted by specifying ``ident=None`` or ``ident=''``. + See https://github.com/Pylons/waitress/pull/187 + +Bugfixes +~~~~~~~~ + +- Waitress will no longer send Transfer-Encoding or Content-Length for 1xx, + 204, or 304 responses, and will completely ignore any message body sent by + the WSGI application, making sure to follow the HTTP standard. See + https://github.com/Pylons/waitress/pull/166, + https://github.com/Pylons/waitress/issues/165, + https://github.com/Pylons/waitress/issues/152, and + https://github.com/Pylons/waitress/pull/202 + +Compatibility +~~~~~~~~~~~~~ + +- Waitress has now "vendored" asyncore into itself as ``waitress.wasyncore``. + This is to cope with the eventuality that asyncore will be removed from + the Python standard library in 3.8 or so. + +Documentation +~~~~~~~~~~~~~ + +- Bring in documentation of paste.translogger from Pyramid. Reorganize and + clean up documentation. See + https://github.com/Pylons/waitress/pull/205 + https://github.com/Pylons/waitress/pull/70 + https://github.com/Pylons/waitress/pull/206 + +1.1.0 (2017-10-10) +------------------ + +Features +~~~~~~~~ + +- Waitress now has a __main__ and thus may be called with ``python -mwaitress`` + +Bugfixes +~~~~~~~~ + +- Waitress no longer allows lowercase HTTP verbs. This change was made to fall + in line with most HTTP servers. See https://github.com/Pylons/waitress/pull/170 + +- When receiving non-ascii bytes in the request URL, waitress will no longer + abruptly close the connection, instead returning a 400 Bad Request. See + https://github.com/Pylons/waitress/pull/162 and + https://github.com/Pylons/waitress/issues/64 + +1.0.2 (2017-02-04) +------------------ + +Features +~~~~~~~~ + +- Python 3.6 is now officially supported in Waitress + +Bugfixes +~~~~~~~~ + +- Add a work-around for libc issue on Linux not following the documented + standards. If getnameinfo() fails because of DNS not being available it + should return the IP address instead of the reverse DNS entry, however + instead getnameinfo() raises. We catch this, and ask getnameinfo() + for the same information again, explicitly asking for IP address instead of + reverse DNS hostname. See https://github.com/Pylons/waitress/issues/149 and + https://github.com/Pylons/waitress/pull/153 + +1.0.1 (2016-10-22) +------------------ + +Bugfixes +~~~~~~~~ + +- IPv6 support on Windows was broken due to missing constants in the socket + module. This has been resolved by setting the constants on Windows if they + are missing. See https://github.com/Pylons/waitress/issues/138 + +- A ValueError was raised on Windows when passing a string for the port, on + Windows in Python 2 using service names instead of port numbers doesn't work + with `getaddrinfo`. This has been resolved by attempting to convert the port + number to an integer, if that fails a ValueError will be raised. See + https://github.com/Pylons/waitress/issues/139 + + +1.0.0 (2016-08-31) +------------------ + +Bugfixes +~~~~~~~~ + +- Removed `AI_ADDRCONFIG` from the call to `getaddrinfo`, this resolves an + issue whereby `getaddrinfo` wouldn't return any addresses to `bind` to on + hosts where there is no internet connection but localhost is requested to be + bound to. See https://github.com/Pylons/waitress/issues/131 for more + information. + +Deprecations +~~~~~~~~~~~~ + +- Python 2.6 is no longer supported. + +Features +~~~~~~~~ + +- IPv6 support + +- Waitress is now able to listen on multiple sockets, including IPv4 and IPv6. + Instead of passing in a host/port combination you now provide waitress with a + space delineated list, and it will create as many sockets as required. + + .. code-block:: python + + from waitress import serve + serve(wsgiapp, listen='0.0.0.0:8080 [::]:9090 *:6543') + +Security +~~~~~~~~ + +- Waitress will now drop HTTP headers that contain an underscore in the key + when received from a client. This is to stop any possible underscore/dash + conflation that may lead to security issues. See + https://github.com/Pylons/waitress/pull/80 and + https://www.djangoproject.com/weblog/2015/jan/13/security/ + 0.9.0 (2016-04-15) ------------------ @@ -223,7 +741,7 @@ Bug Fixes Bug Fixes ~~~~~~~~~ -- http://corte.si/posts/code/pathod/pythonservers/index.html pointed out that +- https://corte.si/posts/code/pathod/pythonservers/index.html pointed out that sending a bad header resulted in an exception leading to a 500 response instead of the more proper 400 response without an exception. @@ -254,7 +772,7 @@ Features ~~~~~~~~ - Support the WSGI ``wsgi.file_wrapper`` protocol as per - http://www.python.org/dev/peps/pep-0333/#optional-platform-specific-file-handling. + https://www.python.org/dev/peps/pep-0333/#optional-platform-specific-file-handling. Here's a usage example:: import os diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..b41b4db3 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,22 @@ +graft src/waitress +graft tests +graft docs +graft .github + +include README.rst +include CHANGES.txt +include HISTORY.txt +include RELEASING.txt +include LICENSE.txt +include contributing.md +include CONTRIBUTORS.txt +include COPYRIGHT.txt + +include pyproject.toml setup.cfg +include .coveragerc .flake8 +include tox.ini rtd.txt + +exclude TODO.txt +prune docs/_build + +recursive-exclude * __pycache__ *.py[cod] diff --git a/README.rst b/README.rst index ce9ea017..0d9c5e94 100644 --- a/README.rst +++ b/README.rst @@ -1,8 +1,26 @@ -Waitress is meant to be a production-quality pure-Python WSGI server with very -acceptable performance. It has no dependencies except ones which live in the -Python standard library. It runs on CPython on Unix and Windows under Python -2.7+ and Python 3.3+. It is also known to run on PyPy 1.6.0+ on UNIX. It +Waitress +======== + +.. image:: https://img.shields.io/pypi/v/waitress.svg + :target: https://pypi.org/project/waitress/ + :alt: latest version of waitress on PyPI + +.. image:: https://github.com/Pylons/waitress/workflows/Build%20and%20test/badge.svg + :target: https://github.com/Pylons/waitress/actions?query=workflow%3A%22Build+and+test%22 + +.. image:: https://readthedocs.org/projects/waitress/badge/?version=master + :target: https://docs.pylonsproject.org/projects/waitress/en/master + :alt: master Documentation Status + +.. image:: https://img.shields.io/badge/irc-freenode-blue.svg + :target: https://webchat.freenode.net/?channels=pyramid + :alt: IRC Freenode + +Waitress is a production-quality pure-Python WSGI server with very acceptable +performance. It has no dependencies except ones which live in the Python +standard library. It runs on CPython on Unix and Windows under Python 3.7+. It +is also known to run on PyPy 3 (version 3.7 compatible python) on UNIX. It supports HTTP/1.0 and HTTP/1.1. -For more information, see the "docs" directory of the Waitress package or -http://docs.pylonsproject.org/projects/waitress/en/latest/ . +For more information, see the "docs" directory of the Waitress package or visit +https://docs.pylonsproject.org/projects/waitress/en/latest/ diff --git a/RELEASING.txt b/RELEASING.txt new file mode 100644 index 00000000..13060986 --- /dev/null +++ b/RELEASING.txt @@ -0,0 +1,115 @@ +Releasing +========= + +- For clarity, we define releases as follows. + + - Alpha, beta, dev and similar statuses do not qualify whether a release is + major or minor. The term "pre-release" means alpha, beta, or dev. + + - A release is final when it is no longer pre-release. + + - A *major* release is where the first number either before or after the + first dot increases. Examples: 1.0 to 1.1a1, or 0.9 to 1.0. + + - A *minor* or *bug fix* release is where the number after the second dot + increases. Example: 1.0 to 1.0.1. + +Prepare new release +------------------- + +- Do platform test via tox: + + $ tox -r + + Make sure statement coverage is at 100% (the test run will fail if not). + +- Run tests on Windows if feasible. + +- Ensure all features of the release are documented (audit CHANGES.txt or + communicate with contributors). + +- Change CHANGES.txt heading to reflect the new version number. + +- Minor releases should include a link under "Bug Fix Releases" to the minor + feature changes in CHANGES.txt. + +- Change setup.py version to the release version number. + +- Make sure PyPI long description renders (requires ``readme_renderer`` + installed into your Python):: + + $ python setup.py check -r -s -m + +- Create a release tag. + +- Make sure your Python has ``setuptools-git``, ``twine``, and ``wheel`` + installed and release to PyPI:: + + $ python setup.py sdist bdist_wheel + $ twine upload dist/waitress-X.X-* + + +Prepare master for further development (major releases only) +------------------------------------------------------------ + +- In CHANGES.txt, preserve headings but clear out content. Add heading + "unreleased" for the version number. + +- Forward port the changes in CHANGES.txt to HISTORY.txt. + +- Change setup.py version to the next version number. + + +Marketing and communications +---------------------------- + +- Check `https://wiki.python.org/moin/WebServers + `_. + +- Announce to Twitter. + +``` +waitress 1.x released. + +PyPI +https://pypi.org/project/waitress/1.x/ + +=== One time only for new version, first pre-release === +What's New +https://docs.pylonsproject.org/projects/waitress/en/latest/#id2 +=== For all subsequent pre-releases === +Changes +https://docs.pylonsproject.org/projects/waitress/en/latest/#change-history + +Documentation: +https://docs.pylonsproject.org/projects/waitress/en/latest/ + +Issues +https://github.com/Pylons/waitress/issues +``` + +- Announce to maillist. + +``` +waitress 1.X.X has been released. + +The full changelog is here: +https://docs.pylonsproject.org/projects/waitress/en/latest/#change-history + +What's New In waitress 1.X: +https://docs.pylonsproject.org/projects/waitress/en/latest/#id2 + +Documentation: +https://docs.pylonsproject.org/projects/waitress/en/latest/ + +You can install it via PyPI: + + pip install waitress==1.X + +Enjoy, and please report any issues you find to the issue tracker at +https://github.com/Pylons/waitress/issues + +Thanks! + +- waitress core developers +``` diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index bc7aa9bf..00000000 --- a/appveyor.yml +++ /dev/null @@ -1,23 +0,0 @@ -environment: - matrix: - - PYTHON: "C:\\Python35" - TOXENV: "py35" - - PYTHON: "C:\\Python27" - TOXENV: "py27" - - PYTHON: "C:\\Python27-x64" - TOXENV: "py27" - - PYTHON: "C:\\Python35-x64" - TOXENV: "py35" - -cache: - - '%LOCALAPPDATA%\pip\Cache' - -version: '{branch}.{build}' - -install: - - "%PYTHON%\\python.exe -m pip install tox" - -build: off - -test_script: - - "%PYTHON%\\Scripts\\tox.exe" diff --git a/contributing.md b/contributing.md new file mode 100644 index 00000000..6bdfb523 --- /dev/null +++ b/contributing.md @@ -0,0 +1,95 @@ +Contributing +============ + +All projects under the Pylons Projects, including this one, follow the guidelines established at [How to Contribute](https://pylonsproject.org/community-how-to-contribute.html) and [Coding Style and Standards](https://pylonsproject.org/community-coding-style-standards.html). + + +Get support +----------- + +See [Get Support](https://pylonsproject.org/community-support.html). You are reading this document most likely because you want to *contribute* to the project and not *get support*. + + +Working on issues +----------------- + +To respect both your time and ours, we emphasize the following points. + +* We use the [Issue Tracker on GitHub](https://github.com/Pylons/waitress/issues) to discuss bugs, improvements, and feature requests. Search through existing issues before reporting a new one. Issues may be complex or wide-ranging. A discussion up front sets us all on the best path forward. +* Minor issues—such as spelling, grammar, and syntax—don't require discussion and a pull request is sufficient. +* After discussing the issue with maintainers and agreeing on a resolution, submit a pull request of your work. [GitHub Flow](https://guides.github.com/introduction/flow/index.html) describes the workflow process and why it's a good practice. + + +Git branches +------------ + +There is a single branch [master](https://github.com/Pylons/waitress/) on which development takes place and from which releases to PyPI are tagged. This is the default branch on GitHub. + + +Running tests and building documentation +---------------------------------------- + +We use [tox](https://tox.readthedocs.io/en/latest/) to automate test running, coverage, and building documentation across all supported Python versions. + +To run everything configured in the `tox.ini` file: + + $ tox + +To run tests on Python 2 and 3, and ensure full coverage, but exclude building of docs: + + $ tox -e py2-cover,py3-cover,coverage + +To build the docs only: + + $ tox -e docs + +See the `tox.ini` file for details. + + +Contributing documentation +-------------------------- + +*Note:* These instructions might not work for Windows users. Suggestions to improve the process for Windows users are welcome by submitting an issue or a pull request. + +1. Fork the repo on GitHub by clicking the [Fork] button. +2. Clone your fork into a workspace on your local machine. + + cd ~/projects + git clone git@github.com:/waitress.git + +3. Add a git remote "upstream" for the cloned fork. + + git remote add upstream git@github.com:Pylons/waitress.git + +4. Set an environment variable to your virtual environment. + + # Mac and Linux + $ export VENV=~/projects/waitress/env + + # Windows + set VENV=c:\projects\waitress\env + +5. Try to build the docs in your workspace. + + # Mac and Linux + $ make clean html SPHINXBUILD=$VENV/bin/sphinx-build + + # Windows + c:\> make clean html SPHINXBUILD=%VENV%\bin\sphinx-build + + If successful, then you can make changes to the documentation. You can load the built documentation in the `/_build/html/` directory in a web browser. + +6. From this point forward, follow the typical [git workflow](https://help.github.com/articles/what-is-a-good-git-workflow/). Start by pulling from the upstream to get the most current changes. + + git pull upstream master + +7. Make a branch, make changes to the docs, and rebuild them as indicated in step 5. To speed up the build process, you can omit `clean` from the above command to rebuild only those pages that depend on the files you have changed. + +8. Once you are satisfied with your changes and the documentation builds successfully without errors or warnings, then git commit and push them to your "origin" repository on GitHub. + + git commit -m "commit message" + git push -u origin --all # first time only, subsequent can be just 'git push'. + +9. Create a [pull request](https://help.github.com/articles/using-pull-requests/). + +10. Repeat the process starting from Step 6. diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index da7abd0c..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -_themes -_build - - diff --git a/docs/api.rst b/docs/api.rst index 5e0a5231..318ff2c2 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,6 +5,6 @@ .. module:: waitress -.. function:: serve(app, listen='0.0.0.0:8080', unix_socket=None, unix_socket_perms='600', threads=4, url_scheme='http', url_prefix='', ident='waitress', backlog=1204, recv_bytes=8192, send_bytes=18000, outbuf_overflow=104856, inbuf_overflow=52488, connection_limit=1000, cleanup_interval=30, channel_timeout=120, log_socket_errors=True, max_request_header_size=262144, max_request_body_size=1073741824, expose_tracebacks=False) +.. function:: serve(app, listen='0.0.0.0:8080', unix_socket=None, unix_socket_perms='600', threads=4, url_scheme='http', url_prefix='', ident='waitress', backlog=1024, recv_bytes=8192, send_bytes=1, outbuf_overflow=104856, outbuf_high_watermark=16777216, inbuf_overflow=52488, connection_limit=1000, cleanup_interval=30, channel_timeout=120, log_socket_errors=True, max_request_header_size=262144, max_request_body_size=1073741824, expose_tracebacks=False) See :ref:`arguments` for more information. diff --git a/docs/arguments.rst b/docs/arguments.rst index b827b0ee..f9b9310f 100644 --- a/docs/arguments.rst +++ b/docs/arguments.rst @@ -3,11 +3,11 @@ Arguments to ``waitress.serve`` ------------------------------- -Here are the arguments you can pass to the `waitress.serve`` function or use +Here are the arguments you can pass to the ``waitress.serve`` function or use in :term:`PasteDeploy` configuration (interchangeably): host - hostname or IP address (string) on which to listen, default ``0.0.0.0``, + Hostname or IP address (string) on which to listen, default ``0.0.0.0``, which means "all IP addresses on this host". .. warning:: @@ -20,21 +20,38 @@ port May not be used with ``listen`` listen - Tell waitress to listen on an host/port combination. It is to be provided - as a space delineated list of host/port: + Tell waitress to listen on combinations of ``host:port`` arguments. + Combinations should be a quoted, space-delimited list, as in the following examples. - Examples: + .. code-block:: python - - ``listen="127.0.0.1:8080 [::1]:8080"`` - - ``listen="*:8080 *:6543"`` + listen="127.0.0.1:8080 [::1]:8080" + listen="*:8080 *:6543" - A wildcard for the hostname is also supported and will bind to both - IPv4/IPv6 depending on whether they are enabled or disabled. + A wildcard for the hostname is also supported and will bind to both + IPv4/IPv6 depending on whether they are enabled or disabled. - IPv6 IP addresses are supported by surrounding the IP address with brackets. + IPv6 IP addresses are supported by surrounding the IP address with brackets. .. versionadded:: 1.0 +server_name + This is the value that will be placed in the WSGI environment as + ``SERVER_NAME``, the only time that this value is used in the WSGI + environment for a request is if the client sent a HTTP/1.0 request without + a ``Host`` header set, and no other proxy headers. + + The default is value is ``waitress.invalid``, if your WSGI application is + creating URL's that include this as the hostname and you are using a + reverse proxy setup, you may want to validate that your reverse proxy is + sending the appropriate headers. + + In most situations you will not need to set this value. + + Default: ``waitress.invalid`` + + .. versionadded:: 2.0 + ipv4 Enable or disable IPv4 (boolean) @@ -42,66 +59,180 @@ ipv6 Enable or disable IPv6 (boolean) unix_socket - Path of Unix socket (string), default is ``None``. If a socket path is - specified, a Unix domain socket is made instead of the usual inet domain - socket. + Path of Unix socket (string). If a socket path is specified, a Unix domain + socket is made instead of the usual inet domain socket. Not available on Windows. + Default: ``None`` + unix_socket_perms - Octal permissions to use for the Unix domain socket (string), default is - ``600``. Only used if ``unix_socket`` is not ``None``. + Octal permissions to use for the Unix domain socket (string). + Only used if ``unix_socket`` is not ``None``. + + Default: ``'600'`` + +sockets + A list of sockets. The sockets can be either Internet or UNIX sockets and have + to be bound. Internet and UNIX sockets cannot be mixed. + If the socket list is not empty, waitress creates one server for each socket. + + Default: ``[]`` + + .. versionadded:: 1.1.1 + + .. warning:: + May not be used with ``listen``, ``host``, ``port`` or ``unix_socket`` threads - number of threads used to process application logic (integer), default - ``4`` + The number of threads used to process application logic (integer). + + Default: ``4`` trusted_proxy - IP address of a client allowed to override ``url_scheme`` via the - ``X_FORWARDED_PROTO`` header. + IP address of a remote peer allowed to override various WSGI environment + variables using proxy headers. + + For unix sockets, set this value to ``localhost`` instead of an IP address. + + Default: ``None`` + +trusted_proxy_count + How many proxies we trust when chained. For example, + + ``X-Forwarded-For: 192.0.2.1, "[2001:db8::1]"`` + + or + + ``Forwarded: for=192.0.2.1, For="[2001:db8::1]"`` + + means there were (potentially), two proxies involved. If we know there is + only 1 valid proxy, then that initial IP address "192.0.2.1" is not trusted + and we completely ignore it. + + If there are two trusted proxies in the path, this value should be set to + 2. If there are more proxies, this value should be set higher. + + Default: ``1`` + + .. versionadded:: 1.2.0 + +trusted_proxy_headers + Which of the proxy headers should we trust, this is a set where you + either specify "forwarded" or one or more of "x-forwarded-host", "x-forwarded-for", + "x-forwarded-proto", "x-forwarded-port", "x-forwarded-by". + + This list of trusted headers is used when ``trusted_proxy`` is set and will + allow waitress to modify the WSGI environment using the values provided by + the proxy. + + .. versionadded:: 1.2.0 + + .. warning:: + If ``trusted_proxy`` is set, the default is ``x-forwarded-proto`` to + match older versions of Waitress. Users should explicitly opt-in by + selecting the headers to be trusted as future versions of waitress will + use an empty default. + + .. warning:: + It is an error to set this value without setting ``trusted_proxy``. + +log_untrusted_proxy_headers + Should waitress log warning messages about proxy headers that are being + sent from upstream that are not trusted by ``trusted_proxy_headers`` but + are being cleared due to ``clear_untrusted_proxy_headers``? + + This may be useful for debugging if you expect your upstream proxy server + to only send specific headers. + + Default: ``False`` + + .. versionadded:: 1.2.0 + + .. warning:: + It is a no-op to set this value without also setting + ``clear_untrusted_proxy_headers`` and ``trusted_proxy`` + +clear_untrusted_proxy_headers + This tells Waitress to remove any untrusted proxy headers ("Forwarded", + "X-Forwared-For", "X-Forwarded-By", "X-Forwarded-Host", "X-Forwarded-Port", + "X-Forwarded-Proto") not explicitly allowed by ``trusted_proxy_headers``. + + Default: ``False`` + + .. versionadded:: 1.2.0 + + .. warning:: + The default value is set to ``False`` for backwards compatibility. In + future versions of Waitress this default will be changed to ``True``. + Warnings will be raised unless the user explicitly provides a value for + this option, allowing the user to opt-in to the new safety features + automatically. + + .. warning:: + It is an error to set this value without setting ``trusted_proxy``. url_scheme - default ``wsgi.url_scheme`` value (string), default ``http``; can be + The value of ``wsgi.url_scheme`` in the environ. This can be overridden per-request by the value of the ``X_FORWARDED_PROTO`` header, but only if the client address matches ``trusted_proxy``. + Default: ``http`` + ident - server identity (string) used in "Server:" header in responses, default - ``waitress`` + Server identity (string) used in "Server:" header in responses. + + Default: ``waitress`` backlog - backlog is the value waitress passes to pass to socket.listen() - (integer), default ``1024``. This is the maximum number of incoming TCP + The value waitress passes to pass to ``socket.listen()`` (integer). + This is the maximum number of incoming TCP connections that will wait in an OS queue for an available channel. From listen(1): "If a connection request arrives when the queue is full, the client may receive an error with an indication of ECONNREFUSED or, if the underlying protocol supports retransmission, the request may be ignored so that a later reattempt at connection succeeds." + Default: ``1024`` + recv_bytes - recv_bytes is the argument waitress passes to socket.recv() (integer), - default ``8192`` + The argument waitress passes to ``socket.recv()`` (integer). + + Default: ``8192`` send_bytes - send_bytes is the number of bytes to send to socket.send() (integer), - default ``18000``. Multiples of 9000 should avoid partly-filled TCP + The number of bytes to send to ``socket.send()`` (integer). + Multiples of 9000 should avoid partly-filled TCP packets, but don't set this larger than the TCP write buffer size. In - Linux, /proc/sys/net/ipv4/tcp_wmem controls the minimum, default, and + Linux, ``/proc/sys/net/ipv4/tcp_wmem`` controls the minimum, default, and maximum sizes of TCP write buffers. + Default: ``1`` + + .. deprecated:: 1.3 + outbuf_overflow A tempfile should be created if the pending output is larger than - outbuf_overflow, which is measured in bytes. The default is 1MB - (``1048576``). This is conservative. + outbuf_overflow, which is measured in bytes. The default is conservative. + + Default: ``1048576`` (1MB) + +outbuf_high_watermark + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. + + Default: ``16777216`` (16MB) inbuf_overflow A tempfile should be created if the pending input is larger than - inbuf_overflow, which is measured in bytes. The default is 512K - (``524288``). This is conservative. + inbuf_overflow, which is measured in bytes. The default is conservative. + + Default: ``524288`` (512K) connection_limit Stop creating new channels if too many are already active (integer). - Default is ``100``. Each channel consumes at least one file descriptor, + Each channel consumes at least one file descriptor, and, depending on the input and output body sizes, potentially up to three, plus whatever file descriptors your application logic happens to open. The default is conservative, but you may need to increase the @@ -111,45 +242,62 @@ connection_limit connections that can be waiting for processing; the ``backlog`` argument controls that. + Default: ``100`` + cleanup_interval - Minimum seconds between cleaning up inactive channels (integer), default - ``30``. See "channel_timeout". + Minimum seconds between cleaning up inactive channels (integer). + See also ``channel_timeout``. + + Default: ``30`` channel_timeout - Maximum seconds to leave an inactive connection open (integer), default - ``120``. "Inactive" is defined as "has received no data from a client + Maximum seconds to leave an inactive connection open (integer). + "Inactive" is defined as "has received no data from a client and has sent no data to a client". + Default: ``120`` + log_socket_errors - Boolean: turn off to not log premature client disconnect tracebacks. - Default: ``True``. + Set to ``False`` to not log premature client disconnect tracebacks. + + Default: ``True`` max_request_header_size - maximum number of bytes of all request headers combined (integer), 256K - (``262144``) default) + Maximum number of bytes of all request headers combined (integer). + + Default: ``262144`` (256K) max_request_body_size - maximum number of bytes in request body (integer), 1GB (``1073741824``) - default. + Maximum number of bytes in request body (integer). + + Default: ``1073741824`` (1GB) expose_tracebacks - Boolean: expose tracebacks of unhandled exceptions to client. Default: - ``False``. + Set to ``True`` to expose tracebacks of unhandled exceptions to client. + + Default: ``False`` asyncore_loop_timeout - The ``timeout`` value (seconds) passed to ``asyncore.loop`` to run the - mainloop. Default: 1. (New in 0.8.3.) + The ``timeout`` value (seconds) passed to ``asyncore.loop`` to run the mainloop. + + Default: ``1`` + + .. versionadded:: 0.8.3 asyncore_use_poll - Boolean: switch from using select() to poll() in ``asyncore.loop``. - By default asyncore.loop() uses select() which has a limit of 1024 - file descriptors. Select() and poll() provide basically the same - functionality, but poll() doesn't have the file descriptors limit. - Default: False (New in 0.8.6) + Set to ``True`` to switch from using ``select()`` to ``poll()`` in ``asyncore.loop``. + By default ``asyncore.loop()`` uses ``select()`` which has a limit of 1024 file descriptors. + ``select()`` and ``poll()`` provide basically the same functionality, but ``poll()`` doesn't have the file descriptors limit. + + Default: ``False`` + + .. versionadded:: 0.8.6 url_prefix String: the value used as the WSGI ``SCRIPT_NAME`` value. Setting this to anything except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be the value passed minus any trailing slashes you add, and it will cause the ``PATH_INFO`` of any request which is prefixed with this value to - be stripped of the prefix. Default: the empty string. + be stripped of the prefix. + + Default: ``''`` diff --git a/docs/conf.py b/docs/conf.py index aa9b3fbd..cf0ff9b5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,9 +15,9 @@ # If your extensions are in another directory, add it here. If the # directory is relative to the documentation root, use os.path.abspath to # make it absolute, like shown here. -#sys.path.append(os.path.abspath('some/directory')) +# sys.path.append(os.path.abspath('some/directory')) -import sys, os +import datetime import pkg_resources import pylons_sphinx_themes @@ -27,71 +27,82 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ - 'sphinx.ext.autodoc', - ] + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), +} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General substitutions. -project = 'waitress' -copyright = '2012, Agendaless Consulting ' +project = "waitress" +thisyear = datetime.datetime.now().year +copyright = "2012-%s, Agendaless Consulting " % thisyear # The default replacements for |version| and |release|, also used in various # other places throughout the built documents. # # The short X.Y version. -version = pkg_resources.get_distribution('waitress').version +version = pkg_resources.get_distribution("waitress").version # The full version, including alpha/beta/rc tags. release = version # There are two options for replacing |today|: either, you set today to # some non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -today_fmt = '%B %d, %Y' +today_fmt = "%B %d, %Y" # List of documents that shouldn't be included in the build. -#unused_docs = [] +# unused_docs = [] # List of directories, relative to source directories, that shouldn't be # searched for source files. -#exclude_dirs = [] -exclude_patterns = ['_themes/README.rst',] +# exclude_dirs = [] +exclude_patterns = [ + "_themes/README.rst", +] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True add_module_names = False # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" + +# Do not use smart quotes. +smartquotes = False # Options for HTML output # ----------------------- # Add and use Pylons theme -html_theme = 'pylons' +html_theme = "pylons" html_theme_path = pylons_sphinx_themes.get_html_themes_path() -html_theme_options = dict(github_url='http://github.com/Pylons/waitress') +html_theme_options = dict(github_url="https://github.com/Pylons/waitress") # The style sheet to use for HTML and HTML Help pages. A file of that name # must exist either in Sphinx' static/ path, or in one of the custom paths @@ -100,11 +111,11 @@ # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as # html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (within the static path) to place at the top of # the sidebar. @@ -113,84 +124,99 @@ # The name of an image file (within the static path) to use as favicon of # the docs. This file should be a Windows icon file (.ico) being 16x16 or # 32x32 pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) # here, relative to this directory. They are copied after the builtin # static files, so a file named "default.css" will overwrite the builtin # "default.css". -#html_static_path = ['.static'] +# html_static_path = ['.static'] # If not '', a 'Last updated on:' timestamp is inserted at every page # bottom, using the given strftime format. -html_last_updated_fmt = '%b %d, %Y' +html_last_updated_fmt = "%b %d, %Y" # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_use_modindex = True +# html_use_modindex = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, the reST sources are included in the HTML build as # _sources/. -#html_copy_source = True +# html_copy_source = True # If true, an OpenSearch description file will be output, and all pages # will contain a tag referring to it. The value of this option must # be the base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' +# html_file_suffix = '' # Output file base name for HTML help builder. -htmlhelp_basename = 'atemplatedoc' - +htmlhelp_basename = "waitress" + +# Control display of sidebars +html_sidebars = { + "**": [ + "localtoc.html", + "ethicalads.html", + "relations.html", + "sourcelink.html", + "searchbox.html", + ] +} # Options for LaTeX output # ------------------------ # The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' +# latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, document class [howto/manual]). latex_documents = [ - ('index', 'waitress.tex', 'waitress Documentation', - 'Pylons Developers', 'manual'), + ( + "index", + "waitress.tex", + "waitress Documentation", + "Pylons Project Developers", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the # top of the title page. -#latex_logo = '.static/logo_hi.gif' +# latex_logo = '.static/logo_hi.gif' # For "manual" documents, if this is true, then toplevel headings are # parts, not chapters. -#latex_use_parts = False +# latex_use_parts = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_use_modindex = True +# latex_use_modindex = True diff --git a/docs/design.rst b/docs/design.rst index 591a4236..c0d13eff 100644 --- a/docs/design.rst +++ b/docs/design.rst @@ -1,11 +1,19 @@ Design ------ -Waitress uses a combination of asynchronous and synchronous code to do its -job. It handles I/O to and from clients using the :term:`asyncore` library. +Waitress uses a combination of asynchronous and synchronous code to do its job. +It handles I/O to and from clients using the :term:`wasyncore`, which is :term:`asyncore` vendored into Waitress. It services requests via threads. -The :term:`asyncore` module in the Python standard library: +.. note:: + :term:`asyncore` has been deprecated since Python 3.6. + Work continues on its inevitable removal from the Python standard library. + Its recommended replacement is :mod:`asyncio`. + + Although :term:`asyncore` has been vendored into Waitress as :term:`wasyncore`, you may see references to "asyncore" in this documentation's code examples and API. + The terms are effectively the same and may be used interchangeably. + +The :term:`wasyncore` module: - Uses the ``select.select`` function to wait for connections from clients and determine if a connected client is ready to receive output. @@ -37,10 +45,12 @@ channel, and can write back to the channel's output buffer. When all worker threads are in use, scheduled tasks will wait in a queue for a worker thread to become available. -I/O is always done asynchronously (by asyncore) in the main thread. Worker -threads never do any I/O. This means that 1) a large number of clients can -be connected to the server at once and 2) worker threads will never be hung -up trying to send data to a slow client. +I/O is always done asynchronously (by :term:`wasyncore`) in the main thread. +Worker threads never do any I/O. +This means that + +#. a large number of clients can be connected to the server at once, and +#. worker threads will never be hung up trying to send data to a slow client. No attempt is made to kill a "hung thread". It's assumed that when a task (application logic) starts that it will eventually complete. If for some diff --git a/docs/filewrapper.rst b/docs/filewrapper.rst index a1195944..e682046b 100644 --- a/docs/filewrapper.rst +++ b/docs/filewrapper.rst @@ -1,9 +1,7 @@ Support for ``wsgi.file_wrapper`` --------------------------------- -Waitress supports the `WSGI file_wrapper protocol -`_ -. Here's a usage example: +Waitress supports the Python Web Server Gateway Interface v1.0 as specified in :pep:`3333`. Here's a usage example: .. code-block:: python diff --git a/docs/glossary.rst b/docs/glossary.rst index 92aaf31d..53098450 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -4,12 +4,31 @@ Glossary ======== .. glossary:: - :sorted: + :sorted: - PasteDeploy - A system for configuration of WSGI web components in declarative - ``.ini`` format. See http://pythonpaste.org/deploy/. + PasteDeploy + A system for configuration of WSGI web components in declarative ``.ini`` format. + See https://docs.pylonsproject.org/projects/pastedeploy/en/latest/. - asyncore - A standard library module for asynchronous communications. See - http://docs.python.org/library/asyncore.html . + asyncore + A Python standard library module for asynchronous communications. See :mod:`asyncore`. + + .. versionchanged:: 1.2.0 + Waitress has now "vendored" ``asyncore`` into itself as ``waitress.wasyncore``. + This is to cope with the eventuality that ``asyncore`` will be removed from the Python standard library in Python 3.8 or so. + + middleware + *Middleware* is a :term:`WSGI` concept. + It is a WSGI component that acts both as a server and an application. + Interesting uses for middleware exist, such as caching, content-transport encoding, and other functions. + See `WSGI.org `_ or `PyPI `_ to find middleware for your application. + + WSGI + `Web Server Gateway Interface `_. + This is a Python standard for connecting web applications to web servers, similar to the concept of Java Servlets. + Waitress requires that your application be served as a WSGI application. + + wasyncore + .. versionchanged:: 1.2.0 + Waitress has now "vendored" :term:`asyncore` into itself as ``waitress.wasyncore``. + This is to cope with the eventuality that ``asyncore`` will be removed from the Python standard library in Python 3.8 or so. diff --git a/docs/index.rst b/docs/index.rst index 9e893950..ba34b5e2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,298 +1,15 @@ -Waitress --------- - -Waitress is meant to be a production-quality pure-Python WSGI server with -very acceptable performance. It has no dependencies except ones which live -in the Python standard library. It runs on CPython on Unix and Windows under -Python 2.7+ and Python 3.3+. It is also known to run on PyPy 1.6.0 on UNIX. -It supports HTTP/1.0 and HTTP/1.1. - -Usage ------ - -Here's normal usage of the server: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp, listen='*:8080') - -This will run waitress on port 8080 on all available IP addresses, both IPv4 -and IPv6. - - -.. code-block:: python - - from waitress import serve - serve(wsgiapp, host='0.0.0.0', port=8080) - -This will run waitress on port 8080 on all available IPv4 addresses. - -If you want to serve your application on all IP addresses, on port 8080, you -can omit the ``host`` and ``port`` arguments and just call ``serve`` with the -WSGI app as a single argument: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp) - -Press Ctrl-C (or Ctrl-Break on Windows) to exit the server. - -The default is to bind to any IPv4 address on port 8080: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp) - -If you want to serve your application through a UNIX domain socket (to serve -a downstream HTTP server/proxy, e.g. nginx, lighttpd, etc.), call ``serve`` -with the ``unix_socket`` argument: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp, unix_socket='/path/to/unix.sock') - -Needless to say, this configuration won't work on Windows. - -Exceptions generated by your application will be shown on the console by -default. See :ref:`logging` to change this. - -There's an entry point for :term:`PasteDeploy` (``egg:waitress#main``) that -lets you use Waitress's WSGI gateway from a configuration file, e.g.: - -.. code-block:: ini - - [server:main] - use = egg:waitress#main - listen = 127.0.0.1:8080 - -Using ``host`` and ``port`` is also supported: - -.. code-block:: ini - - [server:main] - host = 127.0.0.1 - port = 8080 - -The :term:`PasteDeploy` syntax for UNIX domain sockets is analagous: - -.. code-block:: ini - - [server:main] - use = egg:waitress#main - unix_socket = /path/to/unix.sock - -You can find more settings to tweak (arguments to ``waitress.serve`` or -equivalent settings in PasteDeploy) in :ref:`arguments`. - -Additionally, there is a command line runner called ``waitress-serve``, which -can be used in development and in situations where the likes of -:term:`PasteDeploy` is not necessary: - -.. code-block:: bash - - # Listen on both IPv4 and IPv6 on port 8041 - waitress-serve --listen=*:8041 myapp:wsgifunc - - # Listen on only IPv4 on port 8041 - waitress-serve --port=8041 myapp:wsgifunc - -For more information on this, see :ref:`runner`. - -.. _logging: - -Logging -------- - -``waitress.serve`` calls ``logging.basicConfig()`` to set up logging to the -console when the server starts up. Assuming no other logging configuration -has already been done, this sets the logging default level to -``logging.WARNING``. The Waitress logger will inherit the root logger's -level information (it logs at level ``WARNING`` or above). - -Waitress sends its logging output (including application exception -renderings) to the Python logger object named ``waitress``. You can -influence the logger level and output stream using the normal Python -``logging`` module API. For example: - -.. code-block:: python - - import logging - logger = logging.getLogger('waitress') - logger.setLevel(logging.INFO) - -Within a PasteDeploy configuration file, you can use the normal Python -``logging`` module ``.ini`` file format to change similar Waitress logging -options. For example: +.. _index: -.. code-block:: ini - - [logger_waitress] - level = INFO - -Using Behind a Reverse Proxy ----------------------------- - -Often people will set up "pure Python" web servers behind reverse proxies, -especially if they need SSL support (Waitress does not natively support SSL). -Even if you don't need SSL support, it's not uncommon to see Waitress and -other pure-Python web servers set up to "live" behind a reverse proxy; these -proxies often have lots of useful deployment knobs. - -If you're using Waitress behind a reverse proxy, you'll almost always want -your reverse proxy to pass along the ``Host`` header sent by the client to -Waitress, in either case, as it will be used by most applications to generate -correct URLs. - -For example, when using Nginx as a reverse proxy, you might add the following -lines in a ``location`` section:: - - proxy_set_header Host $host; - -The Apache directive named ``ProxyPreserveHost`` does something similar when -used as a reverse proxy. - -Unfortunately, even if you pass the ``Host`` header, the Host header does not -contain enough information to regenerate the original URL sent by the client. -For example, if your reverse proxy accepts HTTPS requests (and therefore URLs -which start with ``https://``), the URLs generated by your application when -used behind a reverse proxy served by Waitress might inappropriately be -``http://foo`` rather than ``https://foo``. To fix this, you'll want to -change the ``wsgi.url_scheme`` in the WSGI environment before it reaches your -application. You can do this in one of three ways: - -1. You can pass a ``url_scheme`` configuration variable to the - ``waitress.serve`` function. - -2. You can configure the proxy reverse server to pass a header, - ``X_FORWARDED_PROTO``, whose value will be set for that request as - the ``wsgi.url_scheme`` environment value. Note that you must also - conigure ``waitress.serve`` by passing the IP address of that proxy - as its ``trusted_proxy``. - -3. You can use Paste's ``PrefixMiddleware`` in conjunction with - configuration settings on the reverse proxy server. - -Using ``url_scheme`` to set ``wsgi.url_scheme`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -You can have the Waitress server use the ``https`` url scheme by default.: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp, listen='0.0.0.0:8080', url_scheme='https') - -This works if all URLs generated by your application should use the ``https`` -scheme. - -Passing the ``X_FORWARDED_PROTO`` header to set ``wsgi.url_scheme`` -------------------------------------------------------------------- - -If your proxy accepts both HTTP and HTTPS URLs, and you want your application -to generate the appropriate url based on the incoming scheme, also set up -your proxy to send a ``X-Forwarded-Proto`` with the original URL scheme along -with each proxied request. For example, when using Nginx:: - - proxy_set_header X-Forwarded-Proto $scheme; - -or via Apache:: - - RequestHeader set X-Forwarded-Proto https - -.. note:: - - You must also configure the Waitress server's ``trusted_proxy`` to - contain the IP address of the proxy in order for this header to override - the default URL scheme. - -Using ``url_prefix`` to influence ``SCRIPT_NAME`` and ``PATH_INFO`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -You can have the Waitress server use a particular url prefix by default for all -URLs generated by downstream applications that take ``SCRIPT_NAME`` into -account.: - -.. code-block:: python - - from waitress import serve - serve(wsgiapp, listen='0.0.0.0:8080', url_prefix='/foo') - -Setting this to any value except the empty string will cause the WSGI -``SCRIPT_NAME`` value to be that value, minus any trailing slashes you add, and -it will cause the ``PATH_INFO`` of any request which is prefixed with this -value to be stripped of the prefix. This is useful in proxying scenarios where -you wish to forward all traffic to a Waitress server but need URLs generated by -downstream applications to be prefixed with a particular path segment. - -Using Paste's ``PrefixMiddleware`` to set ``wsgi.url_scheme`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If only some of the URLs generated by your application should use the -``https`` scheme (and some should use ``http``), you'll need to use Paste's -``PrefixMiddleware`` as well as change some configuration settings on your -proxy. To use ``PrefixMiddleware``, wrap your application before serving it -using Waitress: - -.. code-block:: python - - from waitress import serve - from paste.deploy.config import PrefixMiddleware - app = PrefixMiddleware(app) - serve(app) - -Once you wrap your application in the the ``PrefixMiddleware``, the -middleware will notice certain headers sent from your proxy and will change -the ``wsgi.url_scheme`` and possibly other WSGI environment variables -appropriately. - -Once your application is wrapped by the prefix middleware, you should -instruct your proxy server to send along the original ``Host`` header from -the client to your Waitress server, as well as sending along a -``X-Forwarded-Proto`` header with the appropriate value for -``wsgi.url_scheme``. - -If your proxy accepts both HTTP and HTTPS URLs, and you want your application -to generate the appropriate url based on the incoming scheme, also set up -your proxy to send a ``X-Forwarded-Proto`` with the original URL scheme along -with each proxied request. For example, when using Nginx:: - - proxy_set_header X-Forwarded-Proto $scheme; - -It's permitted to set an ``X-Forwarded-For`` header too; the -``PrefixMiddleware`` uses this to adjust other environment variables (you'll -have to read its docs to find out which ones, I don't know what they are). For -the ``X-Forwarded-For`` header:: - - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - -Note that you can wrap your application in the PrefixMiddleware declaratively -in a :term:`PasteDeploy` configuration file too, if your web framework uses -PasteDeploy-style configuration: - -.. code-block:: ini - - [app:myapp] - use = egg:mypackage#myapp - - [filter:paste_prefix] - use = egg:PasteDeploy#prefix - - [pipeline:main] - pipeline = - paste_prefix - myapp +======== +Waitress +======== - [server:main] - use = egg:waitress#main - listen = 127.0.0.1:8080 +Waitress is meant to be a production-quality pure-Python WSGI server with very +acceptable performance. It has no dependencies except ones which live in the +Python standard library. It runs on CPython on Unix and Windows under Python +3.7+. It is also known to run on PyPy 3 (python version 3.7+) on UNIX. It +supports HTTP/1.0 and HTTP/1.1. -Note that you can also set ``PATH_INFO`` and ``SCRIPT_NAME`` using -PrefixMiddleware too (its original purpose, really) instead of using Waitress' -``url_prefix`` adjustment. See the PasteDeploy docs for more information. Extended Documentation ---------------------- @@ -300,13 +17,17 @@ Extended Documentation .. toctree:: :maxdepth: 1 - design.rst - differences.rst - api.rst - arguments.rst - filewrapper.rst - runner.rst - glossary.rst + usage + logging + reverse-proxy + design + differences + api + arguments + filewrapper + runner + socket-activation + glossary Change History -------------- @@ -317,33 +38,31 @@ Change History Known Issues ------------ -- Does not support SSL natively. +- Does not support TLS natively. See :ref:`using-behind-a-reverse-proxy` for more information. Support and Development ----------------------- -The `Pylons Project web site `_ is the main online +The `Pylons Project web site `_ is the main online source of Waitress support and development information. To report bugs, use the `issue tracker -`_. +`_. If you've got questions that aren't answered by this documentation, -contact the `Pylons-devel maillist -`_ or join the `#pyramid -IRC channel `_. +contact the `Pylons-discuss maillist +`_ or join the `#pyramid +IRC channel `_. Browse and check out tagged and trunk versions of Waitress via -the `Waitress GitHub repository `_. +the `Waitress GitHub repository `_. To check out the trunk via ``git``, use this command: .. code-block:: text git clone git@github.com:Pylons/waitress.git -To find out how to become a contributor to Waitress, please see the -`contributor's section of the documentation -`_. +To find out how to become a contributor to Waitress, please see the guidelines in `contributing.md `_ and `How to Contribute Source Code and Documentation `_. Why? ---- @@ -373,7 +92,7 @@ framework distribution simply for its server component is awkward. The test suite of the CherryPy server also depends on the CherryPy web framework, so even if we forked its server component into a separate distribution, we would have still needed to backfill for all of its tests. The CherryPy team has -started work on `Cheroot `_, which +started work on `Cheroot `_, which should solve this problem, however. Waitress is a fork of the WSGI-related components which existed in diff --git a/docs/logging.rst b/docs/logging.rst new file mode 100644 index 00000000..799b75d6 --- /dev/null +++ b/docs/logging.rst @@ -0,0 +1,190 @@ +.. _access-logging: + +============== +Access Logging +============== + +The WSGI design is modular. Waitress logs error conditions, debugging +output, etc., but not web traffic. For web traffic logging, Paste +provides `TransLogger +`_ +:term:`middleware`. TransLogger produces logs in the `Apache Combined +Log Format `_. + + +.. _logging-to-the-console-using-python: + +Logging to the Console Using Python +----------------------------------- + +``waitress.serve`` calls ``logging.basicConfig()`` to set up logging to the +console when the server starts up. Assuming no other logging configuration +has already been done, this sets the logging default level to +``logging.WARNING``. The Waitress logger will inherit the root logger's +level information (it logs at level ``WARNING`` or above). + +Waitress sends its logging output (including application exception +renderings) to the Python logger object named ``waitress``. You can +influence the logger level and output stream using the normal Python +``logging`` module API. For example: + +.. code-block:: python + + import logging + logger = logging.getLogger('waitress') + logger.setLevel(logging.INFO) + +Within a PasteDeploy configuration file, you can use the normal Python +``logging`` module ``.ini`` file format to change similar Waitress logging +options. For example: + +.. code-block:: ini + + [logger_waitress] + level = INFO + + +.. _logging-to-the-console-using-pastedeploy: + +Logging to the Console Using PasteDeploy +---------------------------------------- + +TransLogger will automatically setup a logging handler to the console when called with no arguments. +It "just works" in environments that don't configure logging. +This is by virtue of its default configuration setting of ``setup_console_handler = True``. + + +.. TODO: +.. .. _logging-to-a-file-using-python: + +.. Logging to a File Using Python +.. ------------------------------ + +.. Show how to configure the WSGI logger via python. + + +.. _logging-to-a-file-using-pastedeploy: + +Logging to a File Using PasteDeploy +------------------------------------ + +TransLogger does not write to files, and the Python logging system +must be configured to do this. The Python class :class:`FileHandler` +logging handler can be used alongside TransLogger to create an +``access.log`` file similar to Apache's. + +Like any standard :term:`middleware` with a Paste entry point, +TransLogger can be configured to wrap your application using ``.ini`` +file syntax. First add a +``[filter:translogger]`` section, then use a ``[pipeline:main]`` +section file to form a WSGI pipeline with both the translogger and +your application in it. For instance, if you have this: + +.. code-block:: ini + + [app:wsgiapp] + use = egg:mypackage#wsgiapp + + [server:main] + use = egg:waitress#main + host = 127.0.0.1 + port = 8080 + +Add this: + +.. code-block:: ini + + [filter:translogger] + use = egg:Paste#translogger + setup_console_handler = False + + [pipeline:main] + pipeline = translogger + wsgiapp + +Using PasteDeploy this way to form and serve a pipeline is equivalent to +wrapping your app in a TransLogger instance via the bottom of the ``main`` +function of your project's ``__init__`` file: + +.. code-block:: python + + from mypackage import wsgiapp + from waitress import serve + from paste.translogger import TransLogger + serve(TransLogger(wsgiapp, setup_console_handler=False)) + +.. note:: + TransLogger will automatically set up a logging handler to the console when + called with no arguments, so it "just works" in environments that don't + configure logging. Since our logging handlers are configured, we disable + the automation via ``setup_console_handler = False``. + +With the filter in place, TransLogger's logger (named the ``wsgi`` logger) will +propagate its log messages to the parent logger (the root logger), sending +its output to the console when we request a page: + +.. code-block:: text + + 00:50:53,694 INFO [wsgiapp] Returning: Hello World! + (content-type: text/plain) + 00:50:53,695 INFO [wsgi] 192.168.1.111 - - [11/Aug/2011:20:09:33 -0700] "GET /hello + HTTP/1.1" 404 - "-" + "Mozilla/5.0 (Macintosh; U; Intel Mac OS X; en-US; rv:1.8.1.6) Gecko/20070725 + Firefox/2.0.0.6" + +To direct TransLogger to an ``access.log`` FileHandler, we need the +following to add a FileHandler (named ``accesslog``) to the list of +handlers, and ensure that the ``wsgi`` logger is configured and uses +this handler accordingly: + +.. code-block:: ini + + # Begin logging configuration + + [loggers] + keys = root, wsgiapp, wsgi + + [handlers] + keys = console, accesslog + + [logger_wsgi] + level = INFO + handlers = accesslog + qualname = wsgi + propagate = 0 + + [handler_accesslog] + class = FileHandler + args = ('%(here)s/access.log','a') + level = INFO + formatter = generic + +As mentioned above, non-root loggers by default propagate their log records +to the root logger's handlers (currently the console handler). Setting +``propagate`` to ``0`` (``False``) here disables this; so the ``wsgi`` logger +directs its records only to the ``accesslog`` handler. + +Finally, there's no need to use the ``generic`` formatter with +TransLogger, as TransLogger itself provides all the information we +need. We'll use a formatter that passes-through the log messages as +is. Add a new formatter called ``accesslog`` by including the +following in your configuration file: + +.. code-block:: ini + + [formatters] + keys = generic, accesslog + + [formatter_accesslog] + format = %(message)s + +Finally alter the existing configuration to wire this new +``accesslog`` formatter into the FileHandler: + +.. code-block:: ini + + [handler_accesslog] + class = FileHandler + args = ('%(here)s/access.log','a') + level = INFO + formatter = accesslog diff --git a/docs/reverse-proxy.rst b/docs/reverse-proxy.rst new file mode 100644 index 00000000..6490e3d7 --- /dev/null +++ b/docs/reverse-proxy.rst @@ -0,0 +1,132 @@ +.. index:: reverse, proxy, TLS, SSL, https + +.. _using-behind-a-reverse-proxy: + +============================ +Using Behind a Reverse Proxy +============================ + +Often people will set up "pure Python" web servers behind reverse proxies, +especially if they need TLS support (Waitress does not natively support TLS). +Even if you don't need TLS support, it's not uncommon to see Waitress and +other pure-Python web servers set up to only handle requests behind a reverse proxy; +these proxies often have lots of useful deployment knobs. + +If you're using Waitress behind a reverse proxy, you'll almost always want +your reverse proxy to pass along the ``Host`` header sent by the client to +Waitress, in either case, as it will be used by most applications to generate +correct URLs. You may also use the proxy headers if passing ``Host`` directly +is not possible, or there are multiple proxies involved. + +For example, when using nginx as a reverse proxy, you might add the following +lines in a ``location`` section. + +.. code-block:: nginx + + proxy_set_header Host $host; + +The Apache directive named ``ProxyPreserveHost`` does something similar when +used as a reverse proxy. + +Unfortunately, even if you pass the ``Host`` header, the Host header does not +contain enough information to regenerate the original URL sent by the client. +For example, if your reverse proxy accepts HTTPS requests (and therefore URLs +which start with ``https://``), the URLs generated by your application when +used behind a reverse proxy served by Waitress might inappropriately be +``http://foo`` rather than ``https://foo``. To fix this, you'll want to +change the ``wsgi.url_scheme`` in the WSGI environment before it reaches your +application. You can do this in one of three ways: + +1. You can pass a ``url_scheme`` configuration variable to the + ``waitress.serve`` function. + +2. You can pass certain well known proxy headers from your proxy server and + use waitress's ``trusted_proxy`` support to automatically configure the + WSGI environment. + +Using ``url_scheme`` to set ``wsgi.url_scheme`` +----------------------------------------------- + +You can have the Waitress server use the ``https`` url scheme by default.: + +.. code-block:: python + + from waitress import serve + serve(wsgiapp, listen='0.0.0.0:8080', url_scheme='https') + +This works if all URLs generated by your application should use the ``https`` +scheme. + +Passing the proxy headers to setup the WSGI environment +------------------------------------------------------- + +If your proxy accepts both HTTP and HTTPS URLs, and you want your application +to generate the appropriate url based on the incoming scheme, you'll want to +pass waitress ``X-Forwarded-Proto``, however Waitress is also able to update +the environment using ``X-Forwarded-Proto``, ``X-Forwarded-For``, +``X-Forwarded-Host``, and ``X-Forwarded-Port``:: + + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Host $host:$server_port; + proxy_set_header X-Forwarded-Port $server_port; + +when using Apache, ``mod_proxy`` automatically forwards the following headers:: + + X-Forwarded-For + X-Forwarded-Host + X-Forwarded-Server + +You will also want to add to Apache:: + + RequestHeader set X-Forwarded-Proto https + +Configure waitress's ``trusted_proxy_headers`` as appropriate:: + + trusted_proxy_headers = "x-forwarded-for x-forwarded-host x-forwarded-proto x-forwarded-port" + +At this point waitress will set up the WSGI environment using the information +specified in the trusted proxy headers. This will setup the following +variables:: + + HTTP_HOST + SERVER_NAME + SERVER_PORT + REMOTE_ADDR + REMOTE_PORT (if available) + wsgi.url_scheme + +Waitress also has support for the `Forwarded (RFC7239) HTTP header +`_ which is better defined than the ad-hoc +``X-Forwarded-*``, however support is not nearly as widespread yet. +``Forwarded`` supports similar functionality as the different individual +headers, and is mutually exclusive to using the ``X-Forwarded-*`` headers. + +To configure waitress to use the ``Forwarded`` header, set:: + + trusted_proxy_headers = "forwarded" + +.. note:: + + You must also configure the Waitress server's ``trusted_proxy`` to + contain the IP address of the proxy. + + +Using ``url_prefix`` to influence ``SCRIPT_NAME`` and ``PATH_INFO`` +------------------------------------------------------------------- + +You can have the Waitress server use a particular url prefix by default for all +URLs generated by downstream applications that take ``SCRIPT_NAME`` into +account.: + +.. code-block:: python + + from waitress import serve + serve(wsgiapp, listen='0.0.0.0:8080', url_prefix='/foo') + +Setting this to any value except the empty string will cause the WSGI +``SCRIPT_NAME`` value to be that value, minus any trailing slashes you add, and +it will cause the ``PATH_INFO`` of any request which is prefixed with this +value to be stripped of the prefix. This is useful in proxying scenarios where +you wish to forward all traffic to a Waitress server but need URLs generated by +downstream applications to be prefixed with a particular path segment. diff --git a/docs/runner.rst b/docs/runner.rst index 88a7d63a..2776e444 100644 --- a/docs/runner.rst +++ b/docs/runner.rst @@ -3,14 +3,10 @@ waitress-serve -------------- -Waitress comes bundled with a thin command-line wrapper around the -``waitress.serve`` function called ``waitress-serve``. This is useful for -development, and in production situations where serving of static assets is -delegated to a reverse proxy, such as Nginx or Apache. +.. versionadded:: 0.8.4 -.. note:: - - This feature is new as of Waitress 0.8.4. + Waitress comes bundled with a thin command-line wrapper around the ``waitress.serve`` function called ``waitress-serve``. + This is useful for development, and in production situations where serving of static assets is delegated to a reverse proxy, such as nginx or Apache. ``waitress-serve`` takes the very same :ref:`arguments ` as the ``waitress.serve`` function, but where the function's arguments have @@ -147,19 +143,26 @@ Tuning options: 8192. ``--send-bytes=INT`` - Number of bytes to send to socket.send(). Default is 18000. + Number of bytes to send to socket.send(). Default is 1. Multiples of 9000 should avoid partly-filled TCP packets. + .. deprecated:: 1.3 + ``--outbuf-overflow=INT`` A temporary file should be created if the pending output is larger than this. Default is 1048576 (1MB). +``--outbuf-high-watermark=INT`` + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. Default is 16777216 (16MB). + ``--inbuf-overflow=INT`` A temporary file should be created if the pending input is larger than this. Default is 524288 (512KB). ``--connection-limit=INT`` - Stop creating new channelse if too many are already active. Default is + Stop creating new channels if too many are already active. Default is 100. ``--cleanup-interval=INT`` @@ -168,11 +171,11 @@ Tuning options: ``--channel-timeout=INT`` Maximum number of seconds to leave inactive connections open. Default is - 120. 'Inactive' is defined as 'has recieved no data from the client and has + 120. 'Inactive' is defined as 'has received no data from the client and has sent no data to the client'. ``--[no-]log-socket-errors`` - Toggle whether premature client disconnect tracepacks ought to be logged. + Toggle whether premature client disconnect tracebacks ought to be logged. On by default. ``--max-request-header-size=INT`` diff --git a/docs/socket-activation.rst b/docs/socket-activation.rst new file mode 100644 index 00000000..63483a31 --- /dev/null +++ b/docs/socket-activation.rst @@ -0,0 +1,45 @@ +Socket Activation +----------------- + +While waitress does not support the various implementations of socket activation, +for example using systemd or launchd, it is prepared to receive pre-bound sockets +from init systems, process and socket managers, or other launchers that can provide +pre-bound sockets. + +The following shows a code example starting waitress with two pre-bound Internet sockets. + +.. code-block:: python + + import socket + import waitress + + + def app(environ, start_response): + content_length = environ.get('CONTENT_LENGTH', None) + if content_length is not None: + content_length = int(content_length) + body = environ['wsgi.input'].read(content_length) + content_length = str(len(body)) + start_response( + '200 OK', + [('Content-Length', content_length), ('Content-Type', 'text/plain')] + ) + return [body] + + + if __name__ == '__main__': + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(('127.0.0.1', 8080)) + sockets[1].bind(('127.0.0.1', 9090)) + waitress.serve(app, sockets=sockets) + for socket in sockets: + socket.close() + +Generally, to implement socket activation for a given init system, a wrapper +script uses the init system specific libraries to retrieve the sockets from +the init system. Afterwards it starts waitress, passing the sockets with the parameter +``sockets``. Note that the sockets have to be bound, which all init systems +supporting socket activation do. + diff --git a/docs/usage.rst b/docs/usage.rst new file mode 100644 index 00000000..dfcd6dcb --- /dev/null +++ b/docs/usage.rst @@ -0,0 +1,104 @@ +.. _usage: + +===== +Usage +===== + +The following code will run waitress on port 8080 on all available IP addresses, both IPv4 and IPv6. + +.. code-block:: python + + from waitress import serve + serve(wsgiapp, listen='*:8080') + +Press :kbd:`Ctrl-C` (or :kbd:`Ctrl-Break` on Windows) to exit the server. + +The following will run waitress on port 8080 on all available IPv4 addresses, but not IPv6. + +.. code-block:: python + + from waitress import serve + serve(wsgiapp, host='0.0.0.0', port=8080) + +By default Waitress binds to any IPv4 address on port 8080. +You can omit the ``host`` and ``port`` arguments and just call ``serve`` with the WSGI app as a single argument: + +.. code-block:: python + + from waitress import serve + serve(wsgiapp) + +If you want to serve your application through a UNIX domain socket (to serve a downstream HTTP server/proxy such as nginx, lighttpd, and so on), call ``serve`` with the ``unix_socket`` argument: + +.. code-block:: python + + from waitress import serve + serve(wsgiapp, unix_socket='/path/to/unix.sock') + +Needless to say, this configuration won't work on Windows. + +Exceptions generated by your application will be shown on the console by +default. See :ref:`access-logging` to change this. + +There's an entry point for :term:`PasteDeploy` (``egg:waitress#main``) that +lets you use Waitress's WSGI gateway from a configuration file, e.g.: + +.. code-block:: ini + + [server:main] + use = egg:waitress#main + listen = 127.0.0.1:8080 + +Using ``host`` and ``port`` is also supported: + +.. code-block:: ini + + [server:main] + host = 127.0.0.1 + port = 8080 + +The :term:`PasteDeploy` syntax for UNIX domain sockets is analagous: + +.. code-block:: ini + + [server:main] + use = egg:waitress#main + unix_socket = /path/to/unix.sock + +You can find more settings to tweak (arguments to ``waitress.serve`` or +equivalent settings in PasteDeploy) in :ref:`arguments`. + +Additionally, there is a command line runner called ``waitress-serve``, which +can be used in development and in situations where the likes of +:term:`PasteDeploy` is not necessary: + +.. code-block:: bash + + # Listen on both IPv4 and IPv6 on port 8041 + waitress-serve --listen=*:8041 myapp:wsgifunc + + # Listen on only IPv4 on port 8041 + waitress-serve --port=8041 myapp:wsgifunc + +Heroku +------ + +Waitress can be used to serve WSGI apps on Heroku, include waitress in your requirements.txt file a update the Procfile as following: + +.. code-block:: bash + + web: waitress-serve \ + --listen "*:$PORT" \ + --trusted-proxy '*' \ + --trusted-proxy-headers 'x-forwarded-for x-forwarded-proto x-forwarded-port' \ + --log-untrusted-proxy-headers \ + --clear-untrusted-proxy-headers \ + --threads ${WEB_CONCURRENCY:-4} \ + myapp:wsgifunc + +The proxy config informs Waitress to trust the `forwarding headers `_ set by the Heroku load balancer. +It also allows for setting the standard ``WEB_CONCURRENCY`` environment variable to tweak the number of requests handled by Waitress at a time. + +Note that Waitress uses a thread-based model and careful effort should be taken to ensure that requests do not take longer than 30 seconds or Heroku will inform the client that the request failed even though the request is still being processed by Waitress and occupying a thread until it completes. + +For more information on this, see :ref:`runner`. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b68b9058 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools >= 41"] +build-backend = "setuptools.build_meta" + +[tool.black] +target-version = ['py35', 'py36', 'py37', 'py38'] +exclude = ''' +/( + \.git + | .tox +)/ +''' + + # This next section only exists for people that have their editors +# automatically call isort, black already sorts entries on its own when run. +[tool.isort] +profile = "black" +multi_line_output = 3 +src_paths = ["src", "tests"] +skip_glob = ["docs/*"] +include_trailing_comma = true +force_grid_wrap = false +combine_as_imports = true +line_length = 88 +force_sort_within_sections = true +default_section = "THIRDPARTY" +known_first_party = "waitress" diff --git a/setup.cfg b/setup.cfg index 4be5f9cf..333766a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,68 @@ -[easy_install] -zip_ok = false +[metadata] +name = waitress +version = 2.1.2 +description = Waitress WSGI server +long_description = file: README.rst, CHANGES.txt +long_description_content_type = text/x-rst +keywords = waitress wsgi server http +license = ZPL 2.1 +classifiers = + Development Status :: 6 - Mature + Environment :: Web Environment + Intended Audience :: Developers + License :: OSI Approved :: Zope Public License + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: Implementation :: CPython + Programming Language :: Python :: Implementation :: PyPy + Operating System :: OS Independent + Topic :: Internet :: WWW/HTTP + Topic :: Internet :: WWW/HTTP :: WSGI +url = https://github.com/Pylons/waitress +project_urls = + Documentation = https://docs.pylonsproject.org/projects/waitress/en/latest/index.html + Changelog = https://docs.pylonsproject.org/projects/waitress/en/latest/index.html#change-history + Issue Tracker = https://github.com/Pylons/waitress/issues -[nosetests] -match=^test -where=waitress -nocapture=1 -cover-package=waitress -cover-erase=1 +author = Zope Foundation and Contributors +author_email = zope-dev@zope.org +maintainer = Pylons Project +maintainer_email = pylons-discuss@googlegroups.com -[bdist_wheel] -universal = 1 +[options] +package_dir= + =src +packages=find: +python_requires = >=3.7.0 -[aliases] -dev = develop easy_install waitress[testing] -docs = develop easy_install waitress[docs] +[options.entry_points] +paste.server_runner = + main = waitress:serve_paste +console_scripts = + waitress-serve = waitress.runner:run + +[options.packages.find] +where=src + +[options.extras_require] +testing = + pytest + pytest-cover + coverage>=5.0 + +docs = + Sphinx>=1.8.1 + docutils + pylons-sphinx-themes>=1.0.9 + +[tool:pytest] +python_files = test_*.py +# For the benefit of test_wasyncore.py +python_classes = Test* +testpaths = + tests +addopts = --cov -W always diff --git a/setup.py b/setup.py index 3bda37d8..60684932 100644 --- a/setup.py +++ b/setup.py @@ -1,80 +1,3 @@ -############################################################################## -# -# Copyright (c) 2006 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -import os -from setuptools import setup, find_packages +from setuptools import setup -here = os.path.abspath(os.path.dirname(__file__)) -try: - README = open(os.path.join(here, 'README.rst')).read() - CHANGES = open(os.path.join(here, 'CHANGES.txt')).read() -except IOError: - README = CHANGES = '' - -docs_extras = [ - 'Sphinx', - 'docutils', - 'pylons-sphinx-themes >= 0.3', -] - -testing_extras = [ - 'nose', - 'coverage', -] - -setup( - name='waitress', - version='1.1.0', - author='Zope Foundation and Contributors', - author_email='zope-dev@zope.org', - maintainer="Pylons Project", - maintainer_email="pylons-discuss@googlegroups.com", - description='Waitress WSGI server', - long_description=README + '\n\n' + CHANGES, - license='ZPL 2.1', - keywords='waitress wsgi server http', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Zope Public License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - 'Natural Language :: English', - 'Operating System :: OS Independent', - 'Topic :: Internet :: WWW/HTTP', - ], - url='https://github.com/Pylons/waitress', - packages=find_packages(), - extras_require={ - 'testing': testing_extras, - 'docs': docs_extras, - }, - include_package_data=True, - test_suite='waitress', - zip_safe=False, - entry_points=""" - [paste.server_runner] - main = waitress:serve_paste - [console_scripts] - waitress-serve = waitress.runner:run - """ -) +setup() diff --git a/waitress/__init__.py b/src/waitress/__init__.py similarity index 54% rename from waitress/__init__.py rename to src/waitress/__init__.py index 775fe3a5..bbb99da0 100644 --- a/waitress/__init__.py +++ b/src/waitress/__init__.py @@ -1,41 +1,46 @@ -from waitress.server import create_server import logging +from waitress.server import create_server + + def serve(app, **kw): - _server = kw.pop('_server', create_server) # test shim - _quiet = kw.pop('_quiet', False) # test shim - _profile = kw.pop('_profile', False) # test shim - if not _quiet: # pragma: no cover + _server = kw.pop("_server", create_server) # test shim + _quiet = kw.pop("_quiet", False) # test shim + _profile = kw.pop("_profile", False) # test shim + if not _quiet: # pragma: no cover # idempotent if logging has already been set up logging.basicConfig() server = _server(app, **kw) - if not _quiet: # pragma: no cover - server.print_listen('Serving on http://{}:{}') - if _profile: # pragma: no cover - profile('server.run()', globals(), locals(), (), False) + if not _quiet: # pragma: no cover + server.print_listen("Serving on http://{}:{}") + if _profile: # pragma: no cover + profile("server.run()", globals(), locals(), (), False) else: server.run() + def serve_paste(app, global_conf, **kw): serve(app, **kw) return 0 -def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover + +def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover # runs a command under the profiler and print profiling output at shutdown import os import profile import pstats import tempfile + fd, fn = tempfile.mkstemp() try: profile.runctx(cmd, globals, locals, fn) stats = pstats.Stats(fn) stats.strip_dirs() # calls,time,cumulative and cumulative,calls,time are useful - stats.sort_stats(*(sort_order or ('cumulative', 'calls', 'time'))) + stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) if callers: - stats.print_callers(.3) + stats.print_callers(0.3) else: - stats.print_stats(.3) + stats.print_stats(0.3) finally: os.remove(fn) diff --git a/waitress/__main__.py b/src/waitress/__main__.py similarity index 98% rename from waitress/__main__.py rename to src/waitress/__main__.py index e484f40a..9bcd07e5 100644 --- a/waitress/__main__.py +++ b/src/waitress/__main__.py @@ -1,2 +1,3 @@ from waitress.runner import run # pragma nocover + run() # pragma nocover diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py new file mode 100644 index 00000000..f2a852c7 --- /dev/null +++ b/src/waitress/adjustments.py @@ -0,0 +1,523 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Adjustments are tunable parameters. +""" +import getopt +import socket +import warnings + +from .compat import HAS_IPV6, WIN +from .proxy_headers import PROXY_HEADERS + +truthy = frozenset(("t", "true", "y", "yes", "on", "1")) + +KNOWN_PROXY_HEADERS = frozenset( + header.lower().replace("_", "-") for header in PROXY_HEADERS +) + + +def asbool(s): + """Return the boolean value ``True`` if the case-lowered value of string + input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise + return the boolean value ``False``. If ``s`` is the value ``None``, + return ``False``. If ``s`` is already one of the boolean values ``True`` + or ``False``, return it.""" + if s is None: + return False + if isinstance(s, bool): + return s + s = str(s).strip() + return s.lower() in truthy + + +def asoctal(s): + """Convert the given octal string to an actual number.""" + return int(s, 8) + + +def aslist_cronly(value): + if isinstance(value, str): + value = filter(None, [x.strip() for x in value.splitlines()]) + return list(value) + + +def aslist(value): + """Return a list of strings, separating the input based on newlines + and, if flatten=True (the default), also split on spaces within + each line.""" + values = aslist_cronly(value) + result = [] + for value in values: + subvalues = value.split() + result.extend(subvalues) + return result + + +def asset(value): + return set(aslist(value)) + + +def slash_fixed_str(s): + s = s.strip() + if s: + # always have a leading slash, replace any number of leading slashes + # with a single slash, and strip any trailing slashes + s = "/" + s.lstrip("/").rstrip("/") + return s + + +def str_iftruthy(s): + return str(s) if s else None + + +def as_socket_list(sockets): + """Checks if the elements in the list are of type socket and + removes them if not.""" + return [sock for sock in sockets if isinstance(sock, socket.socket)] + + +class _str_marker(str): + pass + + +class _int_marker(int): + pass + + +class _bool_marker: + pass + + +class Adjustments: + """This class contains tunable parameters.""" + + _params = ( + ("host", str), + ("port", int), + ("ipv4", asbool), + ("ipv6", asbool), + ("listen", aslist), + ("threads", int), + ("trusted_proxy", str_iftruthy), + ("trusted_proxy_count", int), + ("trusted_proxy_headers", asset), + ("log_untrusted_proxy_headers", asbool), + ("clear_untrusted_proxy_headers", asbool), + ("url_scheme", str), + ("url_prefix", slash_fixed_str), + ("backlog", int), + ("recv_bytes", int), + ("send_bytes", int), + ("outbuf_overflow", int), + ("outbuf_high_watermark", int), + ("inbuf_overflow", int), + ("connection_limit", int), + ("cleanup_interval", int), + ("channel_timeout", int), + ("log_socket_errors", asbool), + ("max_request_header_size", int), + ("max_request_body_size", int), + ("expose_tracebacks", asbool), + ("ident", str_iftruthy), + ("asyncore_loop_timeout", int), + ("asyncore_use_poll", asbool), + ("unix_socket", str), + ("unix_socket_perms", asoctal), + ("sockets", as_socket_list), + ("channel_request_lookahead", int), + ("server_name", str), + ) + + _param_map = dict(_params) + + # hostname or IP address to listen on + host = _str_marker("0.0.0.0") + + # TCP port to listen on + port = _int_marker(8080) + + listen = [f"{host}:{port}"] + + # number of threads available for tasks + threads = 4 + + # Host allowed to overrid ``wsgi.url_scheme`` via header + trusted_proxy = None + + # How many proxies we trust when chained + # + # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" + # + # or + # + # Forwarded: for=192.0.2.1, For="[2001:db8::1]" + # + # means there were (potentially), two proxies involved. If we know there is + # only 1 valid proxy, then that initial IP address "192.0.2.1" is not + # trusted and we completely ignore it. If there are two trusted proxies in + # the path, this value should be set to a higher number. + trusted_proxy_count = None + + # Which of the proxy headers should we trust, this is a set where you + # either specify forwarded or one or more of forwarded-host, forwarded-for, + # forwarded-proto, forwarded-port. + trusted_proxy_headers = set() + + # Would you like waitress to log warnings about untrusted proxy headers + # that were encountered while processing the proxy headers? This only makes + # sense to set when you have a trusted_proxy, and you expect the upstream + # proxy server to filter invalid headers + log_untrusted_proxy_headers = False + + # Should waitress clear any proxy headers that are not deemed trusted from + # the environ? Change to True by default in 2.x + clear_untrusted_proxy_headers = _bool_marker + + # default ``wsgi.url_scheme`` value + url_scheme = "http" + + # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` + # when nonempty + url_prefix = "" + + # server identity (sent in Server: header) + ident = "waitress" + + # backlog is the value waitress passes to pass to socket.listen() This is + # the maximum number of incoming TCP connections that will wait in an OS + # queue for an available channel. From listen(1): "If a connection + # request arrives when the queue is full, the client may receive an error + # with an indication of ECONNREFUSED or, if the underlying protocol + # supports retransmission, the request may be ignored so that a later + # reattempt at connection succeeds." + backlog = 1024 + + # recv_bytes is the argument to pass to socket.recv(). + recv_bytes = 8192 + + # deprecated setting controls how many bytes will be buffered before + # being flushed to the socket + send_bytes = 1 + + # A tempfile should be created if the pending output is larger than + # outbuf_overflow, which is measured in bytes. The default is 1MB. This + # is conservative. + outbuf_overflow = 1048576 + + # The app_iter will pause when pending output is larger than this value + # in bytes. + outbuf_high_watermark = 16777216 + + # A tempfile should be created if the pending input is larger than + # inbuf_overflow, which is measured in bytes. The default is 512K. This + # is conservative. + inbuf_overflow = 524288 + + # Stop creating new channels if too many are already active (integer). + # Each channel consumes at least one file descriptor, and, depending on + # the input and output body sizes, potentially up to three. The default + # is conservative, but you may need to increase the number of file + # descriptors available to the Waitress process on most platforms in + # order to safely change it (see ``ulimit -a`` "open files" setting). + # Note that this doesn't control the maximum number of TCP connections + # that can be waiting for processing; the ``backlog`` argument controls + # that. + connection_limit = 100 + + # Minimum seconds between cleaning up inactive channels. + cleanup_interval = 30 + + # Maximum seconds to leave an inactive connection open. + channel_timeout = 120 + + # Boolean: turn off to not log premature client disconnects. + log_socket_errors = True + + # maximum number of bytes of all request headers combined (256K default) + max_request_header_size = 262144 + + # maximum number of bytes in request body (1GB default) + max_request_body_size = 1073741824 + + # expose tracebacks of uncaught exceptions + expose_tracebacks = False + + # Path to a Unix domain socket to use. + unix_socket = None + + # Path to a Unix domain socket to use. + unix_socket_perms = 0o600 + + # The socket options to set on receiving a connection. It is a list of + # (level, optname, value) tuples. TCP_NODELAY disables the Nagle + # algorithm for writes (Waitress already buffers its writes). + socket_options = [ + (socket.SOL_TCP, socket.TCP_NODELAY, 1), + ] + + # The asyncore.loop timeout value + asyncore_loop_timeout = 1 + + # The asyncore.loop flag to use poll() instead of the default select(). + asyncore_use_poll = False + + # Enable IPv4 by default + ipv4 = True + + # Enable IPv6 by default + ipv6 = True + + # A list of sockets that waitress will use to accept connections. They can + # be used for e.g. socket activation + sockets = [] + + # By setting this to a value larger than zero, each channel stays readable + # and continues to read requests from the client even if a request is still + # running, until the number of buffered requests exceeds this value. + # This allows detecting if a client closed the connection while its request + # is being processed. + channel_request_lookahead = 0 + + # This setting controls the SERVER_NAME of the WSGI environment, this is + # only ever used if the remote client sent a request without a Host header + # (or when using the Proxy settings, without forwarding a Host header) + server_name = "waitress.invalid" + + def __init__(self, **kw): + + if "listen" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if listen is set.") + + if "listen" in kw and "sockets" in kw: + raise ValueError("socket may not be set if listen is set.") + + if "sockets" in kw and ("host" in kw or "port" in kw): + raise ValueError("host or port may not be set if sockets is set.") + + if "sockets" in kw and "unix_socket" in kw: + raise ValueError("unix_socket may not be set if sockets is set") + + if "unix_socket" in kw and ("host" in kw or "port" in kw): + raise ValueError("unix_socket may not be set if host or port is set") + + if "unix_socket" in kw and "listen" in kw: + raise ValueError("unix_socket may not be set if listen is set") + + if "send_bytes" in kw: + warnings.warn( + "send_bytes will be removed in a future release", DeprecationWarning + ) + + for k, v in kw.items(): + if k not in self._param_map: + raise ValueError("Unknown adjustment %r" % k) + setattr(self, k, self._param_map[k](v)) + + if not isinstance(self.host, _str_marker) or not isinstance( + self.port, _int_marker + ): + self.listen = [f"{self.host}:{self.port}"] + + enabled_families = socket.AF_UNSPEC + + if not self.ipv4 and not HAS_IPV6: # pragma: no cover + raise ValueError( + "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." + ) + + if self.ipv4 and not self.ipv6: + enabled_families = socket.AF_INET + + if not self.ipv4 and self.ipv6 and HAS_IPV6: + enabled_families = socket.AF_INET6 + + wanted_sockets = [] + hp_pairs = [] + for i in self.listen: + if ":" in i: + (host, port) = i.rsplit(":", 1) + + # IPv6 we need to make sure that we didn't split on the address + if "]" in port: # pragma: nocover + (host, port) = (i, str(self.port)) + else: + (host, port) = (i, str(self.port)) + + if WIN: # pragma: no cover + try: + # Try turning the port into an integer + port = int(port) + + except Exception: + raise ValueError( + "Windows does not support service names instead of port numbers" + ) + + try: + if "[" in host and "]" in host: # pragma: nocover + host = host.strip("[").rstrip("]") + + if host == "*": + host = None + + for s in socket.getaddrinfo( + host, + port, + enabled_families, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE, + ): + (family, socktype, proto, _, sockaddr) = s + + # It seems that getaddrinfo() may sometimes happily return + # the same result multiple times, this of course makes + # bind() very unhappy... + # + # Split on %, and drop the zone-index from the host in the + # sockaddr. Works around a bug in OS X whereby + # getaddrinfo() returns the same link-local interface with + # two different zone-indices (which makes no sense what so + # ever...) yet treats them equally when we attempt to bind(). + if ( + sockaddr[1] == 0 + or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs + ): + wanted_sockets.append((family, socktype, proto, sockaddr)) + hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) + + except Exception: + raise ValueError("Invalid host/port specified.") + + if self.trusted_proxy_count is not None and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_count has no meaning without setting " "trusted_proxy" + ) + + elif self.trusted_proxy_count is None: + self.trusted_proxy_count = 1 + + if self.trusted_proxy_headers and self.trusted_proxy is None: + raise ValueError( + "trusted_proxy_headers has no meaning without setting " "trusted_proxy" + ) + + if self.trusted_proxy_headers: + self.trusted_proxy_headers = { + header.lower() for header in self.trusted_proxy_headers + } + + unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS + if unknown_values: + raise ValueError( + "Received unknown trusted_proxy_headers value (%s) expected one " + "of %s" + % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) + ) + + if ( + "forwarded" in self.trusted_proxy_headers + and self.trusted_proxy_headers - {"forwarded"} + ): + raise ValueError( + "The Forwarded proxy header and the " + "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " + "exclusive. Can't trust both!" + ) + + elif self.trusted_proxy is not None: + warnings.warn( + "No proxy headers were marked as trusted, but trusted_proxy was set. " + "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " + "This will be removed in future versions of waitress.", + DeprecationWarning, + ) + self.trusted_proxy_headers = {"x-forwarded-proto"} + + if self.clear_untrusted_proxy_headers is _bool_marker: + warnings.warn( + "In future versions of Waitress clear_untrusted_proxy_headers will be " + "set to True by default. You may opt-out by setting this value to " + "False, or opt-in explicitly by setting this to True.", + DeprecationWarning, + ) + self.clear_untrusted_proxy_headers = False + + self.listen = wanted_sockets + + self.check_sockets(self.sockets) + + @classmethod + def parse_args(cls, argv): + """Pre-parse command line arguments for input into __init__. Note that + this does not cast values into adjustment types, it just creates a + dictionary suitable for passing into __init__, where __init__ does the + casting. + """ + long_opts = ["help", "call"] + for opt, cast in cls._params: + opt = opt.replace("_", "-") + if cast is asbool: + long_opts.append(opt) + long_opts.append("no-" + opt) + else: + long_opts.append(opt + "=") + + kw = { + "help": False, + "call": False, + } + + opts, args = getopt.getopt(argv, "", long_opts) + for opt, value in opts: + param = opt.lstrip("-").replace("-", "_") + + if param == "listen": + kw["listen"] = "{} {}".format(kw.get("listen", ""), value) + continue + + if param.startswith("no_"): + param = param[3:] + kw[param] = "false" + elif param in ("help", "call"): + kw[param] = True + elif cls._param_map[param] is asbool: + kw[param] = "true" + else: + kw[param] = value + + return kw, args + + @classmethod + def check_sockets(cls, sockets): + has_unix_socket = False + has_inet_socket = False + has_unsupported_socket = False + for sock in sockets: + if ( + sock.family == socket.AF_INET or sock.family == socket.AF_INET6 + ) and sock.type == socket.SOCK_STREAM: + has_inet_socket = True + elif ( + hasattr(socket, "AF_UNIX") + and sock.family == socket.AF_UNIX + and sock.type == socket.SOCK_STREAM + ): + has_unix_socket = True + else: + has_unsupported_socket = True + if has_unix_socket and has_inet_socket: + raise ValueError("Internet and UNIX sockets may not be mixed.") + if has_unsupported_socket: + raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/waitress/buffers.py b/src/waitress/buffers.py similarity index 83% rename from waitress/buffers.py rename to src/waitress/buffers.py index cacc0947..8091ff0e 100644 --- a/waitress/buffers.py +++ b/src/waitress/buffers.py @@ -16,12 +16,13 @@ from io import BytesIO # copy_bytes controls the size of temp. strings for shuffling data around. -COPY_BYTES = 1 << 18 # 256K +COPY_BYTES = 1 << 18 # 256K # The maximum number of bytes to buffer in a simple string. STRBUF_LIMIT = 8192 -class FileBasedBuffer(object): + +class FileBasedBuffer: remain = 0 @@ -46,7 +47,7 @@ def __len__(self): def __nonzero__(self): return True - __bool__ = __nonzero__ # py3 + __bool__ = __nonzero__ # py3 def append(self, s): file = self.file @@ -73,8 +74,8 @@ def get(self, numbytes=-1, skip=False): def skip(self, numbytes, allow_prune=0): if self.remain < numbytes: - raise ValueError("Can't skip %d bytes in buffer of %d bytes" % ( - numbytes, self.remain) + raise ValueError( + "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) ) self.file.seek(numbytes, 1) self.remain = self.remain - numbytes @@ -104,21 +105,22 @@ def getfile(self): return self.file def close(self): - if hasattr(self.file, 'close'): + if hasattr(self.file, "close"): self.file.close() self.remain = 0 -class TempfileBasedBuffer(FileBasedBuffer): +class TempfileBasedBuffer(FileBasedBuffer): def __init__(self, from_buffer=None): FileBasedBuffer.__init__(self, self.newfile(), from_buffer) def newfile(self): from tempfile import TemporaryFile - return TemporaryFile('w+b') -class BytesIOBasedBuffer(FileBasedBuffer): + return TemporaryFile("w+b") + +class BytesIOBasedBuffer(FileBasedBuffer): def __init__(self, from_buffer=None): if from_buffer is not None: FileBasedBuffer.__init__(self, BytesIO(), from_buffer) @@ -129,15 +131,32 @@ def __init__(self, from_buffer=None): def newfile(self): return BytesIO() + +def _is_seekable(fp): + if hasattr(fp, "seekable"): + return fp.seekable() + return hasattr(fp, "seek") and hasattr(fp, "tell") + + class ReadOnlyFileBasedBuffer(FileBasedBuffer): # used as wsgi.file_wrapper def __init__(self, file, block_size=32768): self.file = file - self.block_size = block_size # for __iter__ + self.block_size = block_size # for __iter__ + + # This is for the benefit of anyone that is attempting to wrap this + # wsgi.file_wrapper in a WSGI middleware and wants to seek, this is + # useful for instance for support Range requests + if _is_seekable(self.file): + if hasattr(self.file, "seekable"): + self.seekable = self.file.seekable + + self.seek = self.file.seek + self.tell = self.file.tell def prepare(self, size=None): - if hasattr(self.file, 'seek') and hasattr(self.file, 'tell'): + if _is_seekable(self.file): start_pos = self.file.tell() self.file.seek(0, 2) end_pos = self.file.tell() @@ -163,7 +182,7 @@ def get(self, numbytes=-1, skip=False): file.seek(read_pos) return res - def __iter__(self): # called by task if self.filelike has no seek/tell + def __iter__(self): # called by task if self.filelike has no seek/tell return self def next(self): @@ -172,12 +191,13 @@ def next(self): raise StopIteration return val - __next__ = next # py3 + __next__ = next # py3 def append(self, s): raise NotImplementedError -class OverflowableBuffer(object): + +class OverflowableBuffer: """ This buffer implementation has four stages: - No data @@ -189,7 +209,7 @@ class OverflowableBuffer(object): overflowed = False buf = None - strbuf = b'' # Bytes-based buffer. + strbuf = b"" # Bytes-based buffer. def __init__(self, overflow): # overflow is the maximum to be stored in a StringIO buffer. @@ -209,7 +229,7 @@ def __nonzero__(self): # OverflowError on Python 2 return self.__len__() > 0 - __bool__ = __nonzero__ # py3 + __bool__ = __nonzero__ # py3 def _create_buffer(self): strbuf = self.strbuf @@ -220,15 +240,27 @@ def _create_buffer(self): buf = self.buf if strbuf: buf.append(self.strbuf) - self.strbuf = b'' + self.strbuf = b"" return buf def _set_small_buffer(self): - self.buf = BytesIOBasedBuffer(self.buf) + oldbuf = self.buf + self.buf = BytesIOBasedBuffer(oldbuf) + + # Attempt to close the old buffer + if hasattr(oldbuf, "close"): + oldbuf.close() + self.overflowed = False def _set_large_buffer(self): - self.buf = TempfileBasedBuffer(self.buf) + oldbuf = self.buf + self.buf = TempfileBasedBuffer(oldbuf) + + # Attempt to close the old buffer + if hasattr(oldbuf, "close"): + oldbuf.close() + self.overflowed = True def append(self, s): @@ -263,7 +295,7 @@ def skip(self, numbytes, allow_prune=False): # We could slice instead of converting to # a buffer, but that would eat up memory in # large transfers. - self.strbuf = b'' + self.strbuf = b"" return buf = self._create_buffer() buf.skip(numbytes, allow_prune) @@ -275,7 +307,7 @@ def prune(self): """ buf = self.buf if buf is None: - self.strbuf = b'' + self.strbuf = b"" return buf.prune() if self.overflowed: diff --git a/src/waitress/channel.py b/src/waitress/channel.py new file mode 100644 index 00000000..eb59dd3f --- /dev/null +++ b/src/waitress/channel.py @@ -0,0 +1,519 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import socket +import threading +import time +import traceback + +from waitress.buffers import OverflowableBuffer, ReadOnlyFileBasedBuffer +from waitress.parser import HTTPRequestParser +from waitress.task import ErrorTask, WSGITask +from waitress.utilities import InternalServerError + +from . import wasyncore + + +class ClientDisconnected(Exception): + """Raised when attempting to write to a closed socket.""" + + +class HTTPChannel(wasyncore.dispatcher): + """ + Setting self.requests = [somerequest] prevents more requests from being + received until the out buffers have been flushed. + + Setting self.requests = [] allows more requests to be received. + """ + + task_class = WSGITask + error_task_class = ErrorTask + parser_class = HTTPRequestParser + + # A request that has not been received yet completely is stored here + request = None + last_activity = 0 # Time of last activity + will_close = False # set to True to close the socket. + close_when_flushed = False # set to True to close the socket when flushed + sent_continue = False # used as a latch after sending 100 continue + total_outbufs_len = 0 # total bytes ready to send + current_outbuf_count = 0 # total bytes written to current outbuf + + # + # ASYNCHRONOUS METHODS (including __init__) + # + + def __init__(self, server, sock, addr, adj, map=None): + self.server = server + self.adj = adj + self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] + self.creation_time = self.last_activity = time.time() + self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) + + # requests_lock used to push/pop requests and modify the request that is + # currently being created + self.requests_lock = threading.Lock() + # outbuf_lock used to access any outbuf (expected to use an RLock) + self.outbuf_lock = threading.Condition() + + wasyncore.dispatcher.__init__(self, sock, map=map) + + # Don't let wasyncore.dispatcher throttle self.addr on us. + self.addr = addr + self.requests = [] + + def check_client_disconnected(self): + """ + This method is inserted into the environment of any created task so it + may occasionally check if the client has disconnected and interrupt + execution. + """ + + return not self.connected + + def writable(self): + # if there's data in the out buffer or we've been instructed to close + # the channel (possibly by our server maintenance logic), run + # handle_write + + return self.total_outbufs_len or self.will_close or self.close_when_flushed + + def handle_write(self): + # Precondition: there's data in the out buffer to be sent, or + # there's a pending will_close request + + if not self.connected: + # we dont want to close the channel twice + + return + + # try to flush any pending output + + if not self.requests: + # 1. There are no running tasks, so we don't need to try to lock + # the outbuf before sending + # 2. The data in the out buffer should be sent as soon as possible + # because it's either data left over from task output + # or a 100 Continue line sent within "received". + flush = self._flush_some + elif self.total_outbufs_len >= self.adj.send_bytes: + # 1. There's a running task, so we need to try to lock + # the outbuf before sending + # 2. Only try to send if the data in the out buffer is larger + # than self.adj_bytes to avoid TCP fragmentation + flush = self._flush_some_if_lockable + else: + # 1. There's not enough data in the out buffer to bother to send + # right now. + flush = None + + self._flush_exception(flush) + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def _flush_exception(self, flush, do_close=True): + if flush: + try: + return (flush(do_close=do_close), False) + except OSError: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.will_close = True + + return (False, True) + except Exception: # pragma: nocover + self.logger.exception("Unexpected exception when flushing") + self.will_close = True + + return (False, True) + + def readable(self): + # We might want to read more requests. We can only do this if: + # 1. We're not already about to close the connection. + # 2. We're not waiting to flush remaining data before closing the + # connection + # 3. There are not too many tasks already queued + # 4. There's no data in the output buffer that needs to be sent + # before we potentially create a new task. + + return not ( + self.will_close + or self.close_when_flushed + or len(self.requests) > self.adj.channel_request_lookahead + or self.total_outbufs_len + ) + + def handle_read(self): + try: + data = self.recv(self.adj.recv_bytes) + except OSError: + if self.adj.log_socket_errors: + self.logger.exception("Socket error") + self.handle_close() + + return + + if data: + self.last_activity = time.time() + self.received(data) + else: + # Client disconnected. + self.connected = False + + def send_continue(self): + """ + Send a 100-Continue header to the client. This is either called from + receive (if no requests are running and the client expects it) or at + the end of service (if no more requests are queued and a request has + been read partially that expects it). + """ + self.request.expect_continue = False + outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" + num_bytes = len(outbuf_payload) + with self.outbuf_lock: + self.outbufs[-1].append(outbuf_payload) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + self.sent_continue = True + self._flush_some() + self.request.completed = False + + def received(self, data): + """ + Receives input asynchronously and assigns one or more requests to the + channel. + """ + + if not data: + return False + + with self.requests_lock: + while data: + if self.request is None: + self.request = self.parser_class(self.adj) + n = self.request.received(data) + + # if there are requests queued, we can not send the continue + # header yet since the responses need to be kept in order + + if ( + self.request.expect_continue + and self.request.headers_finished + and not self.requests + and not self.sent_continue + ): + self.send_continue() + + if self.request.completed: + # The request (with the body) is ready to use. + self.sent_continue = False + + if not self.request.empty: + self.requests.append(self.request) + + if len(self.requests) == 1: + # self.requests was empty before so the main thread + # is in charge of starting the task. Otherwise, + # service() will add a new task after each request + # has been processed + self.server.add_task(self) + self.request = None + + if n >= len(data): + break + data = data[n:] + + return True + + def _flush_some_if_lockable(self, do_close=True): + # Since our task may be appending to the outbuf, we try to acquire + # the lock, but we don't block if we can't. + + if self.outbuf_lock.acquire(False): + try: + self._flush_some(do_close=do_close) + + if self.total_outbufs_len < self.adj.outbuf_high_watermark: + self.outbuf_lock.notify() + finally: + self.outbuf_lock.release() + + def _flush_some(self, do_close=True): + # Send as much data as possible to our client + + sent = 0 + dobreak = False + + while True: + outbuf = self.outbufs[0] + # use outbuf.__len__ rather than len(outbuf) FBO of not getting + # OverflowError on 32-bit Python + outbuflen = outbuf.__len__() + + while outbuflen > 0: + chunk = outbuf.get(self.sendbuf_len) + num_sent = self.send(chunk, do_close=do_close) + + if num_sent: + outbuf.skip(num_sent, True) + outbuflen -= num_sent + sent += num_sent + self.total_outbufs_len -= num_sent + else: + # failed to write anything, break out entirely + dobreak = True + + break + else: + # self.outbufs[-1] must always be a writable outbuf + + if len(self.outbufs) > 1: + toclose = self.outbufs.pop(0) + try: + toclose.close() + except Exception: + self.logger.exception("Unexpected error when closing an outbuf") + else: + # caught up, done flushing for now + dobreak = True + + if dobreak: + break + + if sent: + self.last_activity = time.time() + + return True + + return False + + def handle_close(self): + with self.outbuf_lock: + for outbuf in self.outbufs: + try: + outbuf.close() + except Exception: + self.logger.exception( + "Unknown exception while trying to close outbuf" + ) + self.total_outbufs_len = 0 + self.connected = False + self.outbuf_lock.notify() + wasyncore.dispatcher.close(self) + + def add_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of opened channels. + """ + wasyncore.dispatcher.add_channel(self, map) + self.server.active_channels[self._fileno] = self + + def del_channel(self, map=None): + """See wasyncore.dispatcher + + This hook keeps track of closed channels. + """ + fd = self._fileno # next line sets this to None + wasyncore.dispatcher.del_channel(self, map) + ac = self.server.active_channels + + if fd in ac: + del ac[fd] + + # + # SYNCHRONOUS METHODS + # + + def write_soon(self, data): + if not self.connected: + # if the socket is closed then interrupt the task so that it + # can cleanup possibly before the app_iter is exhausted + raise ClientDisconnected + + if data: + # the async mainloop might be popping data off outbuf; we can + # block here waiting for it because we're in a task thread + with self.outbuf_lock: + self._flush_outbufs_below_high_watermark() + + if not self.connected: + raise ClientDisconnected + num_bytes = len(data) + + if data.__class__ is ReadOnlyFileBasedBuffer: + # they used wsgi.file_wrapper + self.outbufs.append(data) + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + else: + if self.current_outbuf_count >= self.adj.outbuf_high_watermark: + # rotate to a new buffer if the current buffer has hit + # the watermark to avoid it growing unbounded + nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) + self.outbufs.append(nextbuf) + self.current_outbuf_count = 0 + self.outbufs[-1].append(data) + self.current_outbuf_count += num_bytes + self.total_outbufs_len += num_bytes + + if self.total_outbufs_len >= self.adj.send_bytes: + (flushed, exception) = self._flush_exception( + self._flush_some, do_close=False + ) + + if ( + exception + or not flushed + or self.total_outbufs_len >= self.adj.send_bytes + ): + self.server.pull_trigger() + + return num_bytes + + return 0 + + def _flush_outbufs_below_high_watermark(self): + # check first to avoid locking if possible + + if self.total_outbufs_len > self.adj.outbuf_high_watermark: + with self.outbuf_lock: + (_, exception) = self._flush_exception(self._flush_some, do_close=False) + + if exception: + # An exception happened while flushing, wake up the main + # thread, then wait for it to decide what to do next + # (probably close the socket, and then just return) + self.server.pull_trigger() + self.outbuf_lock.wait() + + return + + while ( + self.connected + and self.total_outbufs_len > self.adj.outbuf_high_watermark + ): + self.server.pull_trigger() + self.outbuf_lock.wait() + + def service(self): + """Execute one request. If there are more, we add another task to the + server at the end.""" + + request = self.requests[0] + + if request.error: + task = self.error_task_class(self, request) + else: + task = self.task_class(self, request) + + try: + if self.connected: + task.service() + else: + task.close_on_finish = True + except ClientDisconnected: + self.logger.info("Client disconnected while serving %s" % task.request.path) + task.close_on_finish = True + except Exception: + self.logger.exception("Exception while serving %s" % task.request.path) + + if not task.wrote_header: + if self.adj.expose_tracebacks: + body = traceback.format_exc() + else: + body = "The server encountered an unexpected internal server error" + req_version = request.version + req_headers = request.headers + err_request = self.parser_class(self.adj) + err_request.error = InternalServerError(body) + # copy some original request attributes to fulfill + # HTTP 1.1 requirements + err_request.version = req_version + try: + err_request.headers["CONNECTION"] = req_headers["CONNECTION"] + except KeyError: + pass + task = self.error_task_class(self, err_request) + try: + task.service() # must not fail + except ClientDisconnected: + task.close_on_finish = True + else: + task.close_on_finish = True + + if task.close_on_finish: + with self.requests_lock: + self.close_when_flushed = True + + for request in self.requests: + request.close() + self.requests = [] + else: + # before processing a new request, ensure there is not too + # much data in the outbufs waiting to be flushed + # NB: currently readable() returns False while we are + # flushing data so we know no new requests will come in + # that we need to account for, otherwise it'd be better + # to do this check at the start of the request instead of + # at the end to account for consecutive service() calls + + if len(self.requests) > 1: + self._flush_outbufs_below_high_watermark() + + # this is a little hacky but basically it's forcing the + # next request to create a new outbuf to avoid sharing + # outbufs across requests which can cause outbufs to + # not be deallocated regularly when a connection is open + # for a long time + + if self.current_outbuf_count > 0: + self.current_outbuf_count = self.adj.outbuf_high_watermark + + request.close() + + # Add new task to process the next request + with self.requests_lock: + self.requests.pop(0) + + if self.connected and self.requests: + self.server.add_task(self) + elif ( + self.connected + and self.request is not None + and self.request.expect_continue + and self.request.headers_finished + and not self.sent_continue + ): + # A request waits for a signal to continue, but we could + # not send it until now because requests were being + # processed and the output needs to be kept in order + self.send_continue() + + if self.connected: + self.server.pull_trigger() + + self.last_activity = time.time() + + def cancel(self): + """Cancels all pending / active requests""" + self.will_close = True + self.connected = False + self.last_activity = time.time() + self.requests = [] diff --git a/src/waitress/compat.py b/src/waitress/compat.py new file mode 100644 index 00000000..67543b9c --- /dev/null +++ b/src/waitress/compat.py @@ -0,0 +1,29 @@ +import platform + +# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, +# Python on Windows may not define IPPROTO_IPV6 in socket. +import socket +import sys +import warnings + +# True if we are running on Windows +WIN = platform.system() == "Windows" + +MAXINT = sys.maxsize +HAS_IPV6 = socket.has_ipv6 + +if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): + IPPROTO_IPV6 = socket.IPPROTO_IPV6 + IPV6_V6ONLY = socket.IPV6_V6ONLY +else: # pragma: no cover + if WIN: + IPPROTO_IPV6 = 41 + IPV6_V6ONLY = 27 + else: + warnings.warn( + "OS does not support required IPv6 socket flags. This is requirement " + "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " + "IPv6 support has been disabled.", + RuntimeWarning, + ) + HAS_IPV6 = False diff --git a/src/waitress/parser.py b/src/waitress/parser.py new file mode 100644 index 00000000..b31b5ccb --- /dev/null +++ b/src/waitress/parser.py @@ -0,0 +1,442 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser + +This server uses asyncore to accept connections and do initial +processing but threads to do work. +""" +from io import BytesIO +import re +from urllib import parse +from urllib.parse import unquote_to_bytes + +from waitress.buffers import OverflowableBuffer +from waitress.receiver import ChunkedReceiver, FixedStreamReceiver +from waitress.rfc7230 import HEADER_FIELD_RE, ONLY_DIGIT_RE +from waitress.utilities import ( + BadRequest, + RequestEntityTooLarge, + RequestHeaderFieldsTooLarge, + ServerNotImplemented, + find_double_newline, +) + + +def unquote_bytes_to_wsgi(bytestring): + return unquote_to_bytes(bytestring).decode("latin-1") + + +class ParsingError(Exception): + pass + + +class TransferEncodingNotImplemented(Exception): + pass + + +class HTTPRequestParser: + """A structure that collects the HTTP request. + + Once the stream is completed, the instance is passed to + a server task constructor. + """ + + completed = False # Set once request is completed. + empty = False # Set if no request was made. + expect_continue = False # client sent "Expect: 100-continue" header + headers_finished = False # True when headers have been read + header_plus = b"" + chunked = False + content_length = 0 + header_bytes_received = 0 + body_bytes_received = 0 + body_rcv = None + version = "1.0" + error = None + connection_close = False + + # Other attributes: first_line, header, headers, command, uri, version, + # path, query, fragment + + def __init__(self, adj): + """ + adj is an Adjustments object. + """ + # headers is a mapping containing keys translated to uppercase + # with dashes turned into underscores. + self.headers = {} + self.adj = adj + + def received(self, data): + """ + Receives the HTTP stream for one request. Returns the number of + bytes consumed. Sets the completed flag once both the header and the + body have been received. + """ + + if self.completed: + return 0 # Can't consume any more. + + datalen = len(data) + br = self.body_rcv + + if br is None: + # In header. + max_header = self.adj.max_request_header_size + + s = self.header_plus + data + index = find_double_newline(s) + consumed = 0 + + if index >= 0: + # If the headers have ended, and we also have part of the body + # message in data we still want to validate we aren't going + # over our limit for received headers. + self.header_bytes_received = index + consumed = datalen - (len(s) - index) + else: + self.header_bytes_received += datalen + consumed = datalen + + # If the first line + headers is over the max length, we return a + # RequestHeaderFieldsTooLarge error rather than continuing to + # attempt to parse the headers. + + if self.header_bytes_received >= max_header: + self.parse_header(b"GET / HTTP/1.0\r\n") + self.error = RequestHeaderFieldsTooLarge( + "exceeds max_header of %s" % max_header + ) + self.completed = True + + return consumed + + if index >= 0: + # Header finished. + header_plus = s[:index] + + # Remove preceeding blank lines. This is suggested by + # https://tools.ietf.org/html/rfc7230#section-3.5 to support + # clients sending an extra CR LF after another request when + # using HTTP pipelining + header_plus = header_plus.lstrip() + + if not header_plus: + self.empty = True + self.completed = True + else: + try: + self.parse_header(header_plus) + except ParsingError as e: + self.error = BadRequest(e.args[0]) + self.completed = True + except TransferEncodingNotImplemented as e: + self.error = ServerNotImplemented(e.args[0]) + self.completed = True + else: + if self.body_rcv is None: + # no content-length header and not a t-e: chunked + # request + self.completed = True + + if self.content_length > 0: + max_body = self.adj.max_request_body_size + # we won't accept this request if the content-length + # is too large + + if self.content_length >= max_body: + self.error = RequestEntityTooLarge( + "exceeds max_body of %s" % max_body + ) + self.completed = True + self.headers_finished = True + + return consumed + + # Header not finished yet. + self.header_plus = s + + return datalen + else: + # In body. + consumed = br.received(data) + self.body_bytes_received += consumed + max_body = self.adj.max_request_body_size + + if self.body_bytes_received >= max_body: + # this will only be raised during t-e: chunked requests + self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) + self.completed = True + elif br.error: + # garbage in chunked encoding input probably + self.error = br.error + self.completed = True + elif br.completed: + # The request (with the body) is ready to use. + self.completed = True + + if self.chunked: + # We've converted the chunked transfer encoding request + # body into a normal request body, so we know its content + # length; set the header here. We already popped the + # TRANSFER_ENCODING header in parse_header, so this will + # appear to the client to be an entirely non-chunked HTTP + # request with a valid content-length. + self.headers["CONTENT_LENGTH"] = str(br.__len__()) + + return consumed + + def parse_header(self, header_plus): + """ + Parses the header_plus block of text (the headers plus the + first line of the request). + """ + index = header_plus.find(b"\r\n") + + if index >= 0: + first_line = header_plus[:index].rstrip() + header = header_plus[index + 2 :] + else: + raise ParsingError("HTTP message header invalid") + + if b"\r" in first_line or b"\n" in first_line: + raise ParsingError("Bare CR or LF found in HTTP message") + + self.first_line = first_line # for testing + + lines = get_header_lines(header) + + headers = self.headers + + for line in lines: + header = HEADER_FIELD_RE.match(line) + + if not header: + raise ParsingError("Invalid header") + + key, value = header.group("name", "value") + + if b"_" in key: + # TODO(xistence): Should we drop this request instead? + + continue + + # Only strip off whitespace that is considered valid whitespace by + # RFC7230, don't strip the rest + value = value.strip(b" \t") + key1 = key.upper().replace(b"-", b"_").decode("latin-1") + # If a header already exists, we append subsequent values + # separated by a comma. Applications already need to handle + # the comma separated values, as HTTP front ends might do + # the concatenation for you (behavior specified in RFC2616). + try: + headers[key1] += (b", " + value).decode("latin-1") + except KeyError: + headers[key1] = value.decode("latin-1") + + # command, uri, version will be bytes + command, uri, version = crack_first_line(first_line) + # self.request_uri is like nginx's request_uri: + # "full original request URI (with arguments)" + self.request_uri = uri.decode("latin-1") + version = version.decode("latin-1") + command = command.decode("latin-1") + self.command = command + self.version = version + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + self.url_scheme = self.adj.url_scheme + connection = headers.get("CONNECTION", "") + + if version == "1.0": + if connection.lower() != "keep-alive": + self.connection_close = True + + if version == "1.1": + # since the server buffers data from chunked transfers and clients + # never need to deal with chunked requests, downstream clients + # should not see the HTTP_TRANSFER_ENCODING header; we pop it + # here + te = headers.pop("TRANSFER_ENCODING", "") + + # NB: We can not just call bare strip() here because it will also + # remove other non-printable characters that we explicitly do not + # want removed so that if someone attempts to smuggle a request + # with these characters we don't fall prey to it. + # + # For example \x85 is stripped by default, but it is not considered + # valid whitespace to be stripped by RFC7230. + encodings = [ + encoding.strip(" \t").lower() for encoding in te.split(",") if encoding + ] + + for encoding in encodings: + # Out of the transfer-codings listed in + # https://tools.ietf.org/html/rfc7230#section-4 we only support + # chunked at this time. + + # Note: the identity transfer-coding was removed in RFC7230: + # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus + # not supported + + if encoding not in {"chunked"}: + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + if encodings and encodings[-1] == "chunked": + self.chunked = True + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = ChunkedReceiver(buf) + elif encodings: # pragma: nocover + raise TransferEncodingNotImplemented( + "Transfer-Encoding requested is not supported." + ) + + expect = headers.get("EXPECT", "").lower() + self.expect_continue = expect == "100-continue" + + if connection.lower() == "close": + self.connection_close = True + + if not self.chunked: + cl = headers.get("CONTENT_LENGTH", "0") + + if not ONLY_DIGIT_RE.match(cl.encode("latin-1")): + raise ParsingError("Content-Length is invalid") + + cl = int(cl) + self.content_length = cl + + if cl > 0: + buf = OverflowableBuffer(self.adj.inbuf_overflow) + self.body_rcv = FixedStreamReceiver(cl, buf) + + def get_body_stream(self): + body_rcv = self.body_rcv + + if body_rcv is not None: + return body_rcv.getfile() + else: + return BytesIO() + + def close(self): + body_rcv = self.body_rcv + + if body_rcv is not None: + body_rcv.getbuf().close() + + +def split_uri(uri): + # urlsplit handles byte input by returning bytes on py3, so + # scheme, netloc, path, query, and fragment are bytes + + scheme = netloc = path = query = fragment = b"" + + # urlsplit below will treat this as a scheme-less netloc, thereby losing + # the original intent of the request. Here we shamelessly stole 4 lines of + # code from the CPython stdlib to parse out the fragment and query but + # leave the path alone. See + # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 + # and https://github.com/Pylons/waitress/issues/260 + + if uri[:2] == b"//": + path = uri + + if b"#" in path: + path, fragment = path.split(b"#", 1) + + if b"?" in path: + path, query = path.split(b"?", 1) + else: + try: + scheme, netloc, path, query, fragment = parse.urlsplit(uri) + except UnicodeError: + raise ParsingError("Bad URI") + + return ( + scheme.decode("latin-1"), + netloc.decode("latin-1"), + unquote_bytes_to_wsgi(path), + query.decode("latin-1"), + fragment.decode("latin-1"), + ) + + +def get_header_lines(header): + """ + Splits the header into lines, putting multi-line headers together. + """ + r = [] + lines = header.split(b"\r\n") + + for line in lines: + if not line: + continue + + if b"\r" in line or b"\n" in line: + raise ParsingError( + 'Bare CR or LF found in header line "%s"' % str(line, "latin-1") + ) + + if line.startswith((b" ", b"\t")): + if not r: + # https://corte.si/posts/code/pathod/pythonservers/index.html + raise ParsingError('Malformed header line "%s"' % str(line, "latin-1")) + r[-1] += line + else: + r.append(line) + + return r + + +first_line_re = re.compile( + b"([^ ]+) " + b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" + b"(( HTTP/([0-9.]+))$|$)" +) + + +def crack_first_line(line): + m = first_line_re.match(line) + + if m is not None and m.end() == len(line): + if m.group(3): + version = m.group(5) + else: + version = b"" + method = m.group(1) + + # the request methods that are currently defined are all uppercase: + # https://www.iana.org/assignments/http-methods/http-methods.xhtml and + # the request method is case sensitive according to + # https://tools.ietf.org/html/rfc7231#section-4.1 + + # By disallowing anything but uppercase methods we save poor + # unsuspecting souls from sending lowercase HTTP methods to waitress + # and having the request complete, while servers like nginx drop the + # request onto the floor. + + if method != method.upper(): + raise ParsingError('Malformed HTTP method "%s"' % str(method, "latin-1")) + uri = m.group(2) + + return method, uri, version + else: + return b"", b"", b"" diff --git a/src/waitress/proxy_headers.py b/src/waitress/proxy_headers.py new file mode 100644 index 00000000..652ca0bc --- /dev/null +++ b/src/waitress/proxy_headers.py @@ -0,0 +1,330 @@ +from collections import namedtuple + +from .utilities import BadRequest, logger, undquote + +PROXY_HEADERS = frozenset( + { + "X_FORWARDED_FOR", + "X_FORWARDED_HOST", + "X_FORWARDED_PROTO", + "X_FORWARDED_PORT", + "X_FORWARDED_BY", + "FORWARDED", + } +) + +Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) + + +class MalformedProxyHeader(Exception): + def __init__(self, header, reason, value): + self.header = header + self.reason = reason + self.value = value + super().__init__(header, reason, value) + + +def proxy_headers_middleware( + app, + trusted_proxy=None, + trusted_proxy_count=1, + trusted_proxy_headers=None, + clear_untrusted=True, + log_untrusted=False, + logger=logger, +): + def translate_proxy_headers(environ, start_response): + untrusted_headers = PROXY_HEADERS + remote_peer = environ["REMOTE_ADDR"] + if trusted_proxy == "*" or remote_peer == trusted_proxy: + try: + untrusted_headers = parse_proxy_headers( + environ, + trusted_proxy_count=trusted_proxy_count, + trusted_proxy_headers=trusted_proxy_headers, + logger=logger, + ) + except MalformedProxyHeader as ex: + logger.warning( + 'Malformed proxy header "%s" from "%s": %s value: %s', + ex.header, + remote_peer, + ex.reason, + ex.value, + ) + error = BadRequest(f'Header "{ex.header}" malformed.') + return error.wsgi_response(environ, start_response) + + # Clear out the untrusted proxy headers + if clear_untrusted: + clear_untrusted_headers( + environ, untrusted_headers, log_warning=log_untrusted, logger=logger + ) + + return app(environ, start_response) + + return translate_proxy_headers + + +def parse_proxy_headers( + environ, trusted_proxy_count, trusted_proxy_headers, logger=logger +): + if trusted_proxy_headers is None: + trusted_proxy_headers = set() + + forwarded_for = [] + forwarded_host = forwarded_proto = forwarded_port = forwarded = "" + client_addr = None + untrusted_headers = set(PROXY_HEADERS) + + def raise_for_multiple_values(): + raise ValueError("Unspecified behavior for multiple values found in header") + + if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: + try: + forwarded_for = [] + + for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): + forward_hop = forward_hop.strip() + forward_hop = undquote(forward_hop) + + # Make sure that all IPv6 addresses are surrounded by brackets, + # this is assuming that the IPv6 representation here does not + # include a port number. + + if "." not in forward_hop and ( + ":" in forward_hop and forward_hop[-1] != "]" + ): + forwarded_for.append(f"[{forward_hop}]") + else: + forwarded_for.append(forward_hop) + + forwarded_for = forwarded_for[-trusted_proxy_count:] + client_addr = forwarded_for[0] + + untrusted_headers.remove("X_FORWARDED_FOR") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"] + ) + + if ( + "x-forwarded-host" in trusted_proxy_headers + and "HTTP_X_FORWARDED_HOST" in environ + ): + try: + forwarded_host_multiple = [] + + for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): + forward_host = forward_host.strip() + forward_host = undquote(forward_host) + forwarded_host_multiple.append(forward_host) + + forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] + forwarded_host = forwarded_host_multiple[0] + + untrusted_headers.remove("X_FORWARDED_HOST") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"] + ) + + if "x-forwarded-proto" in trusted_proxy_headers: + try: + forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) + if "," in forwarded_proto: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PROTO") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"] + ) + + if "x-forwarded-port" in trusted_proxy_headers: + try: + forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) + if "," in forwarded_port: + raise_for_multiple_values() + untrusted_headers.remove("X_FORWARDED_PORT") + except Exception as ex: + raise MalformedProxyHeader( + "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"] + ) + + if "x-forwarded-by" in trusted_proxy_headers: + # Waitress itself does not use X-Forwarded-By, but we can not + # remove it so it can get set in the environ + untrusted_headers.remove("X_FORWARDED_BY") + + if "forwarded" in trusted_proxy_headers: + forwarded = environ.get("HTTP_FORWARDED", None) + untrusted_headers = PROXY_HEADERS - {"FORWARDED"} + + # If the Forwarded header exists, it gets priority + if forwarded: + proxies = [] + try: + for forwarded_element in forwarded.split(","): + # Remove whitespace that may have been introduced when + # appending a new entry + forwarded_element = forwarded_element.strip() + + forwarded_for = forwarded_host = forwarded_proto = "" + forwarded_port = forwarded_by = "" + + for pair in forwarded_element.split(";"): + pair = pair.lower() + + if not pair: + continue + + token, equals, value = pair.partition("=") + + if equals != "=": + raise ValueError('Invalid forwarded-pair missing "="') + + if token.strip() != token: + raise ValueError("Token may not be surrounded by whitespace") + + if value.strip() != value: + raise ValueError("Value may not be surrounded by whitespace") + + if token == "by": + forwarded_by = undquote(value) + + elif token == "for": + forwarded_for = undquote(value) + + elif token == "host": + forwarded_host = undquote(value) + + elif token == "proto": + forwarded_proto = undquote(value) + + else: + logger.warning("Unknown Forwarded token: %s" % token) + + proxies.append( + Forwarded( + forwarded_by, forwarded_for, forwarded_host, forwarded_proto + ) + ) + except Exception as ex: + raise MalformedProxyHeader("Forwarded", str(ex), environ["HTTP_FORWARDED"]) + + proxies = proxies[-trusted_proxy_count:] + + # Iterate backwards and fill in some values, the oldest entry that + # contains the information we expect is the one we use. We expect + # that intermediate proxies may re-write the host header or proto, + # but the oldest entry is the one that contains the information the + # client expects when generating URL's + # + # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" + # Forwarded: for=192.0.2.1;host="example.internal:8080" + # + # (After HTTPS header folding) should mean that we use as values: + # + # Host: example.com + # Protocol: https + # Port: 8443 + + for proxy in proxies[::-1]: + client_addr = proxy.for_ or client_addr + forwarded_host = proxy.host or forwarded_host + forwarded_proto = proxy.proto or forwarded_proto + + if forwarded_proto: + forwarded_proto = forwarded_proto.lower() + + if forwarded_proto not in {"http", "https"}: + raise MalformedProxyHeader( + "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", + "unsupported proto value", + forwarded_proto, + ) + + # Set the URL scheme to the proxy provided proto + environ["wsgi.url_scheme"] = forwarded_proto + + if not forwarded_port: + if forwarded_proto == "http": + forwarded_port = "80" + + if forwarded_proto == "https": + forwarded_port = "443" + + if forwarded_host: + if ":" in forwarded_host and forwarded_host[-1] != "]": + host, port = forwarded_host.rsplit(":", 1) + host, port = host.strip(), str(port) + + # We trust the port in the Forwarded Host/X-Forwarded-Host over + # X-Forwarded-Port, or whatever we got from Forwarded + # Proto/X-Forwarded-Proto. + + if forwarded_port != port: + forwarded_port = port + + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = host + environ["HTTP_HOST"] = forwarded_host + else: + # We trust the proxy server's forwarded Host + environ["SERVER_NAME"] = forwarded_host + environ["HTTP_HOST"] = forwarded_host + + if forwarded_port: + if forwarded_port not in {"443", "80"}: + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": + environ["HTTP_HOST"] = "{}:{}".format( + forwarded_host, forwarded_port + ) + + if forwarded_port: + environ["SERVER_PORT"] = str(forwarded_port) + + if client_addr: + if ":" in client_addr and client_addr[-1] != "]": + addr, port = client_addr.rsplit(":", 1) + environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) + environ["REMOTE_PORT"] = port.strip() + else: + environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) + environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] + + return untrusted_headers + + +def strip_brackets(addr): + if addr[0] == "[" and addr[-1] == "]": + return addr[1:-1] + return addr + + +def clear_untrusted_headers( + environ, untrusted_headers, log_warning=False, logger=logger +): + untrusted_headers_removed = [ + header + for header in untrusted_headers + if environ.pop("HTTP_" + header, False) is not False + ] + + if log_warning and untrusted_headers_removed: + untrusted_headers_removed = [ + "-".join(x.capitalize() for x in header.split("_")) + for header in untrusted_headers_removed + ] + logger.warning( + "Removed untrusted headers (%s). Waitress recommends these be " + "removed upstream.", + ", ".join(untrusted_headers_removed), + ) diff --git a/waitress/receiver.py b/src/waitress/receiver.py similarity index 62% rename from waitress/receiver.py rename to src/waitress/receiver.py index 594ae971..76633552 100644 --- a/waitress/receiver.py +++ b/src/waitress/receiver.py @@ -14,11 +14,11 @@ """Data Chunk Receiver """ -from waitress.utilities import find_double_newline +from waitress.rfc7230 import CHUNK_EXT_RE, ONLY_HEXDIG_RE +from waitress.utilities import BadRequest, find_double_newline -from waitress.utilities import BadRequest -class FixedStreamReceiver(object): +class FixedStreamReceiver: # See IStreamConsumer completed = False @@ -30,22 +30,27 @@ def __init__(self, cl, buf): def __len__(self): return self.buf.__len__() - + def received(self, data): - 'See IStreamConsumer' + "See IStreamConsumer" rm = self.remain + if rm < 1: - self.completed = True # Avoid any chance of spinning + self.completed = True # Avoid any chance of spinning + return 0 datalen = len(data) + if rm <= datalen: self.buf.append(data[:rm]) self.remain = 0 self.completed = True + return rm else: self.buf.append(data) self.remain -= datalen + return datalen def getfile(self): @@ -54,12 +59,15 @@ def getfile(self): def getbuf(self): return self.buf -class ChunkedReceiver(object): + +class ChunkedReceiver: chunk_remainder = 0 - control_line = b'' + validate_chunk_end = False + control_line = b"" + chunk_end = b"" all_chunks_received = False - trailer = b'' + trailer = b"" completed = False error = None @@ -74,44 +82,86 @@ def __len__(self): def received(self, s): # Returns the number of bytes consumed. + if self.completed: return 0 orig_size = len(s) + while s: rm = self.chunk_remainder + if rm > 0: # Receive the remainder of a chunk. to_write = s[:rm] self.buf.append(to_write) written = len(to_write) s = s[written:] + self.chunk_remainder -= written + + if self.chunk_remainder == 0: + self.validate_chunk_end = True + elif self.validate_chunk_end: + s = self.chunk_end + s + + pos = s.find(b"\r\n") + + if pos < 0 and len(s) < 2: + self.chunk_end = s + s = b"" + else: + self.chunk_end = b"" + + if pos == 0: + # Chop off the terminating CR LF from the chunk + s = s[2:] + else: + self.error = BadRequest("Chunk not properly terminated") + self.all_chunks_received = True + + # Always exit this loop + self.validate_chunk_end = False elif not self.all_chunks_received: # Receive a control line. s = self.control_line + s - pos = s.find(b'\n') + pos = s.find(b"\r\n") + if pos < 0: # Control line not finished. self.control_line = s - s = '' + s = b"" else: # Control line finished. line = s[:pos] - s = s[pos + 1:] - self.control_line = b'' - line = line.strip() + s = s[pos + 2 :] + self.control_line = b"" + if line: # Begin a new chunk. - semi = line.find(b';') + semi = line.find(b";") + if semi >= 0: - # discard extension info. + extinfo = line[semi:] + valid_ext_info = CHUNK_EXT_RE.match(extinfo) + + if not valid_ext_info: + self.error = BadRequest("Invalid chunk extension") + self.all_chunks_received = True + + break + line = line[:semi] - try: - sz = int(line.strip(), 16) # hexadecimal - except ValueError: # garbage in input - self.error = BadRequest( - 'garbage in chunked encoding input') - sz = 0 + + if not ONLY_HEXDIG_RE.match(line): + self.error = BadRequest("Invalid chunk size") + self.all_chunks_received = True + + break + + # Can not fail due to matching against the regular + # expression above + sz = int(line, 16) # hexadecimal + if sz > 0: # Start a new chunk. self.chunk_remainder = sz @@ -122,24 +172,25 @@ def received(self, s): else: # Receive the trailer. trailer = self.trailer + s - if trailer.startswith(b'\r\n'): + + if trailer.startswith(b"\r\n"): # No trailer. self.completed = True + return orig_size - (len(trailer) - 2) - elif trailer.startswith(b'\n'): - # No trailer. - self.completed = True - return orig_size - (len(trailer) - 1) pos = find_double_newline(trailer) + if pos < 0: # Trailer not finished. self.trailer = trailer - s = b'' + s = b"" else: # Finished the trailer. self.completed = True self.trailer = trailer[:pos] + return orig_size - (len(trailer) - pos) + return orig_size def getfile(self): diff --git a/src/waitress/rfc7230.py b/src/waitress/rfc7230.py new file mode 100644 index 00000000..26e64260 --- /dev/null +++ b/src/waitress/rfc7230.py @@ -0,0 +1,75 @@ +""" +This contains a bunch of RFC7230 definitions and regular expressions that are +needed to properly parse HTTP messages. +""" + +import re + +HEXDIG = "[0-9a-fA-F]" +DIGIT = "[0-9]" + +WS = "[ \t]" +OWS = WS + "{0,}?" +RWS = WS + "{1,}?" +BWS = OWS + +# RFC 7230 Section 3.2.6 "Field Value Components": +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# obs-text = %x80-FF +TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" +OBS_TEXT = r"\x80-\xff" + +TOKEN = TCHAR + "{1,}" + +# RFC 5234 Appendix B.1 "Core Rules": +# VCHAR = %x21-7E +# ; visible (printing) characters +VCHAR = r"\x21-\x7e" + +# The '\\' between \x5b and \x5d is needed to escape \x5d (']') +QDTEXT = "[\t \x21\x23-\x5b\\\x5d-\x7e" + OBS_TEXT + "]" + +QUOTED_PAIR = r"\\" + "([\t " + VCHAR + OBS_TEXT + "])" +QUOTED_STRING = '"(?:(?:' + QDTEXT + ")|(?:" + QUOTED_PAIR + '))*"' + +# header-field = field-name ":" OWS field-value OWS +# field-name = token +# field-value = *( field-content / obs-fold ) +# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +# field-vchar = VCHAR / obs-text + +# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 +# changes field-content to: +# +# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) +# field-vchar ] + +FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" +# Field content is more greedy than the ABNF, in that it will match the whole value +FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" +# Which allows the field value here to just see if there is even a value in the first place +FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" + +# chunk-ext = *( ";" chunk-ext-name [ "=" chunk-ext-val ] ) +# chunk-ext-name = token +# chunk-ext-val = token / quoted-string + +CHUNK_EXT_NAME = TOKEN +CHUNK_EXT_VAL = "(?:" + TOKEN + ")|(?:" + QUOTED_STRING + ")" +CHUNK_EXT = ( + "(?:;(?P" + CHUNK_EXT_NAME + ")(?:=(?P" + CHUNK_EXT_VAL + "))?)*" +) + +# Pre-compiled regular expressions for use elsewhere +ONLY_HEXDIG_RE = re.compile(("^" + HEXDIG + "+$").encode("latin-1")) +ONLY_DIGIT_RE = re.compile(("^" + DIGIT + "+$").encode("latin-1")) +HEADER_FIELD_RE = re.compile( + ( + "^(?P" + TOKEN + "):" + OWS + "(?P" + FIELD_VALUE + ")" + OWS + "$" + ).encode("latin-1") +) +QUOTED_PAIR_RE = re.compile(QUOTED_PAIR) +QUOTED_STRING_RE = re.compile(QUOTED_STRING) +CHUNK_EXT_RE = re.compile(("^" + CHUNK_EXT + "$").encode("latin-1")) diff --git a/waitress/runner.py b/src/waitress/runner.py similarity index 77% rename from waitress/runner.py rename to src/waitress/runner.py index abdb38e8..22f70f38 100644 --- a/waitress/runner.py +++ b/src/waitress/runner.py @@ -14,9 +14,9 @@ """Command line runner. """ -from __future__ import print_function, unicode_literals import getopt +import logging import os import os.path import re @@ -24,6 +24,7 @@ from waitress import serve from waitress.adjustments import Adjustments +from waitress.utilities import logger HELP = """\ Usage: @@ -58,7 +59,7 @@ --listen=[::1]:8080 --listen=*:8080 - This option may be used multiple times to listen on multipe sockets. + This option may be used multiple times to listen on multiple sockets. A wildcard for the hostname is also supported and will bind to both IPv4/IPv6 depending on whether they are enabled or disabled. @@ -95,7 +96,7 @@ --url-scheme=STR Default wsgi.url_scheme value, default is 'http'. - --url-prefix=STR + --url-prefix=STR The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be the value passed minus any trailing slashes you add, and it will cause @@ -126,12 +127,17 @@ A temporary file should be created if the pending output is larger than this. Default is 1048576 (1MB). + --outbuf-high-watermark=INT + The app_iter will pause when pending output is larger than this value + and will resume once enough data is written to the socket to fall below + this threshold. Default is 16777216 (16MB). + --inbuf-overflow=INT A temporary file should be created if the pending input is larger than this. Default is 524288 (512KB). --connection-limit=INT - Stop creating new channelse if too many are already active. + Stop creating new channels if too many are already active. Default is 100. --cleanup-interval=INT @@ -140,11 +146,11 @@ --channel-timeout=INT Maximum number of seconds to leave inactive connections open. - Default is 120. 'Inactive' is defined as 'has recieved no data + Default is 120. 'Inactive' is defined as 'has received no data from the client and has sent no data to the client'. --[no-]log-socket-errors - Toggle whether premature client disconnect tracepacks ought to be + Toggle whether premature client disconnect tracebacks ought to be logged. On by default. --max-request-header-size=INT @@ -165,9 +171,16 @@ The use_poll argument passed to ``asyncore.loop()``. Helps overcome open file descriptors limit. Default is False. + --channel-request-lookahead=INT + Allows channels to stay readable and buffer more requests up to the + given maximum even if a request is already being processed. This allows + detecting if a client closed the connection while its request is being + processed. Default is 0. + """ -RUNNER_PATTERN = re.compile(r""" +RUNNER_PATTERN = re.compile( + r""" ^ (?P [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* @@ -177,13 +190,17 @@ [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* ) $ - """, re.I | re.X) + """, + re.I | re.X, +) + def match(obj_name): matches = RUNNER_PATTERN.match(obj_name) if not matches: - raise ValueError("Malformed application '{0}'".format(obj_name)) - return matches.group('module'), matches.group('object') + raise ValueError(f"Malformed application '{obj_name}'") + return matches.group("module"), matches.group("object") + def resolve(module_name, object_name): """Resolve a named object in a module.""" @@ -197,34 +214,35 @@ def resolve(module_name, object_name): # but I've yet to go over the commits. I know, however, that the NEWS # file makes no mention of such a change to the behaviour of # ``__import__``. - segments = [str(segment) for segment in object_name.split('.')] + segments = [str(segment) for segment in object_name.split(".")] obj = __import__(module_name, fromlist=segments[:1]) for segment in segments: obj = getattr(obj, segment) return obj -def show_help(stream, name, error=None): # pragma: no cover + +def show_help(stream, name, error=None): # pragma: no cover if error is not None: - print('Error: {0}\n'.format(error), file=stream) + print(f"Error: {error}\n", file=stream) print(HELP.format(name), file=stream) + def show_exception(stream): exc_type, exc_value = sys.exc_info()[:2] - args = getattr(exc_value, 'args', None) + args = getattr(exc_value, "args", None) print( - ( - 'There was an exception ({0}) importing your module.\n' - ).format( + ("There was an exception ({}) importing your module.\n").format( exc_type.__name__, ), - file=stream + file=stream, ) if args: - print('It had these arguments: ', file=stream) + print("It had these arguments: ", file=stream) for idx, arg in enumerate(args, start=1): - print('{0}. {1}\n'.format(idx, arg), file=stream) + print(f"{idx}. {arg}\n", file=stream) else: - print('It had no arguments.', file=stream) + print("It had no arguments.", file=stream) + def run(argv=sys.argv, _serve=serve): """Command line runner.""" @@ -236,14 +254,20 @@ def run(argv=sys.argv, _serve=serve): show_help(sys.stderr, name, str(exc)) return 1 - if kw['help']: + if kw["help"]: show_help(sys.stdout, name) return 0 if len(args) != 1: - show_help(sys.stderr, name, 'Specify one application only') + show_help(sys.stderr, name, "Specify one application only") return 1 + # set a default level for the logger only if it hasn't been set explicitly + # note that this level does not override any parent logger levels, + # handlers, etc but without it no log messages are emitted by default + if logger.level == logging.NOTSET: + logger.setLevel(logging.INFO) + try: module, obj_name = match(args[0]) except ValueError as exc: @@ -258,18 +282,18 @@ def run(argv=sys.argv, _serve=serve): try: app = resolve(module, obj_name) except ImportError: - show_help(sys.stderr, name, "Bad module '{0}'".format(module)) + show_help(sys.stderr, name, f"Bad module '{module}'") show_exception(sys.stderr) return 1 except AttributeError: - show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name)) + show_help(sys.stderr, name, f"Bad object name '{obj_name}'") show_exception(sys.stderr) return 1 - if kw['call']: + if kw["call"]: app = app() # These arguments are specific to the runner, not waitress itself. - del kw['call'], kw['help'] + del kw["call"], kw["help"] _serve(app, **kw) return 0 diff --git a/waitress/server.py b/src/waitress/server.py similarity index 52% rename from waitress/server.py rename to src/waitress/server.py index 79aa9b75..0a0f876b 100644 --- a/waitress/server.py +++ b/src/waitress/server.py @@ -12,7 +12,6 @@ # ############################################################################## -import asyncore import os import os.path import socket @@ -21,23 +20,22 @@ from waitress import trigger from waitress.adjustments import Adjustments from waitress.channel import HTTPChannel +from waitress.compat import IPPROTO_IPV6, IPV6_V6ONLY from waitress.task import ThreadedTaskDispatcher -from waitress.utilities import ( - cleanup_unix_socket, - logging_dispatcher, - ) -from waitress.compat import ( - IPPROTO_IPV6, - IPV6_V6ONLY, - ) - -def create_server(application, - map=None, - _start=True, # test shim - _sock=None, # test shim - _dispatcher=None, # test shim - **kw # adjustments - ): +from waitress.utilities import cleanup_unix_socket + +from . import wasyncore +from .proxy_headers import proxy_headers_middleware + + +def create_server( + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + _dispatcher=None, # test shim + **kw # adjustments +): """ if __name__ == '__main__': server = create_server(app) @@ -46,11 +44,11 @@ def create_server(application, if application is None: raise ValueError( 'The "app" passed to ``create_server`` was ``None``. You forgot ' - 'to return a WSGI app within your application.' - ) + "to return a WSGI app within your application." + ) adj = Adjustments(**kw) - if map is None: # pragma: nocover + if map is None: # pragma: nocover map = {} dispatcher = _dispatcher @@ -58,7 +56,7 @@ def create_server(application, dispatcher = ThreadedTaskDispatcher() dispatcher.set_thread_count(adj.threads) - if adj.unix_socket and hasattr(socket, 'AF_UNIX'): + if adj.unix_socket and hasattr(socket, "AF_UNIX"): sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) return UnixWSGIServer( application, @@ -67,61 +65,101 @@ def create_server(application, _sock, dispatcher=dispatcher, adj=adj, - sockinfo=sockinfo) + sockinfo=sockinfo, + ) effective_listen = [] last_serv = None - for sockinfo in adj.listen: - # When TcpWSGIServer is called, it registers itself in the map. This - # side-effect is all we need it for, so we don't store a reference to - # or return it to the user. - last_serv = TcpWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo) - effective_listen.append((last_serv.effective_host, last_serv.effective_port)) + if not adj.sockets: + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + + for sock in adj.sockets: + sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) + if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: + last_serv = TcpWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) + elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + last_serv = UnixWSGIServer( + application, + map, + _start, + sock, + dispatcher=dispatcher, + adj=adj, + bind_socket=False, + sockinfo=sockinfo, + ) + effective_listen.append( + (last_serv.effective_host, last_serv.effective_port) + ) # We are running a single server, so we can just return the last server, # saves us from having to create one more object - if len(adj.listen) == 1: + if len(effective_listen) == 1: # In this case we have no need to use a MultiSocketServer return last_serv + log_info = last_serv.log_info # Return a class that has a utility function to print out the sockets it's # listening on, and has a .run() function. All of the TcpWSGIServers # registered themselves in the map above. - return MultiSocketServer(map, adj, effective_listen, dispatcher) + return MultiSocketServer(map, adj, effective_listen, dispatcher, log_info) # This class is only ever used if we have multiple listen sockets. It allows -# the serve() API to call .run() which starts the asyncore loop, and catches +# the serve() API to call .run() which starts the wasyncore loop, and catches # SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. -class MultiSocketServer(object): - asyncore = asyncore # test shim - - def __init__(self, - map=None, - adj=None, - effective_listen=None, - dispatcher=None, - ): +class MultiSocketServer: + asyncore = wasyncore # test shim + + def __init__( + self, + map=None, + adj=None, + effective_listen=None, + dispatcher=None, + log_info=None, + ): self.adj = adj self.map = map self.effective_listen = effective_listen self.task_dispatcher = dispatcher + self.log_info = log_info - def print_listen(self, format_str): # pragma: nocover + def print_listen(self, format_str): # pragma: nocover for l in self.effective_listen: l = list(l) - if ':' in l[0]: - l[0] = '[{}]'.format(l[0]) + if ":" in l[0]: + l[0] = f"[{l[0]}]" - print(format_str.format(*l)) + self.log_info(format_str.format(*l)) def run(self): try: @@ -131,31 +169,53 @@ def run(self): use_poll=self.adj.asyncore_use_poll, ) except (SystemExit, KeyboardInterrupt): - self.task_dispatcher.shutdown() + self.close() + + def close(self): + self.task_dispatcher.shutdown() + wasyncore.close_all(self.map) -class BaseWSGIServer(logging_dispatcher, object): +class BaseWSGIServer(wasyncore.dispatcher): channel_class = HTTPChannel next_channel_cleanup = 0 - socketmod = socket # test shim - asyncore = asyncore # test shim - - def __init__(self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - **kw - ): + socketmod = socket # test shim + asyncore = wasyncore # test shim + in_connection_overflow = False + + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + bind_socket=True, + **kw + ): if adj is None: adj = Adjustments(**kw) + + if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: + # wrap the application to deal with proxy headers + # we wrap it here because webtest subclasses the TcpWSGIServer + # directly and thus doesn't run any code that's in create_server + application = proxy_headers_middleware( + application, + trusted_proxy=adj.trusted_proxy, + trusted_proxy_count=adj.trusted_proxy_count, + trusted_proxy_headers=adj.trusted_proxy_headers, + clear_untrusted=adj.clear_untrusted_proxy_headers, + log_untrusted=adj.log_untrusted_proxy_headers, + logger=self.logger, + ) + if map is None: # use a nonglobal socket map by default to hopefully prevent - # conflicts with apps and libs that use the asyncore global socket + # conflicts with apps and libs that use the wasyncore global socket # map ala https://github.com/Pylons/waitress/issues/63 map = {} if sockinfo is None: @@ -175,45 +235,29 @@ def __init__(self, self.asyncore.dispatcher.__init__(self, _sock, map=map) if _sock is None: self.create_socket(self.family, self.socktype) - if self.family == socket.AF_INET6: # pragma: nocover + if self.family == socket.AF_INET6: # pragma: nocover self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) self.set_reuse_addr() - self.bind_server_socket() + + if bind_socket: + self.bind_server_socket() + self.effective_host, self.effective_port = self.getsockname() - self.server_name = self.get_server_name(self.effective_host) + self.server_name = adj.server_name self.active_channels = {} if _start: self.accept_connections() def bind_server_socket(self): - raise NotImplementedError # pragma: no cover - - def get_server_name(self, ip): - """Given an IP or hostname, try to determine the server name.""" - if ip: - server_name = str(ip) - else: - server_name = str(self.socketmod.gethostname()) - - # Convert to a host name if necessary. - for c in server_name: - if c != '.' and not c.isdigit(): - return server_name - try: - if server_name == '0.0.0.0' or server_name == '::': - return 'localhost' - server_name = self.socketmod.gethostbyaddr(server_name)[0] - except socket.error: # pragma: no cover - pass - return server_name + raise NotImplementedError # pragma: no cover def getsockname(self): - raise NotImplementedError # pragma: no cover + raise NotImplementedError # pragma: no cover def accept_connections(self): self.accepting = True - self.socket.listen(self.adj.backlog) # Get around asyncore NT limit + self.socket.listen(self.adj.backlog) # Get around asyncore NT limit def add_task(self, task): self.task_dispatcher.add_task(task) @@ -223,7 +267,28 @@ def readable(self): if now >= self.next_channel_cleanup: self.next_channel_cleanup = now + self.adj.cleanup_interval self.maintenance(now) - return (self.accepting and len(self._map) < self.adj.connection_limit) + + if self.accepting: + if ( + not self.in_connection_overflow + and len(self._map) >= self.adj.connection_limit + ): + self.in_connection_overflow = True + self.logger.warning( + "total open connections reached the connection limit, " + "no longer accepting new connections" + ) + elif ( + self.in_connection_overflow + and len(self._map) < self.adj.connection_limit + ): + self.in_connection_overflow = False + self.logger.info( + "total open connections dropped below the connection limit, " + "listening again" + ) + return not self.in_connection_overflow + return False def writable(self): return False @@ -240,14 +305,13 @@ def handle_accept(self): if v is None: return conn, addr = v - except socket.error: + except OSError: # Linux: On rare occasions we get a bogus socket back from # accept. socketmodule.c:makesockaddr complains that the # address family is unknown. We don't want the whole server # to shut down because of this. if self.adj.log_socket_errors: - self.logger.warning('server accept() threw an exception', - exc_info=True) + self.logger.warning("server accept() threw an exception", exc_info=True) return self.set_socket_options(conn) addr = self.fix_addr(addr) @@ -283,55 +347,49 @@ def maintenance(self, now): if (not channel.requests) and channel.last_activity < cutoff: channel.will_close = True - def print_listen(self, format_str): # pragma: nocover - print(format_str.format(self.effective_host, self.effective_port)) + def print_listen(self, format_str): # pragma: no cover + self.log_info(format_str.format(self.effective_host, self.effective_port)) + def close(self): + self.trigger.close() + return wasyncore.dispatcher.close(self) -class TcpWSGIServer(BaseWSGIServer): +class TcpWSGIServer(BaseWSGIServer): def bind_server_socket(self): (_, _, _, sockaddr) = self.sockinfo self.bind(sockaddr) def getsockname(self): - try: - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICSERV - ) - except: # pragma: no cover - # This only happens on Linux because a DNS issue is considered a - # temporary failure that will raise (even when NI_NAMEREQD is not - # set). Instead we try again, but this time we just ask for the - # numerichost and the numericserv (port) and return those. It is - # better than nothing. - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV - ) + # Return the IP address, port as numeric + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, + ) def set_socket_options(self, conn): for (level, optname, value) in self.adj.socket_options: conn.setsockopt(level, optname, value) -if hasattr(socket, 'AF_UNIX'): +if hasattr(socket, "AF_UNIX"): class UnixWSGIServer(BaseWSGIServer): - - def __init__(self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - **kw): + def __init__( + self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw + ): if sockinfo is None: sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) - super(UnixWSGIServer, self).__init__( + super().__init__( application, map=map, _start=_start, @@ -339,7 +397,8 @@ def __init__(self, dispatcher=dispatcher, adj=adj, sockinfo=sockinfo, - **kw) + **kw, + ) def bind_server_socket(self): cleanup_unix_socket(self.adj.unix_socket) @@ -348,10 +407,11 @@ def bind_server_socket(self): os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) def getsockname(self): - return ('unix', self.socket.getsockname()) + return ("unix", self.socket.getsockname()) def fix_addr(self, addr): - return ('localhost', None) + return ("localhost", None) + # Compatibility alias. WSGIServer = TcpWSGIServer diff --git a/src/waitress/task.py b/src/waitress/task.py new file mode 100644 index 00000000..574532fa --- /dev/null +++ b/src/waitress/task.py @@ -0,0 +1,571 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## + +from collections import deque +import socket +import sys +import threading +import time + +from .buffers import ReadOnlyFileBasedBuffer +from .utilities import build_http_date, logger, queue_logger + +rename_headers = { # or keep them without the HTTP_ prefix added + "CONTENT_LENGTH": "CONTENT_LENGTH", + "CONTENT_TYPE": "CONTENT_TYPE", +} + +hop_by_hop = frozenset( + ( + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + ) +) + + +class ThreadedTaskDispatcher: + """A Task Dispatcher that creates a thread for each task.""" + + stop_count = 0 # Number of threads that will stop soon. + active_count = 0 # Number of currently active threads + logger = logger + queue_logger = queue_logger + + def __init__(self): + self.threads = set() + self.queue = deque() + self.lock = threading.Lock() + self.queue_cv = threading.Condition(self.lock) + self.thread_exit_cv = threading.Condition(self.lock) + + def start_new_thread(self, target, thread_no): + t = threading.Thread( + target=target, name=f"waitress-{thread_no}", args=(thread_no,) + ) + t.daemon = True + t.start() + + def handler_thread(self, thread_no): + while True: + with self.lock: + while not self.queue and self.stop_count == 0: + # Mark ourselves as idle before waiting to be + # woken up, then we will once again be active + self.active_count -= 1 + self.queue_cv.wait() + self.active_count += 1 + + if self.stop_count > 0: + self.active_count -= 1 + self.stop_count -= 1 + self.threads.discard(thread_no) + self.thread_exit_cv.notify() + break + + task = self.queue.popleft() + try: + task.service() + except BaseException: + self.logger.exception("Exception when servicing %r", task) + + def set_thread_count(self, count): + with self.lock: + threads = self.threads + thread_no = 0 + running = len(threads) - self.stop_count + while running < count: + # Start threads. + while thread_no in threads: + thread_no = thread_no + 1 + threads.add(thread_no) + running += 1 + self.start_new_thread(self.handler_thread, thread_no) + self.active_count += 1 + thread_no = thread_no + 1 + if running > count: + # Stop threads. + self.stop_count += running - count + self.queue_cv.notify_all() + + def add_task(self, task): + with self.lock: + self.queue.append(task) + self.queue_cv.notify() + queue_size = len(self.queue) + idle_threads = len(self.threads) - self.stop_count - self.active_count + if queue_size > idle_threads: + self.queue_logger.warning( + "Task queue depth is %d", queue_size - idle_threads + ) + + def shutdown(self, cancel_pending=True, timeout=5): + self.set_thread_count(0) + # Ensure the threads shut down. + threads = self.threads + expiration = time.time() + timeout + with self.lock: + while threads: + if time.time() >= expiration: + self.logger.warning("%d thread(s) still running", len(threads)) + break + self.thread_exit_cv.wait(0.1) + if cancel_pending: + # Cancel remaining tasks. + queue = self.queue + if len(queue) > 0: + self.logger.warning("Canceling %d pending task(s)", len(queue)) + while queue: + task = queue.popleft() + task.cancel() + self.queue_cv.notify_all() + return True + return False + + +class Task: + close_on_finish = False + status = "200 OK" + wrote_header = False + start_time = 0 + content_length = None + content_bytes_written = 0 + logged_write_excess = False + logged_write_no_body = False + complete = False + chunked_response = False + logger = logger + + def __init__(self, channel, request): + self.channel = channel + self.request = request + self.response_headers = [] + version = request.version + if version not in ("1.0", "1.1"): + # fall back to a version we support. + version = "1.0" + self.version = version + + def service(self): + try: + self.start() + self.execute() + self.finish() + except OSError: + self.close_on_finish = True + if self.channel.adj.log_socket_errors: + raise + + @property + def has_body(self): + return not ( + self.status.startswith("1") + or self.status.startswith("204") + or self.status.startswith("304") + ) + + def build_response_header(self): + version = self.version + # Figure out whether the connection should be closed. + connection = self.request.headers.get("CONNECTION", "").lower() + response_headers = [] + content_length_header = None + date_header = None + server_header = None + connection_close_header = None + + for (headername, headerval) in self.response_headers: + headername = "-".join([x.capitalize() for x in headername.split("-")]) + + if headername == "Content-Length": + if self.has_body: + content_length_header = headerval + else: + continue # pragma: no cover + + if headername == "Date": + date_header = headerval + + if headername == "Server": + server_header = headerval + + if headername == "Connection": + connection_close_header = headerval.lower() + # replace with properly capitalized version + response_headers.append((headername, headerval)) + + if ( + content_length_header is None + and self.content_length is not None + and self.has_body + ): + content_length_header = str(self.content_length) + response_headers.append(("Content-Length", content_length_header)) + + def close_on_finish(): + if connection_close_header is None: + response_headers.append(("Connection", "close")) + self.close_on_finish = True + + if version == "1.0": + if connection == "keep-alive": + if not content_length_header: + close_on_finish() + else: + response_headers.append(("Connection", "Keep-Alive")) + else: + close_on_finish() + + elif version == "1.1": + if connection == "close": + close_on_finish() + + if not content_length_header: + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + + if self.has_body: + response_headers.append(("Transfer-Encoding", "chunked")) + self.chunked_response = True + + if not self.close_on_finish: + close_on_finish() + + # under HTTP 1.1 keep-alive is default, no need to set the header + else: + raise AssertionError("neither HTTP/1.0 or HTTP/1.1") + + # Set the Server and Date field, if not yet specified. This is needed + # if the server is used as a proxy. + ident = self.channel.server.adj.ident + + if not server_header: + if ident: + response_headers.append(("Server", ident)) + else: + response_headers.append(("Via", ident or "waitress")) + + if not date_header: + response_headers.append(("Date", build_http_date(self.start_time))) + + self.response_headers = response_headers + + first_line = f"HTTP/{self.version} {self.status}" + # NB: sorting headers needs to preserve same-named-header order + # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; + # rely on stable sort to keep relative position of same-named headers + next_lines = [ + "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) + ] + lines = [first_line] + next_lines + res = "%s\r\n\r\n" % "\r\n".join(lines) + + return res.encode("latin-1") + + def remove_content_length_header(self): + response_headers = [] + + for header_name, header_value in self.response_headers: + if header_name.lower() == "content-length": + continue # pragma: nocover + response_headers.append((header_name, header_value)) + + self.response_headers = response_headers + + def start(self): + self.start_time = time.time() + + def finish(self): + if not self.wrote_header: + self.write(b"") + if self.chunked_response: + # not self.write, it will chunk it! + self.channel.write_soon(b"0\r\n\r\n") + + def write(self, data): + if not self.complete: + raise RuntimeError("start_response was not called before body written") + channel = self.channel + if not self.wrote_header: + rh = self.build_response_header() + channel.write_soon(rh) + self.wrote_header = True + + if data and self.has_body: + towrite = data + cl = self.content_length + if self.chunked_response: + # use chunked encoding response + towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n" + towrite += data + b"\r\n" + elif cl is not None: + towrite = data[: cl - self.content_bytes_written] + self.content_bytes_written += len(towrite) + if towrite != data and not self.logged_write_excess: + self.logger.warning( + "application-written content exceeded the number of " + "bytes specified by Content-Length header (%s)" % cl + ) + self.logged_write_excess = True + if towrite: + channel.write_soon(towrite) + elif data: + # Cheat, and tell the application we have written all of the bytes, + # even though the response shouldn't have a body and we are + # ignoring it entirely. + self.content_bytes_written += len(data) + + if not self.logged_write_no_body: + self.logger.warning( + "application-written content was ignored due to HTTP " + "response that may not contain a message-body: (%s)" % self.status + ) + self.logged_write_no_body = True + + +class ErrorTask(Task): + """An error task produces an error response""" + + complete = True + + def execute(self): + e = self.request.error + status, headers, body = e.to_response() + self.status = status + self.response_headers.extend(headers) + # We need to explicitly tell the remote client we are closing the + # connection, because self.close_on_finish is set, and we are going to + # slam the door in the clients face. + self.response_headers.append(("Connection", "close")) + self.close_on_finish = True + self.content_length = len(body) + self.write(body) + + +class WSGITask(Task): + """A WSGI task produces a response from a WSGI application.""" + + environ = None + + def execute(self): + environ = self.get_environment() + + def start_response(status, headers, exc_info=None): + if self.complete and not exc_info: + raise AssertionError( + "start_response called a second time without providing exc_info." + ) + if exc_info: + try: + if self.wrote_header: + # higher levels will catch and handle raised exception: + # 1. "service" method in task.py + # 2. "service" method in channel.py + # 3. "handler_thread" method in task.py + raise exc_info[1] + else: + # As per WSGI spec existing headers must be cleared + self.response_headers = [] + finally: + exc_info = None + + self.complete = True + + if not status.__class__ is str: + raise AssertionError("status %s is not a string" % status) + if "\n" in status or "\r" in status: + raise ValueError( + "carriage return/line feed character present in status" + ) + + self.status = status + + # Prepare the headers for output + for k, v in headers: + if not k.__class__ is str: + raise AssertionError( + f"Header name {k!r} is not a string in {(k, v)!r}" + ) + if not v.__class__ is str: + raise AssertionError( + f"Header value {v!r} is not a string in {(k, v)!r}" + ) + + if "\n" in v or "\r" in v: + raise ValueError( + "carriage return/line feed character present in header value" + ) + if "\n" in k or "\r" in k: + raise ValueError( + "carriage return/line feed character present in header name" + ) + + kl = k.lower() + if kl == "content-length": + self.content_length = int(v) + elif kl in hop_by_hop: + raise AssertionError( + '%s is a "hop-by-hop" header; it cannot be used by ' + "a WSGI application (see PEP 3333)" % k + ) + + self.response_headers.extend(headers) + + # Return a method used to write the response data. + return self.write + + # Call the application to handle the request and write a response + app_iter = self.channel.server.application(environ, start_response) + + can_close_app_iter = True + try: + if app_iter.__class__ is ReadOnlyFileBasedBuffer: + cl = self.content_length + size = app_iter.prepare(cl) + if size: + if cl != size: + if cl is not None: + self.remove_content_length_header() + self.content_length = size + self.write(b"") # generate headers + # if the write_soon below succeeds then the channel will + # take over closing the underlying file via the channel's + # _flush_some or handle_close so we intentionally avoid + # calling close in the finally block + self.channel.write_soon(app_iter) + can_close_app_iter = False + return + + first_chunk_len = None + for chunk in app_iter: + if first_chunk_len is None: + first_chunk_len = len(chunk) + # Set a Content-Length header if one is not supplied. + # start_response may not have been called until first + # iteration as per PEP, so we must reinterrogate + # self.content_length here + if self.content_length is None: + app_iter_len = None + if hasattr(app_iter, "__len__"): + app_iter_len = len(app_iter) + if app_iter_len == 1: + self.content_length = first_chunk_len + # transmit headers only after first iteration of the iterable + # that returns a non-empty bytestring (PEP 3333) + if chunk: + self.write(chunk) + + cl = self.content_length + if cl is not None: + if self.content_bytes_written != cl: + # close the connection so the client isn't sitting around + # waiting for more data when there are too few bytes + # to service content-length + self.close_on_finish = True + if self.request.command != "HEAD": + self.logger.warning( + "application returned too few bytes (%s) " + "for specified Content-Length (%s) via app_iter" + % (self.content_bytes_written, cl), + ) + finally: + if can_close_app_iter and hasattr(app_iter, "close"): + app_iter.close() + + def get_environment(self): + """Returns a WSGI environment.""" + environ = self.environ + if environ is not None: + # Return the cached copy. + return environ + + request = self.request + path = request.path + channel = self.channel + server = channel.server + url_prefix = server.adj.url_prefix + + if path.startswith("/"): + # strip extra slashes at the beginning of a path that starts + # with any number of slashes + path = "/" + path.lstrip("/") + + if url_prefix: + # NB: url_prefix is guaranteed by the configuration machinery to + # be either the empty string or a string that starts with a single + # slash and ends without any slashes + if path == url_prefix: + # if the path is the same as the url prefix, the SCRIPT_NAME + # should be the url_prefix and PATH_INFO should be empty + path = "" + else: + # if the path starts with the url prefix plus a slash, + # the SCRIPT_NAME should be the url_prefix and PATH_INFO should + # the value of path from the slash until its end + url_prefix_with_trailing_slash = url_prefix + "/" + if path.startswith(url_prefix_with_trailing_slash): + path = path[len(url_prefix) :] + + environ = { + "REMOTE_ADDR": channel.addr[0], + # Nah, we aren't actually going to look up the reverse DNS for + # REMOTE_ADDR, but we will happily set this environment variable + # for the WSGI application. Spec says we can just set this to + # REMOTE_ADDR, so we do. + "REMOTE_HOST": channel.addr[0], + # try and set the REMOTE_PORT to something useful, but maybe None + "REMOTE_PORT": str(channel.addr[1]), + "REQUEST_METHOD": request.command.upper(), + "SERVER_PORT": str(server.effective_port), + "SERVER_NAME": server.server_name, + "SERVER_SOFTWARE": server.adj.ident, + "SERVER_PROTOCOL": "HTTP/%s" % self.version, + "SCRIPT_NAME": url_prefix, + "PATH_INFO": path, + "REQUEST_URI": request.request_uri, + "QUERY_STRING": request.query, + "wsgi.url_scheme": request.url_scheme, + # the following environment variables are required by the WSGI spec + "wsgi.version": (1, 0), + # apps should use the logging module + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.input": request.get_body_stream(), + "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, + "wsgi.input_terminated": True, # wsgi.input is EOF terminated + } + + for key, value in dict(request.headers).items(): + value = value.strip() + mykey = rename_headers.get(key, None) + if mykey is None: + mykey = "HTTP_" + key + if mykey not in environ: + environ[mykey] = value + + # Insert a callable into the environment that allows the application to + # check if the client disconnected. Only works with + # channel_request_lookahead larger than 0. + environ["waitress.client_disconnected"] = self.channel.check_client_disconnected + + # cache the environ for this request + self.environ = environ + return environ diff --git a/waitress/trigger.py b/src/waitress/trigger.py similarity index 81% rename from waitress/trigger.py rename to src/waitress/trigger.py index cac8e264..73ac31c3 100644 --- a/waitress/trigger.py +++ b/src/waitress/trigger.py @@ -12,12 +12,13 @@ # ############################################################################## -import asyncore +import errno import os import socket -import errno import threading +from . import wasyncore + # Wake up a call to select() running in the main thread. # # This is useful in a context where you are using Medusa's I/O @@ -48,10 +49,11 @@ # new data onto a channel's outgoing data queue at the same time that # the main thread is trying to remove some] -class _triggerbase(object): + +class _triggerbase: """OS-independent base class for OS-dependent trigger class.""" - kind = None # subclass must set to "pipe" or "loopback"; used by repr + kind = None # subclass must set to "pipe" or "loopback"; used by repr def __init__(self): self._closed = False @@ -61,7 +63,7 @@ def __init__(self): self.lock = threading.Lock() # List of no-argument callbacks to invoke when the trigger is - # pulled. These run in the thread running the asyncore mainloop, + # pulled. These run in the thread running the wasyncore mainloop, # regardless of which thread pulls the trigger. self.thunks = [] @@ -77,7 +79,7 @@ def handle_connect(self): def handle_close(self): self.close() - # Override the asyncore close() method, because it doesn't know about + # Override the wasyncore close() method, because it doesn't know about # (so can't close) all the gimmicks we have open. Subclass must # supply a _close() method to do platform-specific closing work. _close() # will be called iff we're not already closed. @@ -85,7 +87,7 @@ def close(self): if not self._closed: self._closed = True self.del_channel() - self._close() # subclass does OS-specific stuff + self._close() # subclass does OS-specific stuff def pull_trigger(self, thunk=None): if thunk: @@ -96,42 +98,42 @@ def pull_trigger(self, thunk=None): def handle_read(self): try: self.recv(8192) - except (OSError, socket.error): + except OSError: return with self.lock: for thunk in self.thunks: try: thunk() except: - nil, t, v, tbinfo = asyncore.compact_traceback() - self.log_info( - 'exception in trigger thunk: (%s:%s %s)' % - (t, v, tbinfo)) + nil, t, v, tbinfo = wasyncore.compact_traceback() + self.log_info(f"exception in trigger thunk: ({t}:{v} {tbinfo})") self.thunks = [] -if os.name == 'posix': - class trigger(_triggerbase, asyncore.file_dispatcher): +if os.name == "posix": + + class trigger(_triggerbase, wasyncore.file_dispatcher): kind = "pipe" def __init__(self, map): _triggerbase.__init__(self) r, self.trigger = self._fds = os.pipe() - asyncore.file_dispatcher.__init__(self, r, map=map) + wasyncore.file_dispatcher.__init__(self, r, map=map) def _close(self): for fd in self._fds: os.close(fd) self._fds = [] + wasyncore.file_dispatcher.close(self) def _physical_pull(self): - os.write(self.trigger, b'x') + os.write(self.trigger, b"x") -else: # pragma: no cover +else: # pragma: no cover # Windows version; uses just sockets, because a pipe isn't select'able # on Windows. - class trigger(_triggerbase, asyncore.dispatcher): + class trigger(_triggerbase, wasyncore.dispatcher): kind = "loopback" def __init__(self, map): @@ -139,12 +141,12 @@ def __init__(self, map): # Get a pair of connected sockets. The trigger is the 'w' # end of the pair, which is connected to 'r'. 'r' is put - # in the asyncore socket map. "pulling the trigger" then + # in the wasyncore socket map. "pulling the trigger" then # means writing something on w, which will wake up r. w = socket.socket() # Disable buffering -- pulling the trigger sends 1 byte, - # and we want that sent immediately, to wake up asyncore's + # and we want that sent immediately, to wake up wasyncore's # select() ASAP. w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -163,20 +165,20 @@ def __init__(self, map): # for hideous details. a = socket.socket() a.bind(("127.0.0.1", 0)) - connect_address = a.getsockname() # assigned (host, port) pair + connect_address = a.getsockname() # assigned (host, port) pair a.listen(1) try: w.connect(connect_address) - break # success - except socket.error as detail: - if detail[0] != errno.WSAEADDRINUSE: + break # success + except OSError as detail: + if getattr(detail, "winerror", None) != errno.WSAEADDRINUSE: # "Address already in use" is the only error # I've seen on two WinXP Pro SP2 boxes, under # Pythons 2.3.5 and 2.4.1. raise # (10048, 'Address already in use') # assert count <= 2 # never triggered in Tim's tests - if count >= 10: # I've never seen it go above 2 + if count >= 10: # I've never seen it go above 2 a.close() w.close() raise RuntimeError("Cannot bind trigger!") @@ -184,10 +186,10 @@ def __init__(self, map): # sleep() here, but it didn't appear to help or hurt. a.close() - r, addr = a.accept() # r becomes asyncore's (self.)socket + r, addr = a.accept() # r becomes wasyncore's (self.)socket a.close() self.trigger = w - asyncore.dispatcher.__init__(self, r, map=map) + wasyncore.dispatcher.__init__(self, r, map=map) def _close(self): # self.socket is r, and self.trigger is w, from __init__ @@ -195,4 +197,4 @@ def _close(self): self.trigger.close() def _physical_pull(self): - self.trigger.send(b'x') + self.trigger.send(b"x") diff --git a/src/waitress/utilities.py b/src/waitress/utilities.py new file mode 100644 index 00000000..164752f9 --- /dev/null +++ b/src/waitress/utilities.py @@ -0,0 +1,298 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""Utility functions +""" + +import calendar +import errno +import logging +import os +import re +import stat +import time + +from .rfc7230 import QUOTED_PAIR_RE, QUOTED_STRING_RE + +logger = logging.getLogger("waitress") +queue_logger = logging.getLogger("waitress.queue") + + +def find_double_newline(s): + """Returns the position just after a double newline in the given string.""" + pos = s.find(b"\r\n\r\n") + + if pos >= 0: + pos += 4 + + return pos + + +def concat(*args): + return "".join(args) + + +def join(seq, field=" "): + return field.join(seq) + + +def group(s): + return "(" + s + ")" + + +short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] +long_days = [ + "sunday", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", +] + +short_day_reg = group(join(short_days, "|")) +long_day_reg = group(join(long_days, "|")) + +daymap = {} + +for i in range(7): + daymap[short_days[i]] = i + daymap[long_days[i]] = i + +hms_reg = join(3 * [group("[0-9][0-9]")], ":") + +months = [ + "jan", + "feb", + "mar", + "apr", + "may", + "jun", + "jul", + "aug", + "sep", + "oct", + "nov", + "dec", +] + +monmap = {} + +for i in range(12): + monmap[months[i]] = i + 1 + +months_reg = group(join(months, "|")) + +# From draft-ietf-http-v11-spec-07.txt/3.3.1 +# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 +# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 +# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format + +# rfc822 format +rfc822_date = join( + [ + concat(short_day_reg, ","), # day + group("[0-9][0-9]?"), # date + months_reg, # month + group("[0-9]+"), # year + hms_reg, # hour minute second + "gmt", + ], + " ", +) + +rfc822_reg = re.compile(rfc822_date) + + +def unpack_rfc822(m): + g = m.group + + return ( + int(g(4)), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# rfc850 format +rfc850_date = join( + [ + concat(long_day_reg, ","), + join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), + hms_reg, + "gmt", + ], + " ", +) + +rfc850_reg = re.compile(rfc850_date) +# they actually unpack the same way +def unpack_rfc850(m): + g = m.group + yr = g(4) + + if len(yr) == 2: + yr = "19" + yr + + return ( + int(yr), # year + monmap[g(3)], # month + int(g(2)), # day + int(g(5)), # hour + int(g(6)), # minute + int(g(7)), # second + 0, + 0, + 0, + ) + + +# parsdate.parsedate - ~700/sec. +# parse_http_date - ~1333/sec. + +weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] +monthname = [ + None, + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", +] + + +def build_http_date(when): + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) + + return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + weekdayname[wd], + day, + monthname[month], + year, + hh, + mm, + ss, + ) + + +def parse_http_date(d): + d = d.lower() + m = rfc850_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc850(m))) + else: + m = rfc822_reg.match(d) + + if m and m.end() == len(d): + retval = int(calendar.timegm(unpack_rfc822(m))) + else: + return 0 + + return retval + + +def undquote(value): + if value.startswith('"') and value.endswith('"'): + # So it claims to be DQUOTE'ed, let's validate that + matches = QUOTED_STRING_RE.match(value) + + if matches and matches.end() == len(value): + # Remove the DQUOTE's from the value + value = value[1:-1] + + # Remove all backslashes that are followed by a valid vchar or + # obs-text + value = QUOTED_PAIR_RE.sub(r"\1", value) + + return value + elif not value.startswith('"') and not value.endswith('"'): + return value + + raise ValueError("Invalid quoting in value") + + +def cleanup_unix_socket(path): + try: + st = os.stat(path) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise # pragma: no cover + else: + if stat.S_ISSOCK(st.st_mode): + try: + os.remove(path) + except OSError: # pragma: no cover + # avoid race condition error during tests + pass + + +class Error: + code = 500 + reason = "Internal Server Error" + + def __init__(self, body): + self.body = body + + def to_response(self): + status = f"{self.code} {self.reason}" + body = f"{self.reason}\r\n\r\n{self.body}" + tag = "\r\n\r\n(generated by waitress)" + body = (body + tag).encode("utf-8") + headers = [("Content-Type", "text/plain; charset=utf-8")] + + return status, headers, body + + def wsgi_response(self, environ, start_response): + status, headers, body = self.to_response() + start_response(status, headers) + yield body + + +class BadRequest(Error): + code = 400 + reason = "Bad Request" + + +class RequestHeaderFieldsTooLarge(BadRequest): + code = 431 + reason = "Request Header Fields Too Large" + + +class RequestEntityTooLarge(BadRequest): + code = 413 + reason = "Request Entity Too Large" + + +class InternalServerError(Error): + code = 500 + reason = "Internal Server Error" + + +class ServerNotImplemented(Error): + code = 501 + reason = "Not Implemented" diff --git a/src/waitress/wasyncore.py b/src/waitress/wasyncore.py new file mode 100644 index 00000000..b3459e01 --- /dev/null +++ b/src/waitress/wasyncore.py @@ -0,0 +1,692 @@ +# -*- Mode: Python -*- +# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. + +NB: this is a fork of asyncore from the stdlib that we've (the waitress +developers) named 'wasyncore' to ensure forward compatibility, as asyncore +in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore +nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. +""" + +from errno import ( + EAGAIN, + EALREADY, + EBADF, + ECONNABORTED, + ECONNRESET, + EINPROGRESS, + EINTR, + EINVAL, + EISCONN, + ENOTCONN, + EPIPE, + ESHUTDOWN, + EWOULDBLOCK, + errorcode, +) +import logging +import os +import select +import socket +import sys +import time +import warnings + +from . import compat, utilities + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + + +def _strerror(err): + try: + return os.strerror(err) + except (TypeError, ValueError, OverflowError, NameError): + return "Unknown error %s" % err + + +class ExitNow(Exception): + pass + + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except OSError as e: + if e.args[0] not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + + +def poll(timeout=0.0, map=None): + if map is None: # pragma: no cover + map = socket_map + if map: + r = [] + w = [] + e = [] + for fd, obj in list(map.items()): # list() call FBO py3 + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + try: + r, w, e = select.select(r, w, e, timeout) + except OSError as err: + if err.args[0] != EINTR: + raise + else: + return + + for fd in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + _exception(obj) + + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: # pragma: no cover + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout * 1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + try: + r = pollster.poll(timeout) + except OSError as err: + if err.args[0] != EINTR: + raise + r = [] + + for fd, flags in r: + obj = map.get(fd) + if obj is None: # pragma: no cover + continue + readwrite(obj, flags) + + +poll3 = poll2 # Alias for backward compatibility + + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: # pragma: no cover + map = socket_map + + if use_poll and hasattr(select, "poll"): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: # pragma: no cover + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + + +def compact_traceback(): + t, v, tb = sys.exc_info() + tbinfo = [] + if not tb: # pragma: no cover + raise AssertionError("traceback does not exist") + while tb: + tbinfo.append( + ( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno), + ) + ) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) + return (file, function, line), t, v, info + + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({"warning"}) + logger = utilities.logger + compact_traceback = staticmethod(compact_traceback) # for testing + + def __init__(self, sock=None, map=None): + if map is None: # pragma: no cover + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(0) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except OSError as err: + if err.args[0] in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__ + "." + self.__class__.__qualname__] + if self.accepting and self.addr: + status.append("listening") + elif self.connected: + status.append("connected") + if self.addr is not None: + try: + status.append("%s:%d" % self.addr) + except TypeError: # pragma: no cover + status.append(repr(self.addr)) + return "<{} at {:#x}>".format(" ".join(status), id(self)) + + __str__ = __repr__ + + def add_channel(self, map=None): + # self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + # self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(0) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, + ) + except OSError: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == "nt" and num > 5: # pragma: no cover + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) + or err == EINVAL + and os.name == "nt" + ): # pragma: no cover + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise OSError(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except OSError as why: + if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data, do_close=True): + try: + result = self.socket.send(data) + return result + except OSError as why: + if why.args[0] == EWOULDBLOCK: + return 0 + elif why.args[0] in _DISCONNECTED: + if do_close: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b"" + else: + return data + except OSError as why: + # winsock sometimes raises ENOTCONN + if why.args[0] in _DISCONNECTED: + self.handle_close() + return b"" + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except OSError as why: + if why.args[0] not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + self.logger.log(logging.DEBUG, message) + + def log_info(self, message, type="info"): + severity = { + "info": logging.INFO, + "warning": logging.WARN, + "error": logging.ERROR, + } + self.logger.log(severity.get(type, logging.INFO), message) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = self.compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: # pragma: no cover + self_repr = "<__repr__(self) failed for object at %0x>" % id(self) + + self.log_info( + "uncaptured python exception, closing channel %s (%s:%s %s)" + % (self_repr, t, v, tbinfo), + "error", + ) + self.handle_close() + + def handle_expt(self): + self.log_info("unhandled incoming priority event", "warning") + + def handle_read(self): + self.log_info("unhandled read event", "warning") + + def handle_write(self): + self.log_info("unhandled write event", "warning") + + def handle_connect(self): + self.log_info("unhandled connect event", "warning") + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info("unhandled accepted event", "warning") + + def handle_close(self): + self.log_info("unhandled close event", "warning") + self.close() + + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + + +class dispatcher_with_send(dispatcher): + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b"" + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + handle_write = initiate_send + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: # pragma: no cover + self.log_info("sending %s" % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + + +def close_all(map=None, ignore_all=False): + if map is None: # pragma: no cover + map = socket_map + for x in list(map.values()): # list() FBO py3 + try: + x.close() + except OSError as x: + if x.args[0] == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == "posix": + + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, ResourceWarning) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): # pragma: no cover + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError( + "Only asyncore specific behaviour " "implemented." + ) + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + os.set_blocking(fd, False) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() diff --git a/waitress/tests/__init__.py b/tests/__init__.py similarity index 100% rename from waitress/tests/__init__.py rename to tests/__init__.py diff --git a/waitress/tests/fixtureapps/__init__.py b/tests/fixtureapps/__init__.py similarity index 100% rename from waitress/tests/fixtureapps/__init__.py rename to tests/fixtureapps/__init__.py diff --git a/tests/fixtureapps/badcl.py b/tests/fixtureapps/badcl.py new file mode 100644 index 00000000..24067de4 --- /dev/null +++ b/tests/fixtureapps/badcl.py @@ -0,0 +1,11 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdefghi" + cl = len(body) + if environ["PATH_INFO"] == "/short_body": + cl = len(body) + 1 + if environ["PATH_INFO"] == "/long_body": + cl = len(body) - 1 + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/tests/fixtureapps/echo.py b/tests/fixtureapps/echo.py new file mode 100644 index 00000000..84975621 --- /dev/null +++ b/tests/fixtureapps/echo.py @@ -0,0 +1,65 @@ +from collections import namedtuple +import json + + +def app_body_only(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + start_response( + "200 OK", + [ + ("Content-Length", cl), + ("Content-Type", "text/plain"), + ], + ) + return [body] + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + request_body = environ["wsgi.input"].read(cl) + cl = str(len(request_body)) + meta = { + "method": environ["REQUEST_METHOD"], + "path_info": environ["PATH_INFO"], + "script_name": environ["SCRIPT_NAME"], + "query_string": environ["QUERY_STRING"], + "content_length": cl, + "scheme": environ["wsgi.url_scheme"], + "remote_addr": environ["REMOTE_ADDR"], + "remote_host": environ["REMOTE_HOST"], + "server_port": environ["SERVER_PORT"], + "server_name": environ["SERVER_NAME"], + "headers": { + k[len("HTTP_") :]: v for k, v in environ.items() if k.startswith("HTTP_") + }, + } + response = json.dumps(meta).encode("utf8") + b"\r\n\r\n" + request_body + start_response( + "200 OK", + [ + ("Content-Length", str(len(response))), + ("Content-Type", "text/plain"), + ], + ) + return [response] + + +Echo = namedtuple( + "Echo", + ( + "method path_info script_name query_string content_length scheme " + "remote_addr remote_host server_port server_name headers body" + ), +) + + +def parse_response(response): + meta, body = response.split(b"\r\n\r\n", 1) + meta = json.loads(meta.decode("utf8")) + return Echo(body=body, **meta) diff --git a/tests/fixtureapps/error.py b/tests/fixtureapps/error.py new file mode 100644 index 00000000..5afb1c54 --- /dev/null +++ b/tests/fixtureapps/error.py @@ -0,0 +1,21 @@ +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + cl = str(len(body)) + if environ["PATH_INFO"] == "/before_start_response": + raise ValueError("wrong") + write = start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + if environ["PATH_INFO"] == "/after_write_cb": + write("abc") + if environ["PATH_INFO"] == "/in_generator": + + def foo(): + yield "abc" + raise ValueError + + return foo() + raise ValueError("wrong") diff --git a/tests/fixtureapps/error_traceback.py b/tests/fixtureapps/error_traceback.py new file mode 100644 index 00000000..24e4cbf6 --- /dev/null +++ b/tests/fixtureapps/error_traceback.py @@ -0,0 +1,2 @@ +def app(environ, start_response): # pragma: no cover + raise ValueError("Invalid application: " + chr(8364)) diff --git a/tests/fixtureapps/filewrapper.py b/tests/fixtureapps/filewrapper.py new file mode 100644 index 00000000..40b7685d --- /dev/null +++ b/tests/fixtureapps/filewrapper.py @@ -0,0 +1,93 @@ +import io +import os + +here = os.path.dirname(os.path.abspath(__file__)) +fn = os.path.join(here, "groundhog1.jpg") + + +class KindaFilelike: # pragma: no cover + def __init__(self, bytes): + self.bytes = bytes + + def read(self, n): + bytes = self.bytes[:n] + self.bytes = self.bytes[n:] + return bytes + + +class UnseekableIOBase(io.RawIOBase): # pragma: no cover + def __init__(self, bytes): + self.buf = io.BytesIO(bytes) + + def writable(self): + return False + + def readable(self): + return True + + def seekable(self): + return False + + def read(self, n): + return self.buf.read(n) + + +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info.startswith("/filelike"): + f = open(fn, "rb") + f.seek(0, 2) + cl = f.tell() + f.seek(0) + if path_info == "/filelike": + headers = [ + ("Content-Length", str(cl)), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/filelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/filelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/filelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + else: + with open(fn, "rb") as fp: + data = fp.read() + cl = len(data) + f = KindaFilelike(data) + if path_info == "/notfilelike": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + elif path_info == "/notfilelike_iobase": + headers = [ + ("Content-Length", str(len(data))), + ("Content-Type", "image/jpeg"), + ] + f = UnseekableIOBase(data) + elif path_info == "/notfilelike_nocl": + headers = [("Content-Type", "image/jpeg")] + elif path_info == "/notfilelike_shortcl": + # short content length + headers = [ + ("Content-Length", "1"), + ("Content-Type", "image/jpeg"), + ] + else: + # long content length (/notfilelike_longcl) + headers = [ + ("Content-Length", str(cl + 10)), + ("Content-Type", "image/jpeg"), + ] + + start_response("200 OK", headers) + return environ["wsgi.file_wrapper"](f, 8192) diff --git a/tests/fixtureapps/getline.py b/tests/fixtureapps/getline.py new file mode 100644 index 00000000..bb5b39c4 --- /dev/null +++ b/tests/fixtureapps/getline.py @@ -0,0 +1,17 @@ +import sys + +if __name__ == "__main__": + try: + from urllib.request import URLError, urlopen + except ImportError: + from urllib2 import URLError, urlopen + + url = sys.argv[1] + headers = {"Content-Type": "text/plain; charset=utf-8"} + try: + resp = urlopen(url) + line = resp.readline().decode("ascii") # py3 + except URLError: + line = "failed to read %s" % url + sys.stdout.write(line) + sys.stdout.flush() diff --git a/waitress/tests/fixtureapps/groundhog1.jpg b/tests/fixtureapps/groundhog1.jpg similarity index 100% rename from waitress/tests/fixtureapps/groundhog1.jpg rename to tests/fixtureapps/groundhog1.jpg diff --git a/tests/fixtureapps/nocl.py b/tests/fixtureapps/nocl.py new file mode 100644 index 00000000..8948422c --- /dev/null +++ b/tests/fixtureapps/nocl.py @@ -0,0 +1,21 @@ +def chunks(l, n): # pragma: no cover + """Yield successive n-sized chunks from l.""" + for i in range(0, len(l), n): + yield l[i : i + n] + + +def gen(body): # pragma: no cover + yield from chunks(body, 10) + + +def app(environ, start_response): # pragma: no cover + cl = environ.get("CONTENT_LENGTH", None) + if cl is not None: + cl = int(cl) + body = environ["wsgi.input"].read(cl) + start_response("200 OK", [("Content-Type", "text/plain")]) + if environ["PATH_INFO"] == "/list": + return [body] + if environ["PATH_INFO"] == "/list_lentwo": + return [body[0:1], body[1:]] + return gen(body) diff --git a/tests/fixtureapps/runner.py b/tests/fixtureapps/runner.py new file mode 100644 index 00000000..1d66ad1c --- /dev/null +++ b/tests/fixtureapps/runner.py @@ -0,0 +1,6 @@ +def app(): # pragma: no cover + return None + + +def returns_app(): # pragma: no cover + return app diff --git a/tests/fixtureapps/sleepy.py b/tests/fixtureapps/sleepy.py new file mode 100644 index 00000000..2d171d8b --- /dev/null +++ b/tests/fixtureapps/sleepy.py @@ -0,0 +1,12 @@ +import time + + +def app(environ, start_response): # pragma: no cover + if environ["PATH_INFO"] == "/sleepy": + time.sleep(2) + body = b"sleepy returned" + else: + body = b"notsleepy returned" + cl = str(len(body)) + start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]) + return [body] diff --git a/tests/fixtureapps/toolarge.py b/tests/fixtureapps/toolarge.py new file mode 100644 index 00000000..a0f36d2c --- /dev/null +++ b/tests/fixtureapps/toolarge.py @@ -0,0 +1,7 @@ +def app(environ, start_response): # pragma: no cover + body = b"abcdef" + cl = len(body) + start_response( + "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] + ) + return [body] diff --git a/tests/fixtureapps/writecb.py b/tests/fixtureapps/writecb.py new file mode 100644 index 00000000..e1d2792e --- /dev/null +++ b/tests/fixtureapps/writecb.py @@ -0,0 +1,14 @@ +def app(environ, start_response): # pragma: no cover + path_info = environ["PATH_INFO"] + if path_info == "/no_content_length": + headers = [] + else: + headers = [("Content-Length", "9")] + write = start_response("200 OK", headers) + if path_info == "/long_body": + write(b"abcdefghij") + elif path_info == "/short_body": + write(b"abcdefgh") + else: + write(b"abcdefghi") + return [] diff --git a/tests/test_adjustments.py b/tests/test_adjustments.py new file mode 100644 index 00000000..69cdf513 --- /dev/null +++ b/tests/test_adjustments.py @@ -0,0 +1,498 @@ +import socket +import sys +import unittest +import warnings + +from waitress.compat import WIN + + +class Test_asbool(unittest.TestCase): + def _callFUT(self, s): + from waitress.adjustments import asbool + + return asbool(s) + + def test_s_is_None(self): + result = self._callFUT(None) + self.assertEqual(result, False) + + def test_s_is_True(self): + result = self._callFUT(True) + self.assertEqual(result, True) + + def test_s_is_False(self): + result = self._callFUT(False) + self.assertEqual(result, False) + + def test_s_is_true(self): + result = self._callFUT("True") + self.assertEqual(result, True) + + def test_s_is_false(self): + result = self._callFUT("False") + self.assertEqual(result, False) + + def test_s_is_yes(self): + result = self._callFUT("yes") + self.assertEqual(result, True) + + def test_s_is_on(self): + result = self._callFUT("on") + self.assertEqual(result, True) + + def test_s_is_1(self): + result = self._callFUT(1) + self.assertEqual(result, True) + + +class Test_as_socket_list(unittest.TestCase): + def test_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + ] + + if hasattr(socket, "AF_UNIX"): + sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) + new_sockets = as_socket_list(sockets) + self.assertEqual(sockets, new_sockets) + + for sock in sockets: + sock.close() + + def test_not_only_sockets_in_list(self): + from waitress.adjustments import as_socket_list + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + {"something": "else"}, + ] + new_sockets = as_socket_list(sockets) + self.assertEqual(new_sockets, [sockets[0], sockets[1]]) + + for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]: + sock.close() + + +class TestAdjustments(unittest.TestCase): + def _hasIPv6(self): # pragma: nocover + if not socket.has_ipv6: + return False + + try: + socket.getaddrinfo( + "::1", + 0, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG, + ) + + return True + except socket.gaierror as e: + # Check to see what the error is + + if e.errno == socket.EAI_ADDRFAMILY: + return False + else: + raise e + + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_goodvars(self): + inst = self._makeOne( + host="localhost", + port="8080", + threads="5", + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"forwarded"}, + trusted_proxy_count=2, + log_untrusted_proxy_headers=True, + url_scheme="https", + backlog="20", + recv_bytes="200", + send_bytes="300", + outbuf_overflow="400", + inbuf_overflow="500", + connection_limit="1000", + cleanup_interval="1100", + channel_timeout="1200", + log_socket_errors="true", + max_request_header_size="1300", + max_request_body_size="1400", + expose_tracebacks="true", + ident="abc", + asyncore_loop_timeout="5", + asyncore_use_poll=True, + unix_socket_perms="777", + url_prefix="///foo/", + ipv4=True, + ipv6=False, + ) + + self.assertEqual(inst.host, "localhost") + self.assertEqual(inst.port, 8080) + self.assertEqual(inst.threads, 5) + self.assertEqual(inst.trusted_proxy, "192.168.1.1") + self.assertEqual(inst.trusted_proxy_headers, {"forwarded"}) + self.assertEqual(inst.trusted_proxy_count, 2) + self.assertEqual(inst.log_untrusted_proxy_headers, True) + self.assertEqual(inst.url_scheme, "https") + self.assertEqual(inst.backlog, 20) + self.assertEqual(inst.recv_bytes, 200) + self.assertEqual(inst.send_bytes, 300) + self.assertEqual(inst.outbuf_overflow, 400) + self.assertEqual(inst.inbuf_overflow, 500) + self.assertEqual(inst.connection_limit, 1000) + self.assertEqual(inst.cleanup_interval, 1100) + self.assertEqual(inst.channel_timeout, 1200) + self.assertEqual(inst.log_socket_errors, True) + self.assertEqual(inst.max_request_header_size, 1300) + self.assertEqual(inst.max_request_body_size, 1400) + self.assertEqual(inst.expose_tracebacks, True) + self.assertEqual(inst.asyncore_loop_timeout, 5) + self.assertEqual(inst.asyncore_use_poll, True) + self.assertEqual(inst.ident, "abc") + self.assertEqual(inst.unix_socket_perms, 0o777) + self.assertEqual(inst.url_prefix, "/foo") + self.assertEqual(inst.ipv4, True) + self.assertEqual(inst.ipv6, False) + + bind_pairs = [ + sockaddr[:2] + for (family, _, _, sockaddr) in inst.listen + if family == socket.AF_INET + ] + + # On Travis, somehow we start listening to two sockets when resolving + # localhost... + self.assertEqual(("127.0.0.1", 8080), bind_pairs[0]) + + def test_goodvar_listen(self): + inst = self._makeOne(listen="127.0.0.1") + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 8080)]) + + def test_default_listen(self): + inst = self._makeOne() + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [("0.0.0.0", 8080)]) + + def test_multiple_listen(self): + inst = self._makeOne(listen="127.0.0.1:9090 127.0.0.1:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 9090), ("127.0.0.1", 8080)]) + + def test_wildcard_listen(self): + inst = self._makeOne(listen="*:8080") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertTrue(len(bind_pairs) >= 1) + + def test_ipv6_no_port(self): # pragma: nocover + if not self._hasIPv6(): + return + + inst = self._makeOne(listen="[::1]") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("::1", 8080)]) + + def test_bad_port(self): + self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test") + + def test_service_port(self): + if WIN: # pragma: no cover + # On Windows this is broken, so we raise a ValueError + self.assertRaises( + ValueError, + self._makeOne, + listen="127.0.0.1:http", + ) + + return + + inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https") + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)]) + + def test_dont_mix_host_port_listen(self): + self.assertRaises( + ValueError, + self._makeOne, + host="localhost", + port="8080", + listen="127.0.0.1:8080", + ) + + def test_good_sockets(self): + sockets = [ + socket.socket(socket.AF_INET6, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + inst = self._makeOne(sockets=sockets) + self.assertEqual(inst.sockets, sockets) + sockets[0].close() + sockets[1].close() + + def test_dont_mix_sockets_and_listen(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_host_port(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_sockets_and_unix_socket(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets + ) + sockets[0].close() + + def test_dont_mix_unix_socket_and_host_port(self): + self.assertRaises( + ValueError, + self._makeOne, + unix_socket="./tmp/test", + host="localhost", + port="8080", + ) + + def test_dont_mix_unix_socket_and_listen(self): + self.assertRaises( + ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080" + ) + + def test_dont_use_unsupported_socket_types(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + + def test_dont_mix_forwarded_with_x_forwarded(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-for"}, + ) + + self.assertIn("The Forwarded proxy header", str(cm.exception)) + + def test_unknown_trusted_proxy_header(self): + with self.assertRaises(ValueError) as cm: + self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers={"forwarded", "x-forwarded-unknown"}, + ) + + self.assertIn( + "unknown trusted_proxy_headers value (x-forwarded-unknown)", + str(cm.exception), + ) + + def test_trusted_proxy_count_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_count=1) + + self.assertIn("trusted_proxy_count has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_no_trusted_proxy(self): + with self.assertRaises(ValueError) as cm: + self._makeOne(trusted_proxy_headers={"forwarded"}) + + self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception)) + + def test_trusted_proxy_headers_string_list(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for x-forwarded-by", + ) + self.assertEqual( + inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"} + ) + + def test_trusted_proxy_headers_string_list_newlines(self): + inst = self._makeOne( + trusted_proxy="localhost", + trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host", + ) + self.assertEqual( + inst.trusted_proxy_headers, + {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"}, + ) + + def test_no_trusted_proxy_headers_trusted_proxy(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(trusted_proxy="localhost") + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0])) + + def test_clear_untrusted_proxy_headers(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne( + trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"} + ) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn( + "clear_untrusted_proxy_headers will be set to True", str(w[0]) + ) + + def test_deprecated_send_bytes(self): + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + warnings.simplefilter("always") + self._makeOne(send_bytes=1) + + self.assertGreaterEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn("send_bytes", str(w[0])) + + def test_badvar(self): + self.assertRaises(ValueError, self._makeOne, nope=True) + + def test_ipv4_disabled(self): + self.assertRaises( + ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080" + ) + + def test_ipv6_disabled(self): + self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") + + def test_server_header_removable(self): + inst = self._makeOne(ident=None) + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="") + self.assertEqual(inst.ident, None) + + inst = self._makeOne(ident="specific_header") + self.assertEqual(inst.ident, "specific_header") + + +class TestCLI(unittest.TestCase): + def parse(self, argv): + from waitress.adjustments import Adjustments + + return Adjustments.parse_args(argv) + + def assertDictContainsSubset(self, subset, dictionary): + self.assertTrue(set(subset.items()) <= set(dictionary.items())) + + def test_noargs(self): + opts, args = self.parse([]) + self.assertDictEqual(opts, {"call": False, "help": False}) + self.assertSequenceEqual(args, []) + + def test_help(self): + opts, args = self.parse(["--help"]) + self.assertDictEqual(opts, {"call": False, "help": True}) + self.assertSequenceEqual(args, []) + + def test_call(self): + opts, args = self.parse(["--call"]) + self.assertDictEqual(opts, {"call": True, "help": False}) + self.assertSequenceEqual(args, []) + + def test_both(self): + opts, args = self.parse(["--call", "--help"]) + self.assertDictEqual(opts, {"call": True, "help": True}) + self.assertSequenceEqual(args, []) + + def test_positive_boolean(self): + opts, args = self.parse(["--expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts) + self.assertSequenceEqual(args, []) + + def test_negative_boolean(self): + opts, args = self.parse(["--no-expose-tracebacks"]) + self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts) + self.assertSequenceEqual(args, []) + + def test_cast_params(self): + opts, args = self.parse( + ["--host=localhost", "--port=80", "--unix-socket-perms=777"] + ) + self.assertDictContainsSubset( + { + "host": "localhost", + "port": "80", + "unix_socket_perms": "777", + }, + opts, + ) + self.assertSequenceEqual(args, []) + + def test_listen_params(self): + opts, args = self.parse( + [ + "--listen=test:80", + ] + ) + + self.assertDictContainsSubset({"listen": " test:80"}, opts) + self.assertSequenceEqual(args, []) + + def test_multiple_listen_params(self): + opts, args = self.parse( + [ + "--listen=test:80", + "--listen=test:8080", + ] + ) + + self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts) + self.assertSequenceEqual(args, []) + + def test_bad_param(self): + import getopt + + self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"]) + + +if hasattr(socket, "AF_UNIX"): + + class TestUnixSocket(unittest.TestCase): + def _makeOne(self, **kw): + from waitress.adjustments import Adjustments + + return Adjustments(**kw) + + def test_dont_mix_internet_and_unix_sockets(self): + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + self.assertRaises(ValueError, self._makeOne, sockets=sockets) + sockets[0].close() + sockets[1].close() diff --git a/waitress/tests/test_buffers.py b/tests/test_buffers.py similarity index 68% rename from waitress/tests/test_buffers.py rename to tests/test_buffers.py index 46a215eb..b37949b8 100644 --- a/waitress/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -1,18 +1,28 @@ -import unittest import io +import unittest -class TestFileBasedBuffer(unittest.TestCase): +class TestFileBasedBuffer(unittest.TestCase): def _makeOne(self, file=None, from_buffer=None): from waitress.buffers import FileBasedBuffer - return FileBasedBuffer(file, from_buffer=from_buffer) + + buf = FileBasedBuffer(file, from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() def test_ctor_from_buffer_None(self): - inst = self._makeOne('file') - self.assertEqual(inst.file, 'file') + inst = self._makeOne("file") + self.assertEqual(inst.file, "file") def test_ctor_from_buffer(self): - from_buffer = io.BytesIO(b'data') + from_buffer = io.BytesIO(b"data") from_buffer.getfile = lambda *x: from_buffer f = io.BytesIO() inst = self._makeOne(f, from_buffer) @@ -34,42 +44,42 @@ def test___nonzero__(self): self.assertEqual(bool(inst), True) def test_append(self): - f = io.BytesIO(b'data') + f = io.BytesIO(b"data") inst = self._makeOne(f) - inst.append(b'data2') - self.assertEqual(f.getvalue(), b'datadata2') + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"datadata2") self.assertEqual(inst.remain, 5) def test_get_skip_true(self): - f = io.BytesIO(b'data') + f = io.BytesIO(b"data") inst = self._makeOne(f) result = inst.get(100, skip=True) - self.assertEqual(result, b'data') + self.assertEqual(result, b"data") self.assertEqual(inst.remain, -4) def test_get_skip_false(self): - f = io.BytesIO(b'data') + f = io.BytesIO(b"data") inst = self._makeOne(f) result = inst.get(100, skip=False) - self.assertEqual(result, b'data') + self.assertEqual(result, b"data") self.assertEqual(inst.remain, 0) def test_get_skip_bytes_less_than_zero(self): - f = io.BytesIO(b'data') + f = io.BytesIO(b"data") inst = self._makeOne(f) result = inst.get(-1, skip=False) - self.assertEqual(result, b'data') + self.assertEqual(result, b"data") self.assertEqual(inst.remain, 0) def test_skip_remain_gt_bytes(self): - f = io.BytesIO(b'd') + f = io.BytesIO(b"d") inst = self._makeOne(f) inst.remain = 1 inst.skip(1) self.assertEqual(inst.remain, 0) def test_skip_remain_lt_bytes(self): - f = io.BytesIO(b'd') + f = io.BytesIO(b"d") inst = self._makeOne(f) inst.remain = 1 self.assertRaises(ValueError, inst.skip, 2) @@ -79,24 +89,24 @@ def test_newfile(self): self.assertRaises(NotImplementedError, inst.newfile) def test_prune_remain_notzero(self): - f = io.BytesIO(b'd') + f = io.BytesIO(b"d") inst = self._makeOne(f) inst.remain = 1 nf = io.BytesIO() inst.newfile = lambda *x: nf inst.prune() self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b'd') + self.assertEqual(nf.getvalue(), b"d") def test_prune_remain_zero_tell_notzero(self): - f = io.BytesIO(b'd') + f = io.BytesIO(b"d") inst = self._makeOne(f) - nf = io.BytesIO(b'd') + nf = io.BytesIO(b"d") inst.newfile = lambda *x: nf inst.remain = 0 inst.prune() self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b'd') + self.assertEqual(nf.getvalue(), b"d") def test_prune_remain_zero_tell_zero(self): f = io.BytesIO() @@ -110,132 +120,169 @@ def test_close(self): inst = self._makeOne(f) inst.close() self.assertTrue(f.closed) + self.buffers_to_close.remove(inst) -class TestTempfileBasedBuffer(unittest.TestCase): +class TestTempfileBasedBuffer(unittest.TestCase): def _makeOne(self, from_buffer=None): from waitress.buffers import TempfileBasedBuffer - return TempfileBasedBuffer(from_buffer=from_buffer) + + buf = TempfileBasedBuffer(from_buffer=from_buffer) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() def test_newfile(self): inst = self._makeOne() r = inst.newfile() - self.assertTrue(hasattr(r, 'fileno')) # file + self.assertTrue(hasattr(r, "fileno")) # file + r.close() -class TestBytesIOBasedBuffer(unittest.TestCase): +class TestBytesIOBasedBuffer(unittest.TestCase): def _makeOne(self, from_buffer=None): from waitress.buffers import BytesIOBasedBuffer + return BytesIOBasedBuffer(from_buffer=from_buffer) def test_ctor_from_buffer_not_None(self): f = io.BytesIO() f.getfile = lambda *x: f inst = self._makeOne(f) - self.assertTrue(hasattr(inst.file, 'read')) + self.assertTrue(hasattr(inst.file, "read")) def test_ctor_from_buffer_None(self): inst = self._makeOne() - self.assertTrue(hasattr(inst.file, 'read')) + self.assertTrue(hasattr(inst.file, "read")) def test_newfile(self): inst = self._makeOne() r = inst.newfile() - self.assertTrue(hasattr(r, 'read')) + self.assertTrue(hasattr(r, "read")) -class TestReadOnlyFileBasedBuffer(unittest.TestCase): +class TestReadOnlyFileBasedBuffer(unittest.TestCase): def _makeOne(self, file, block_size=8192): from waitress.buffers import ReadOnlyFileBasedBuffer - return ReadOnlyFileBasedBuffer(file, block_size) + + buf = ReadOnlyFileBasedBuffer(file, block_size) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() def test_prepare_not_seekable(self): - f = KindaFilelike(b'abc') + f = KindaFilelike(b"abc") inst = self._makeOne(f) + self.assertFalse(hasattr(inst, "seek")) + self.assertFalse(hasattr(inst, "tell")) result = inst.prepare() self.assertEqual(result, False) self.assertEqual(inst.remain, 0) def test_prepare_not_seekable_closeable(self): - f = KindaFilelike(b'abc', close=1) + f = KindaFilelike(b"abc", close=1) inst = self._makeOne(f) result = inst.prepare() self.assertEqual(result, False) self.assertEqual(inst.remain, 0) - self.assertTrue(hasattr(inst, 'close')) + self.assertTrue(hasattr(inst, "close")) def test_prepare_seekable_closeable(self): - f = Filelike(b'abc', close=1, tellresults=[0, 10]) + f = Filelike(b"abc", close=1, tellresults=[0, 10]) inst = self._makeOne(f) + self.assertEqual(inst.seek, f.seek) + self.assertEqual(inst.tell, f.tell) result = inst.prepare() self.assertEqual(result, 10) self.assertEqual(inst.remain, 10) self.assertEqual(inst.file.seeked, 0) - self.assertTrue(hasattr(inst, 'close')) + self.assertTrue(hasattr(inst, "close")) def test_get_numbytes_neg_one(self): - f = io.BytesIO(b'abcdef') + f = io.BytesIO(b"abcdef") inst = self._makeOne(f) inst.remain = 2 result = inst.get(-1) - self.assertEqual(result, b'ab') + self.assertEqual(result, b"ab") self.assertEqual(inst.remain, 2) self.assertEqual(f.tell(), 0) def test_get_numbytes_gt_remain(self): - f = io.BytesIO(b'abcdef') + f = io.BytesIO(b"abcdef") inst = self._makeOne(f) inst.remain = 2 result = inst.get(3) - self.assertEqual(result, b'ab') + self.assertEqual(result, b"ab") self.assertEqual(inst.remain, 2) self.assertEqual(f.tell(), 0) def test_get_numbytes_lt_remain(self): - f = io.BytesIO(b'abcdef') + f = io.BytesIO(b"abcdef") inst = self._makeOne(f) inst.remain = 2 result = inst.get(1) - self.assertEqual(result, b'a') + self.assertEqual(result, b"a") self.assertEqual(inst.remain, 2) self.assertEqual(f.tell(), 0) def test_get_numbytes_gt_remain_withskip(self): - f = io.BytesIO(b'abcdef') + f = io.BytesIO(b"abcdef") inst = self._makeOne(f) inst.remain = 2 result = inst.get(3, skip=True) - self.assertEqual(result, b'ab') + self.assertEqual(result, b"ab") self.assertEqual(inst.remain, 0) self.assertEqual(f.tell(), 2) def test_get_numbytes_lt_remain_withskip(self): - f = io.BytesIO(b'abcdef') + f = io.BytesIO(b"abcdef") inst = self._makeOne(f) inst.remain = 2 result = inst.get(1, skip=True) - self.assertEqual(result, b'a') + self.assertEqual(result, b"a") self.assertEqual(inst.remain, 1) self.assertEqual(f.tell(), 1) def test___iter__(self): - data = b'a' * 10000 + data = b"a" * 10000 f = io.BytesIO(data) inst = self._makeOne(f) - r = b'' + r = b"" for val in inst: r += val self.assertEqual(r, data) def test_append(self): inst = self._makeOne(None) - self.assertRaises(NotImplementedError, inst.append, 'a') + self.assertRaises(NotImplementedError, inst.append, "a") -class TestOverflowableBuffer(unittest.TestCase): +class TestOverflowableBuffer(unittest.TestCase): def _makeOne(self, overflow=10): from waitress.buffers import OverflowableBuffer - return OverflowableBuffer(overflow) + + buf = OverflowableBuffer(overflow) + self.buffers_to_close.append(buf) + return buf + + def setUp(self): + self.buffers_to_close = [] + + def tearDown(self): + for buf in self.buffers_to_close: + buf.close() def test___len__buf_is_None(self): inst = self._makeOne() @@ -243,15 +290,17 @@ def test___len__buf_is_None(self): def test___len__buf_is_not_None(self): inst = self._makeOne() - inst.buf = b'abc' + inst.buf = b"abc" self.assertEqual(len(inst), 3) + self.buffers_to_close.remove(inst) def test___nonzero__(self): inst = self._makeOne() - inst.buf = b'abc' + inst.buf = b"abc" self.assertEqual(bool(inst), True) - inst.buf = b'' + inst.buf = b"" self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) def test___nonzero___on_int_overflow_buffer(self): inst = self._makeOne() @@ -259,124 +308,141 @@ def test___nonzero___on_int_overflow_buffer(self): class int_overflow_buf(bytes): def __len__(self): # maxint + 1 - return 0x7fffffffffffffff + 1 + return 0x7FFFFFFFFFFFFFFF + 1 + inst.buf = int_overflow_buf() self.assertEqual(bool(inst), True) - inst.buf = b'' + inst.buf = b"" self.assertEqual(bool(inst), False) + self.buffers_to_close.remove(inst) def test__create_buffer_large(self): from waitress.buffers import TempfileBasedBuffer + inst = self._makeOne() - inst.strbuf = b'x' * 11 + inst.strbuf = b"x" * 11 inst._create_buffer() self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) - self.assertEqual(inst.buf.get(100), b'x' * 11) - self.assertEqual(inst.strbuf, b'') + self.assertEqual(inst.buf.get(100), b"x" * 11) + self.assertEqual(inst.strbuf, b"") def test__create_buffer_small(self): from waitress.buffers import BytesIOBasedBuffer + inst = self._makeOne() - inst.strbuf = b'x' * 5 + inst.strbuf = b"x" * 5 inst._create_buffer() self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) - self.assertEqual(inst.buf.get(100), b'x' * 5) - self.assertEqual(inst.strbuf, b'') + self.assertEqual(inst.buf.get(100), b"x" * 5) + self.assertEqual(inst.strbuf, b"") def test_append_with_len_more_than_max_int(self): from waitress.compat import MAXINT + inst = self._makeOne() inst.overflowed = True buf = DummyBuffer(length=MAXINT) inst.buf = buf - result = inst.append(b'x') + result = inst.append(b"x") # we don't want this to throw an OverflowError on Python 2 (see # https://github.com/Pylons/waitress/issues/47) self.assertEqual(result, None) - + self.buffers_to_close.remove(inst) + def test_append_buf_None_not_longer_than_srtbuf_limit(self): inst = self._makeOne() - inst.strbuf = b'x' * 5 - inst.append(b'hello') - self.assertEqual(inst.strbuf, b'xxxxxhello') + inst.strbuf = b"x" * 5 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"xxxxxhello") def test_append_buf_None_longer_than_strbuf_limit(self): inst = self._makeOne(10000) - inst.strbuf = b'x' * 8192 - inst.append(b'hello') - self.assertEqual(inst.strbuf, b'') + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") self.assertEqual(len(inst.buf), 8197) def test_append_overflow(self): inst = self._makeOne(10) - inst.strbuf = b'x' * 8192 - inst.append(b'hello') - self.assertEqual(inst.strbuf, b'') + inst.strbuf = b"x" * 8192 + inst.append(b"hello") + self.assertEqual(inst.strbuf, b"") self.assertEqual(len(inst.buf), 8197) def test_append_sz_gt_overflow(self): from waitress.buffers import BytesIOBasedBuffer - f = io.BytesIO(b'data') + + f = io.BytesIO(b"data") inst = self._makeOne(f) buf = BytesIOBasedBuffer() inst.buf = buf inst.overflow = 2 - inst.append(b'data2') - self.assertEqual(f.getvalue(), b'data') + inst.append(b"data2") + self.assertEqual(f.getvalue(), b"data") self.assertTrue(inst.overflowed) self.assertNotEqual(inst.buf, buf) def test_get_buf_None_skip_False(self): inst = self._makeOne() - inst.strbuf = b'x' * 5 + inst.strbuf = b"x" * 5 r = inst.get(5) - self.assertEqual(r, b'xxxxx') + self.assertEqual(r, b"xxxxx") def test_get_buf_None_skip_True(self): inst = self._makeOne() - inst.strbuf = b'x' * 5 + inst.strbuf = b"x" * 5 r = inst.get(5, skip=True) self.assertFalse(inst.buf is None) - self.assertEqual(r, b'xxxxx') + self.assertEqual(r, b"xxxxx") def test_skip_buf_None(self): inst = self._makeOne() - inst.strbuf = b'data' + inst.strbuf = b"data" inst.skip(4) - self.assertEqual(inst.strbuf, b'') + self.assertEqual(inst.strbuf, b"") self.assertNotEqual(inst.buf, None) def test_skip_buf_None_allow_prune_True(self): inst = self._makeOne() - inst.strbuf = b'data' + inst.strbuf = b"data" inst.skip(4, True) - self.assertEqual(inst.strbuf, b'') + self.assertEqual(inst.strbuf, b"") self.assertEqual(inst.buf, None) def test_prune_buf_None(self): inst = self._makeOne() inst.prune() - self.assertEqual(inst.strbuf, b'') + self.assertEqual(inst.strbuf, b"") def test_prune_with_buf(self): inst = self._makeOne() - class Buf(object): + + class Buf: def prune(self): self.pruned = True + inst.buf = Buf() inst.prune() self.assertEqual(inst.buf.pruned, True) + self.buffers_to_close.remove(inst) def test_prune_with_buf_overflow(self): inst = self._makeOne() + class DummyBuffer(io.BytesIO): def getfile(self): return self + def prune(self): return True + def __len__(self): return 5 - buf = DummyBuffer(b'data') + + def close(self): + pass + + buf = DummyBuffer(b"data") inst.buf = buf inst.overflowed = True inst.overflow = 10 @@ -385,19 +451,20 @@ def __len__(self): def test_prune_with_buflen_more_than_max_int(self): from waitress.compat import MAXINT + inst = self._makeOne() inst.overflowed = True - buf = DummyBuffer(length=MAXINT+1) + buf = DummyBuffer(length=MAXINT + 1) inst.buf = buf result = inst.prune() # we don't want this to throw an OverflowError on Python 2 (see # https://github.com/Pylons/waitress/issues/47) self.assertEqual(result, None) - + def test_getfile_buf_None(self): inst = self._makeOne() f = inst.getfile() - self.assertTrue(hasattr(f, 'read')) + self.assertTrue(hasattr(f, "read")) def test_getfile_buf_not_None(self): inst = self._makeOne() @@ -410,28 +477,31 @@ def test_getfile_buf_not_None(self): def test_close_nobuf(self): inst = self._makeOne() inst.buf = None - self.assertEqual(inst.close(), None) # doesnt raise + self.assertEqual(inst.close(), None) # doesnt raise + self.buffers_to_close.remove(inst) def test_close_withbuf(self): - class Buffer(object): + class Buffer: def close(self): self.closed = True + buf = Buffer() inst = self._makeOne() inst.buf = buf inst.close() self.assertTrue(buf.closed) + self.buffers_to_close.remove(inst) -class KindaFilelike(object): +class KindaFilelike: def __init__(self, bytes, close=None, tellresults=None): self.bytes = bytes self.tellresults = tellresults if close is not None: - self.close = close + self.close = lambda: close -class Filelike(KindaFilelike): +class Filelike(KindaFilelike): def seek(self, v, whence=0): self.seeked = v @@ -439,7 +509,8 @@ def tell(self): v = self.tellresults.pop(0) return v -class DummyBuffer(object): + +class DummyBuffer: def __init__(self, length=0): self.length = length @@ -451,3 +522,6 @@ def append(self, s): def prune(self): pass + + def close(self): + pass diff --git a/waitress/tests/test_channel.py b/tests/test_channel.py similarity index 58% rename from waitress/tests/test_channel.py rename to tests/test_channel.py index afe6e510..8467ae7a 100644 --- a/waitress/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,10 +1,13 @@ -import unittest import io +import unittest -class TestHTTPChannel(unittest.TestCase): +import pytest + +class TestHTTPChannel(unittest.TestCase): def _makeOne(self, sock, addr, adj, map=None): from waitress.channel import HTTPChannel + server = DummyServer() return HTTPChannel(server, sock, addr, adj=adj, map=map) @@ -13,30 +16,41 @@ def _makeOneWithMap(self, adj=None): adj = DummyAdjustments() sock = DummySock() map = {} - inst = self._makeOne(sock, '127.0.0.1', adj, map=map) + inst = self._makeOne(sock, "127.0.0.1", adj, map=map) inst.outbuf_lock = DummyLock() return inst, sock, map def test_ctor(self): inst, _, map = self._makeOneWithMap() - self.assertEqual(inst.addr, '127.0.0.1') + self.assertEqual(inst.addr, "127.0.0.1") + self.assertEqual(inst.sendbuf_len, 2048) self.assertEqual(map[100], inst) def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): from waitress.compat import MAXINT + inst, _, map = self._makeOneWithMap() - class DummyHugeBuffer(object): + + class DummyBuffer: + chunks = [] + + def append(self, data): + self.chunks.append(data) + + class DummyData: def __len__(self): - return MAXINT + 1 - inst.outbufs = [DummyHugeBuffer()] - result = inst.total_outbufs_len() + return MAXINT + + inst.total_outbufs_len = 1 + inst.outbufs = [DummyBuffer()] + inst.write_soon(DummyData()) # we are testing that this method does not raise an OverflowError # (see https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, MAXINT+1) + self.assertEqual(inst.total_outbufs_len, MAXINT + 1) def test_writable_something_in_outbuf(self): inst, sock, map = self._makeOneWithMap() - inst.outbufs[0].append(b'abc') + inst.total_outbufs_len = 3 self.assertTrue(inst.writable()) def test_writable_nothing_in_outbuf(self): @@ -64,46 +78,50 @@ def test_handle_write_with_requests(self): def test_handle_write_no_request_with_outbuf(self): inst, sock, map = self._makeOneWithMap() inst.requests = [] - inst.outbufs = [DummyBuffer(b'abc')] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) inst.last_activity = 0 result = inst.handle_write() self.assertEqual(result, None) self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b'abc') + self.assertEqual(sock.sent, b"abc") def test_handle_write_outbuf_raises_socketerror(self): import socket + inst, sock, map = self._makeOneWithMap() inst.requests = [] - outbuf = DummyBuffer(b'abc', socket.error) + outbuf = DummyBuffer(b"abc", socket.error) inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) inst.last_activity = 0 inst.logger = DummyLogger() result = inst.handle_write() self.assertEqual(result, None) self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b'') + self.assertEqual(sock.sent, b"") self.assertEqual(len(inst.logger.exceptions), 1) self.assertTrue(outbuf.closed) def test_handle_write_outbuf_raises_othererror(self): inst, sock, map = self._makeOneWithMap() inst.requests = [] - outbuf = DummyBuffer(b'abc', IOError) + outbuf = DummyBuffer(b"abc", IOError) inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) inst.last_activity = 0 inst.logger = DummyLogger() result = inst.handle_write() self.assertEqual(result, None) self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b'') + self.assertEqual(sock.sent, b"") self.assertEqual(len(inst.logger.exceptions), 1) self.assertTrue(outbuf.closed) def test_handle_write_no_requests_no_outbuf_will_close(self): inst, sock, map = self._makeOneWithMap() inst.requests = [] - outbuf = DummyBuffer(b'') + outbuf = DummyBuffer(b"") inst.outbufs = [outbuf] inst.will_close = True inst.last_activity = 0 @@ -114,24 +132,11 @@ def test_handle_write_no_requests_no_outbuf_will_close(self): self.assertEqual(inst.last_activity, 0) self.assertTrue(outbuf.closed) - def test_handle_write_no_requests_force_flush(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b'abc')] - inst.will_close = False - inst.force_flush = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertEqual(inst.force_flush, False) - self.assertEqual(sock.sent, b'abc') - def test_handle_write_no_requests_outbuf_gt_send_bytes(self): inst, sock, map = self._makeOneWithMap() inst.requests = [True] - inst.outbufs = [DummyBuffer(b'abc')] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) inst.adj.send_bytes = 2 inst.will_close = False inst.last_activity = 0 @@ -139,12 +144,13 @@ def test_handle_write_no_requests_outbuf_gt_send_bytes(self): self.assertEqual(result, None) self.assertEqual(inst.will_close, False) self.assertTrue(inst.outbuf_lock.acquired) - self.assertEqual(sock.sent, b'abc') + self.assertEqual(sock.sent, b"abc") def test_handle_write_close_when_flushed(self): inst, sock, map = self._makeOneWithMap() - outbuf = DummyBuffer(b'abc') + outbuf = DummyBuffer(b"abc") inst.outbufs = [outbuf] + inst.total_outbufs_len = len(outbuf) inst.will_close = False inst.close_when_flushed = True inst.last_activity = 0 @@ -152,7 +158,7 @@ def test_handle_write_close_when_flushed(self): self.assertEqual(result, None) self.assertEqual(inst.will_close, True) self.assertEqual(inst.close_when_flushed, False) - self.assertEqual(sock.sent, b'abc') + self.assertEqual(sock.sent, b"abc") self.assertTrue(outbuf.closed) def test_readable_no_requests_not_will_close(self): @@ -169,27 +175,28 @@ def test_readable_no_requests_will_close(self): def test_readable_with_requests(self): inst, sock, map = self._makeOneWithMap() - inst.requests = True + inst.requests = [True] self.assertEqual(inst.readable(), False) def test_handle_read_no_error(self): inst, sock, map = self._makeOneWithMap() inst.will_close = False - inst.recv = lambda *arg: b'abc' + inst.recv = lambda *arg: b"abc" inst.last_activity = 0 L = [] inst.received = lambda x: L.append(x) result = inst.handle_read() self.assertEqual(result, None) self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(L, [b'abc']) + self.assertEqual(L, [b"abc"]) def test_handle_read_error(self): - import socket inst, sock, map = self._makeOneWithMap() inst.will_close = False + def recv(b): - raise socket.error + raise OSError + inst.recv = recv inst.last_activity = 0 inst.logger = DummyLogger() @@ -200,30 +207,184 @@ def recv(b): def test_write_soon_empty_byte(self): inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b'') + wrote = inst.write_soon(b"") self.assertEqual(wrote, 0) self.assertEqual(len(inst.outbufs[0]), 0) def test_write_soon_nonempty_byte(self): inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b'a') + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + + wrote = inst.write_soon(b"a") self.assertEqual(wrote, 1) self.assertEqual(len(inst.outbufs[0]), 1) def test_write_soon_filewrapper(self): from waitress.buffers import ReadOnlyFileBasedBuffer - f = io.BytesIO(b'abc') + + f = io.BytesIO(b"abc") wrapper = ReadOnlyFileBasedBuffer(f, 8192) wrapper.prepare() inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + outbufs = inst.outbufs - orig_outbuf = outbufs[0] wrote = inst.write_soon(wrapper) self.assertEqual(wrote, 3) - self.assertEqual(len(outbufs), 3) - self.assertEqual(outbufs[0], orig_outbuf) - self.assertEqual(outbufs[1], wrapper) - self.assertEqual(outbufs[2].__class__.__name__, 'OverflowableBuffer') + self.assertEqual(len(outbufs), 2) + self.assertEqual(outbufs[0], wrapper) + self.assertEqual(outbufs[1].__class__.__name__, "OverflowableBuffer") + + def test_write_soon_disconnected(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.connected = False + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_disconnected_while_over_watermark(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + def dummy_flush(): + inst.connected = False + + inst._flush_outbufs_below_high_watermark = dummy_flush + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) + + def test_write_soon_rotates_outbuf_on_overflow(self): + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + + inst.adj.outbuf_high_watermark = 3 + inst.current_outbuf_count = 4 + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") + + def test_write_soon_waits_on_backpressure(self): + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super().wait() + + inst.outbuf_lock = Lock() + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") + self.assertTrue(inst.outbuf_lock.waited) + + def test_write_soon_attempts_flush_high_water_and_exception(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + inst.outbufs[0].append(b"test") + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super().wait() + + inst.outbuf_lock = Lock() + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"xyz")) + + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.outbuf_lock.waited) + self.assertTrue(inst.server.trigger_pulled) + + def test_write_soon_flush_and_exception(self): + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.server.trigger_pulled) + + def test_handle_write_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 5 + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertTrue(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"abc") + + def test_handle_write_no_notify_after_flush(self): + inst, sock, map = self._makeOneWithMap() + inst.requests = [True] + inst.outbufs = [DummyBuffer(b"abc")] + inst.total_outbufs_len = len(inst.outbufs[0]) + inst.adj.send_bytes = 1 + inst.adj.outbuf_high_watermark = 2 + sock.send = lambda x, do_close=True: False + inst.will_close = False + inst.last_activity = 0 + result = inst.handle_write() + self.assertEqual(result, None) + self.assertEqual(inst.will_close, False) + self.assertTrue(inst.outbuf_lock.acquired) + self.assertFalse(inst.outbuf_lock.notified) + self.assertEqual(sock.sent, b"") def test__flush_some_empty_outbuf(self): inst, sock, map = self._makeOneWithMap() @@ -232,22 +393,25 @@ def test__flush_some_empty_outbuf(self): def test__flush_some_full_outbuf_socket_returns_nonzero(self): inst, sock, map = self._makeOneWithMap() - inst.outbufs[0].append(b'abc') + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) result = inst._flush_some() self.assertEqual(result, True) def test__flush_some_full_outbuf_socket_returns_zero(self): inst, sock, map = self._makeOneWithMap() sock.send = lambda x: False - inst.outbufs[0].append(b'abc') + inst.outbufs[0].append(b"abc") + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) result = inst._flush_some() self.assertEqual(result, False) def test_flush_some_multiple_buffers_first_empty(self): inst, sock, map = self._makeOneWithMap() sock.send = lambda x: len(x) - buffer = DummyBuffer(b'abc') + buffer = DummyBuffer(b"abc") inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) result = inst._flush_some() self.assertEqual(result, True) self.assertEqual(buffer.skipped, 3) @@ -256,11 +420,14 @@ def test_flush_some_multiple_buffers_first_empty(self): def test_flush_some_multiple_buffers_close_raises(self): inst, sock, map = self._makeOneWithMap() sock.send = lambda x: len(x) - buffer = DummyBuffer(b'abc') + buffer = DummyBuffer(b"abc") inst.outbufs.append(buffer) + inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) inst.logger = DummyLogger() + def doraise(): raise NotImplementedError + inst.outbufs[0].close = doraise result = inst._flush_some() self.assertEqual(result, True) @@ -270,24 +437,28 @@ def doraise(): def test__flush_some_outbuf_len_gt_sys_maxint(self): from waitress.compat import MAXINT + inst, sock, map = self._makeOneWithMap() - class DummyHugeOutbuffer(object): + + class DummyHugeOutbuffer: def __init__(self): self.length = MAXINT + 1 + def __len__(self): return self.length + def get(self, numbytes): self.length = 0 - return b'123' - def skip(self, *args): pass + return b"123" + buf = DummyHugeOutbuffer() inst.outbufs = [buf] - inst.send = lambda *arg: 0 + inst.send = lambda *arg, do_close: 0 result = inst._flush_some() # we are testing that _flush_some doesn't raise an OverflowError # when one of its outbufs has a __len__ that returns gt sys.maxint self.assertEqual(result, False) - + def test_handle_close(self): inst, sock, map = self._makeOneWithMap() inst.handle_close() @@ -296,8 +467,10 @@ def test_handle_close(self): def test_handle_close_outbuf_raises_on_close(self): inst, sock, map = self._makeOneWithMap() + def doraise(): raise NotImplementedError + inst.outbufs[0].close = doraise inst.logger = DummyLogger() inst.handle_close() @@ -323,13 +496,13 @@ def test_del_channel(self): def test_received(self): inst, sock, map = self._makeOneWithMap() inst.server = DummyServer() - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.server.tasks, [inst]) self.assertTrue(inst.requests) def test_received_no_chunk(self): inst, sock, map = self._makeOneWithMap() - self.assertEqual(inst.received(b''), False) + self.assertEqual(inst.received(b""), False) def test_received_preq_not_completed(self): inst, sock, map = self._makeOneWithMap() @@ -338,8 +511,8 @@ def test_received_preq_not_completed(self): inst.request = preq preq.completed = False preq.empty = True - inst.received(b'GET / HTTP/1.1\n\n') - self.assertEqual(inst.requests, ()) + inst.received(b"GET / HTTP/1.1\r\n\r\n") + self.assertEqual(inst.requests, []) self.assertEqual(inst.server.tasks, []) def test_received_preq_completed_empty(self): @@ -349,7 +522,7 @@ def test_received_preq_completed_empty(self): inst.request = preq preq.completed = True preq.empty = True - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.request, None) self.assertEqual(inst.server.tasks, []) @@ -360,7 +533,7 @@ def test_received_preq_error(self): inst.request = preq preq.completed = True preq.error = True - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.request, None) self.assertEqual(len(inst.server.tasks), 1) self.assertTrue(inst.requests) @@ -373,24 +546,10 @@ def test_received_preq_completed_connection_close(self): preq.completed = True preq.empty = True preq.connection_close = True - inst.received(b'GET / HTTP/1.1\n\n' + b'a' * 50000) + inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000) self.assertEqual(inst.request, None) self.assertEqual(inst.server.tasks, []) - def test_received_preq_completed_n_lt_data(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.empty = False - line = b'GET / HTTP/1.1\n\n' - preq.retval = len(line) - inst.received(line + line) - self.assertEqual(inst.request, None) - self.assertEqual(len(inst.requests), 2) - self.assertEqual(len(inst.server.tasks), 1) - def test_received_headers_finished_expect_continue_false(self): inst, sock, map = self._makeOneWithMap() inst.server = DummyServer() @@ -401,10 +560,10 @@ def test_received_headers_finished_expect_continue_false(self): preq.completed = False preq.empty = False preq.retval = 1 - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) - self.assertEqual(inst.outbufs[0].get(100), b'') + self.assertEqual(inst.outbufs[0].get(100), b"") def test_received_headers_finished_expect_continue_true(self): inst, sock, map = self._makeOneWithMap() @@ -415,10 +574,10 @@ def test_received_headers_finished_expect_continue_true(self): preq.headers_finished = True preq.completed = False preq.empty = False - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b'HTTP/1.1 100 Continue\r\n\r\n') + self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n") self.assertEqual(inst.sent_continue, True) self.assertEqual(preq.completed, False) @@ -432,21 +591,13 @@ def test_received_headers_finished_expect_continue_true_sent_true(self): preq.completed = False preq.empty = False inst.sent_continue = True - inst.received(b'GET / HTTP/1.1\n\n') + inst.received(b"GET / HTTP/1.1\r\n\r\n") self.assertEqual(inst.request, preq) self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b'') + self.assertEqual(sock.sent, b"") self.assertEqual(inst.sent_continue, True) self.assertEqual(preq.completed, False) - def test_service_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(inst.force_flush) - self.assertTrue(inst.last_activity) - def test_service_with_one_request(self): inst, sock, map = self._makeOneWithMap() request = DummyRequest() @@ -475,6 +626,7 @@ def test_service_with_multiple_requests(self): inst.task_class = DummyTaskClass() inst.requests = [request1, request2] inst.service() + inst.service() self.assertEqual(inst.requests, []) self.assertTrue(request1.serviced) self.assertTrue(request2.serviced) @@ -495,7 +647,7 @@ def test_service_with_request_raises(self): self.assertTrue(request.serviced) self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.force_flush) + self.assertTrue(inst.server.trigger_pulled) self.assertTrue(inst.last_activity) self.assertFalse(inst.will_close) self.assertEqual(inst.error_task_class.serviced, True) @@ -514,7 +666,7 @@ def test_service_with_requests_raises_already_wrote_header(self): self.assertTrue(request.serviced) self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.force_flush) + self.assertTrue(inst.server.trigger_pulled) self.assertTrue(inst.last_activity) self.assertTrue(inst.close_when_flushed) self.assertEqual(inst.error_task_class.serviced, False) @@ -535,7 +687,7 @@ def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): self.assertFalse(inst.will_close) self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.force_flush) + self.assertTrue(inst.server.trigger_pulled) self.assertTrue(inst.last_activity) self.assertEqual(inst.error_task_class.serviced, True) self.assertTrue(request.closed) @@ -553,11 +705,59 @@ def test_service_with_requests_raises_didnt_write_header(self): self.assertTrue(request.serviced) self.assertEqual(inst.requests, []) self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.force_flush) + self.assertTrue(inst.server.trigger_pulled) self.assertTrue(inst.last_activity) self.assertTrue(inst.close_when_flushed) self.assertTrue(request.closed) + def test_service_with_request_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + inst.requests = [request] + inst.task_class = DummyTaskClass(ClientDisconnected) + inst.error_task_class = DummyTaskClass() + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.infos), 1) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.error_task_class.serviced, False) + self.assertTrue(request.closed) + + def test_service_with_request_error_raises_disconnect(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + inst.adj.expose_tracebacks = False + inst.server = DummyServer() + request = DummyRequest() + err_request = DummyRequest() + inst.requests = [request] + inst.parser_class = lambda x: err_request + inst.task_class = DummyTaskClass(RuntimeError) + inst.task_class.wrote_header = False + inst.error_task_class = DummyTaskClass(ClientDisconnected) + inst.logger = DummyLogger() + inst.service() + self.assertTrue(request.serviced) + self.assertTrue(err_request.serviced) + self.assertEqual(inst.requests, []) + self.assertEqual(len(inst.logger.exceptions), 1) + self.assertEqual(len(inst.logger.infos), 0) + self.assertTrue(inst.server.trigger_pulled) + self.assertTrue(inst.last_activity) + self.assertFalse(inst.will_close) + self.assertEqual(inst.task_class.serviced, True) + self.assertEqual(inst.error_task_class.serviced, True) + self.assertTrue(request.closed) + def test_cancel_no_requests(self): inst, sock, map = self._makeOneWithMap() inst.requests = () @@ -570,16 +770,144 @@ def test_cancel_with_requests(self): inst.cancel() self.assertEqual(inst.requests, []) - def test_defer(self): - inst, sock, map = self._makeOneWithMap() - self.assertEqual(inst.defer(), None) -class DummySock(object): +class TestHTTPChannelLookahead(TestHTTPChannel): + def app_check_disconnect(self, environ, start_response): + """ + Application that checks for client disconnection every + second for up to two seconds. + """ + import time + + if hasattr(self, "app_started"): + self.app_started.set() + + try: + request_body_size = int(environ.get("CONTENT_LENGTH", 0)) + except ValueError: + request_body_size = 0 + self.request_body = environ["wsgi.input"].read(request_body_size) + + self.disconnect_detected = False + check = environ["waitress.client_disconnected"] + if environ["PATH_INFO"] == "/sleep": + for i in range(3): + if i != 0: + time.sleep(1) + if check(): + self.disconnect_detected = True + break + + body = b"finished" + cl = str(len(body)) + start_response( + "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] + ) + return [body] + + def _make_app_with_lookahead(self): + """ + Setup a channel with lookahead and store it and the socket in self + """ + adj = DummyAdjustments() + adj.channel_request_lookahead = 5 + channel, sock, map = self._makeOneWithMap(adj=adj) + channel.server.application = self.app_check_disconnect + + self.channel = channel + self.sock = sock + + def _send(self, *lines): + """ + Send lines through the socket with correct line endings + """ + self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii")) + + def test_client_disconnect(self, close_before_start=False): + """Disconnect the socket after starting the task.""" + import threading + + self._make_app_with_lookahead() + self._send( + "GET /sleep HTTP/1.1", + "Host: localhost:8080", + "", + ) + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + self.assertEqual(len(self.channel.server.tasks), 1) + self.app_started = threading.Event() + self.disconnect_detected = False + thread = threading.Thread(target=self.channel.server.tasks[0].service) + + if not close_before_start: + thread.start() + self.assertTrue(self.app_started.wait(timeout=5)) + + # Close the socket, check that the channel is still readable due to the + # lookahead and read it, which marks the channel as closed. + self.sock.close() + self.assertTrue(self.channel.readable()) + self.channel.handle_read() + + if close_before_start: + thread.start() + + thread.join() + + if close_before_start: + self.assertFalse(self.app_started.is_set()) + else: + self.assertTrue(self.disconnect_detected) + + def test_client_disconnect_immediate(self): + """ + The same test, but this time we close the socket even before processing + started. The app should not be executed. + """ + self.test_client_disconnect(close_before_start=True) + + def test_lookahead_continue(self): + """ + Send two requests to a channel with lookahead and use an + expect-continue on the second one, making sure the responses still come + in the right order. + """ + self._make_app_with_lookahead() + self._send( + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "", + "x", + "POST / HTTP/1.1", + "Host: localhost:8080", + "Content-Length: 1", + "Expect: 100-continue", + "", + ) + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + data = self.sock.recv(256).decode("ascii") + self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n")) + + self.sock.send(b"x") + self.channel.handle_read() + self.assertEqual(len(self.channel.requests), 1) + self.channel.server.tasks[0].service() + self.channel._flush_some() + data = self.sock.recv(256).decode("ascii") + self.assertEqual(data.split("\r\n")[-1], "finished") + self.assertEqual(self.request_body, b"x") + + +class DummySock: blocking = False closed = False def __init__(self): - self.sent = b'' + self.sent = b"" def setblocking(self, *arg): self.blocking = True @@ -588,7 +916,10 @@ def fileno(self): return 100 def getpeername(self): - return '127.0.0.1' + return "127.0.0.1" + + def getsockopt(self, level, option): + return 2048 def close(self): self.closed = True @@ -597,7 +928,14 @@ def send(self, data): self.sent += data return len(data) -class DummyLock(object): + def recv(self, buffer_size): + result = self.sent[:buffer_size] + self.sent = self.sent[buffer_size:] + return result + + +class DummyLock: + notified = False def __init__(self, acquirable=True): self.acquirable = acquirable @@ -610,13 +948,20 @@ def acquire(self, val): def release(self): self.released = True + def notify(self): + self.notified = True + + def wait(self): + self.waited = True + def __exit__(self, type, val, traceback): self.acquire(True) def __enter__(self): pass -class DummyBuffer(object): + +class DummyBuffer: closed = False def __init__(self, data, toraise=None): @@ -627,7 +972,7 @@ def get(self, *arg): if self.toraise: raise self.toraise data = self.data - self.data = b'' + self.data = b"" return data def skip(self, num, x): @@ -639,22 +984,30 @@ def __len__(self): def close(self): self.closed = True -class DummyAdjustments(object): + +class DummyAdjustments: outbuf_overflow = 1048576 + outbuf_high_watermark = 1048576 inbuf_overflow = 512000 cleanup_interval = 900 - send_bytes = 9000 - url_scheme = 'http' + url_scheme = "http" channel_timeout = 300 log_socket_errors = True recv_bytes = 8192 + send_bytes = 1 expose_tracebacks = True - ident = 'waitress' + ident = "waitress" max_request_header_size = 10000 + url_prefix = "" + channel_request_lookahead = 0 + max_request_body_size = 1048576 + -class DummyServer(object): +class DummyServer: trigger_pulled = False adj = DummyAdjustments() + effective_port = 8080 + server_name = "" def __init__(self): self.tasks = [] @@ -666,7 +1019,8 @@ def add_task(self, task): def pull_trigger(self): self.trigger_pulled = True -class DummyParser(object): + +class DummyParser: version = 1 data = None completed = True @@ -683,10 +1037,11 @@ def received(self, data): return self.retval return len(data) -class DummyRequest(object): + +class DummyRequest: error = None - path = '/' - version = '1.0' + path = "/" + version = "1.0" closed = False def __init__(self): @@ -695,20 +1050,27 @@ def __init__(self): def close(self): self.closed = True -class DummyLogger(object): +class DummyLogger: def __init__(self): self.exceptions = [] + self.infos = [] + self.warnings = [] + + def info(self, msg): + self.infos.append(msg) def exception(self, msg): self.exceptions.append(msg) -class DummyError(object): - code = '431' - reason = 'Bleh' - body = 'My body' -class DummyTaskClass(object): +class DummyError: + code = "431" + reason = "Bleh" + body = "My body" + + +class DummyTaskClass: wrote_header = True close_on_finish = False serviced = False diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 00000000..1dfd8891 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,1782 @@ +import errno +from http import client as httplib +import logging +import multiprocessing +import os +import signal +import socket +import string +import subprocess +import sys +import time +import unittest + +from waitress import server +from waitress.compat import WIN +from waitress.utilities import cleanup_unix_socket + +dn = os.path.dirname +here = dn(__file__) + + +class NullHandler(logging.Handler): # pragma: no cover + """A logging handler that swallows all emitted messages.""" + + def emit(self, record): + pass + + +def start_server(app, svr, queue, **kwargs): # pragma: no cover + """Run a fixture application.""" + logging.getLogger("waitress").addHandler(NullHandler()) + try_register_coverage() + svr(app, queue, **kwargs).run() + + +def try_register_coverage(): # pragma: no cover + # Hack around multiprocessing exiting early and not triggering coverage's + # atexit handler by always registering a signal handler + + if "COVERAGE_PROCESS_START" in os.environ: + + def sigterm(*args): + sys.exit(0) + + signal.signal(signal.SIGTERM, sigterm) + + +class FixtureTcpWSGIServer(server.TcpWSGIServer): + """A version of TcpWSGIServer that relays back what it's bound to.""" + + family = socket.AF_INET # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + kw["host"] = "127.0.0.1" + kw["port"] = 0 # Bind to any available port. + super().__init__(application, **kw) + host, port = self.socket.getsockname() + + if os.name == "nt": + host = "127.0.0.1" + queue.put((host, port)) + + +class SubprocessTests: + + exe = sys.executable + + server = None + + def start_subprocess(self, target, **kw): + # Spawn a server process. + self.queue = multiprocessing.Queue() + + if "COVERAGE_RCFILE" in os.environ: + os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"] + + if not WIN: + ctx = multiprocessing.get_context("fork") + else: + ctx = multiprocessing.get_context("spawn") + + self.proc = ctx.Process( + target=start_server, + args=(target, self.server, self.queue), + kwargs=kw, + ) + self.proc.start() + + if self.proc.exitcode is not None: # pragma: no cover + raise RuntimeError("%s didn't start" % str(target)) + # Get the socket the server is listening on. + self.bound_to = self.queue.get(timeout=5) + self.sock = self.create_socket() + + def stop_subprocess(self): + if self.proc.exitcode is None: + self.proc.terminate() + self.sock.close() + # This give us one FD back ... + self.proc.join() + self.proc.close() + self.queue.close() + self.queue.join_thread() + + # The following is for the benefit of PyPy 3, for some reason it is + # holding on to some resources way longer than necessary causing tests + # to fail with file desctriptor exceeded errors on macOS which defaults + # to 256 file desctriptors per process. While we could use ulimit to + # increase the limits before running tests, this works as well and + # means we don't need to remember to do that. + import gc + + gc.collect() + + def assertline(self, line, status, reason, version): + v, s, r = (x.strip() for x in line.split(None, 2)) + self.assertEqual(s, status.encode("latin-1")) + self.assertEqual(r, reason.encode("latin-1")) + self.assertEqual(v, version.encode("latin-1")) + + def create_socket(self): + return socket.socket(self.server.family, socket.SOCK_STREAM) + + def connect(self): + self.sock.connect(self.bound_to) + + def make_http_connection(self): + raise NotImplementedError # pragma: no cover + + def send_check_error(self, to_send): + self.sock.send(to_send) + + +class TcpTests(SubprocessTests): + + server = FixtureTcpWSGIServer + + def make_http_connection(self): + return httplib.HTTPConnection(*self.bound_to) + + +class SleepyThreadTests(TcpTests, unittest.TestCase): + # test that sleepy thread doesnt block other requests + + def setUp(self): + from tests.fixtureapps import sleepy + + self.start_subprocess(sleepy.app) + + def tearDown(self): + self.stop_subprocess() + + def test_it(self): + getline = os.path.join(here, "fixtureapps", "getline.py") + cmds = ( + [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to], + [self.exe, getline, "http://%s:%d/" % self.bound_to], + ) + r, w = os.pipe() + procs = [] + + for cmd in cmds: + procs.append(subprocess.Popen(cmd, stdout=w)) + time.sleep(3) + + for proc in procs: + if proc.returncode is not None: # pragma: no cover + proc.terminate() + proc.wait() + # the notsleepy response should always be first returned (it sleeps + # for 2 seconds, then returns; the notsleepy response should be + # processed in the meantime) + result = os.read(r, 10000) + os.close(r) + os.close(w) + self.assertEqual(result, b"notsleepy returnedsleepy returned") + + +class EchoTests: + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess( + echo.app, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"}, + clear_untrusted_proxy_headers=True, + ) + + def tearDown(self): + self.stop_subprocess() + + def _read_echo(self, fp): + from tests.fixtureapps import echo + + line, headers, body = read_http(fp) + + return line, headers, echo.parse_response(body) + + def test_date_and_server(self): + to_send = b"GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_bad_host_header(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + to_send = b"GET / HTTP/1.0\r\n Host: 0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + + def test_send_with_body(self): + to_send = b"GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += b"hello" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "5") + self.assertEqual(echo.body, b"hello") + + def test_send_empty_body(self): + to_send = b"GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "0") + self.assertEqual(echo.body, b"") + + def test_multiple_requests_with_body(self): + orig_sock = self.sock + + for x in range(3): + self.sock = self.create_socket() + self.test_send_with_body() + self.sock.close() + self.sock = orig_sock + + def test_multiple_requests_without_body(self): + orig_sock = self.sock + + for x in range(3): + self.sock = self.create_socket() + self.test_send_empty_body() + self.sock.close() + self.sock = orig_sock + + def test_without_crlf(self): + data = b"Echo\r\nthis\r\nplease" + s = ( + b"GET / HTTP/1.0\r\n" + b"Connection: close\r\n" + b"Content-Length: %d\r\n" + b"\r\n" + b"%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(int(echo.content_length), len(data)) + self.assertEqual(len(echo.body), len(data)) + self.assertEqual(echo.body, (data)) + + def test_large_body(self): + # 1024 characters. + body = b"This string has 32 characters.\r\n" * 32 + s = b"GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body) + self.connect() + self.sock.send(s) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(echo.content_length, "1024") + self.assertEqual(echo.body, body) + + def test_many_clients(self): + conns = [] + + for n in range(50): + h = self.make_http_connection() + h.request("GET", "/", headers={"Accept": "text/plain"}) + conns.append(h) + responses = [] + + for h in conns: + response = h.getresponse() + self.assertEqual(response.status, 200) + responses.append(response) + + for response in responses: + response.read() + + for h in conns: + h.close() + + def test_chunking_request_without_content(self): + header = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + self.connect() + self.sock.send(header) + self.sock.send(b"0\r\n\r\n") + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, b"") + self.assertEqual(echo.content_length, "0") + self.assertFalse("transfer-encoding" in headers) + + def test_chunking_request_with_content(self): + control_line = b"20\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + expected = s * 12 + header = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + self.connect() + self.sock.send(header) + with self.sock.makefile("rb", 0) as fp: + for n in range(12): + self.sock.send(control_line) + self.sock.send(s) + self.sock.send(b"\r\n") # End the chunk + self.sock.send(b"0\r\n\r\n") + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(echo.body, expected) + self.assertEqual(echo.content_length, str(len(expected))) + self.assertFalse("transfer-encoding" in headers) + + def test_broken_chunked_encoding(self): + control_line = b"20\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + b"\r\n" + # garbage in input + to_send += b"garbage\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["content-type"], "text/plain; charset=utf-8") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_broken_chunked_encoding_invalid_hex(self): + control_line = b"0x20\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + b"\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertIn(b"Invalid chunk size", response_body) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["content-type"], "text/plain; charset=utf-8") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_broken_chunked_encoding_invalid_extension(self): + control_line = b"20;invalid=\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + b"\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertIn(b"Invalid chunk extension", response_body) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["content-type"], "text/plain; charset=utf-8") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_broken_chunked_encoding_missing_chunk_end(self): + control_line = b"20\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + to_send += control_line + s + # garbage in input + to_send += b"garbage" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # receiver caught garbage and turned it into a 400 + self.assertline(line, "400", "Bad Request", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(b"Chunk not properly terminated" in response_body) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["content-type"], "text/plain; charset=utf-8") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_keepalive_http_10(self): + # Handling of Keep-Alive within HTTP 1.0 + data = b"Default: Don't keep me alive" + s = b"GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + # We sent no Connection: Keep-Alive header + # Connection: close (or no header) is default. + self.assertTrue(connection != "Keep-Alive") + + def test_keepalive_http10_explicit(self): + # If header Connection: Keep-Alive is explicitly sent, + # we want to keept the connection open, we also need to return + # the corresponding header + data = b"Keep me alive" + s = ( + b"GET / HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: %d\r\n" + b"\r\n" + b"%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + connection = response.getheader("Connection", "") + self.assertEqual(connection, "Keep-Alive") + + def test_keepalive_http_11(self): + # Handling of Keep-Alive within HTTP 1.1 + + # All connections are kept alive, unless stated otherwise + data = b"Default: Keep me alive" + s = b"GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_explicit(self): + # Explicitly set keep-alive + data = b"Default: Keep me alive" + s = ( + b"GET / HTTP/1.1\r\n" + b"Connection: keep-alive\r\n" + b"Content-Length: %d\r\n" + b"\r\n" + b"%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertTrue(response.getheader("connection") != "close") + + def test_keepalive_http11_connclose(self): + # specifying Connection: close explicitly + data = b"Don't keep me alive" + s = ( + b"GET / HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Content-Length: %d\r\n" + b"\r\n" + b"%s" % (len(data), data) + ) + self.connect() + self.sock.send(s) + response = httplib.HTTPResponse(self.sock) + response.begin() + self.assertEqual(int(response.status), 200) + self.assertEqual(response.getheader("connection"), "close") + + def test_proxy_headers(self): + to_send = ( + b"GET / HTTP/1.0\r\n" + b"Content-Length: 0\r\n" + b"Host: www.google.com:8080\r\n" + b"X-Forwarded-For: 192.168.1.1\r\n" + b"X-Forwarded-Proto: https\r\n" + b"X-Forwarded-Port: 5000\r\n\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, echo = self._read_echo(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("server"), "waitress") + self.assertTrue(headers.get("date")) + self.assertIsNone(echo.headers.get("X_FORWARDED_PORT")) + self.assertEqual(echo.headers["HOST"], "www.google.com:8080") + self.assertEqual(echo.scheme, "https") + self.assertEqual(echo.remote_addr, "192.168.1.1") + self.assertEqual(echo.remote_host, "192.168.1.1") + + +class PipeliningTests: + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_pipelining(self): + s = ( + b"GET / HTTP/1.0\r\n" + b"Connection: %s\r\n" + b"Content-Length: %d\r\n" + b"\r\n" + b"%s" + ) + to_send = b"" + count = 25 + + for n in range(count): + body = b"Response #%d\r\n" % (n + 1) + + if n + 1 < count: + conn = b"keep-alive" + else: + conn = b"close" + to_send += s % (conn, len(body), body) + + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + for n in range(count): + expect_body = b"Response #%d\r\n" % (n + 1) + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, expect_body) + + +class ExpectContinueTests: + def setUp(self): + from tests.fixtureapps import echo + + self.start_subprocess(echo.app_body_only) + + def tearDown(self): + self.stop_subprocess() + + def test_expect_continue(self): + # specifying Connection: close explicitly + data = b"I have expectations" + to_send = ( + b"GET / HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Content-Length: %d\r\n" + b"Expect: 100-continue\r\n" + b"\r\n" + b"%s" % (len(data), data) + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line = fp.readline() # continue status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + self.assertEqual(int(status), 100) + self.assertEqual(reason, b"Continue") + self.assertEqual(version, b"HTTP/1.1") + fp.readline() # blank line + line = fp.readline() # next status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + length = int(headers.get("content-length")) or None + response_body = fp.read(length) + self.assertEqual(int(status), 200) + self.assertEqual(length, len(response_body)) + self.assertEqual(response_body, data) + + +class BadContentLengthTests: + def setUp(self): + from tests.fixtureapps import badcl + + self.start_subprocess(badcl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = ( + b"GET /short_body HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertNotEqual(content_length, len(response_body)) + self.assertEqual(len(response_body), content_length - 1) + self.assertEqual(response_body, b"abcdefghi") + # remote closed connection (despite keepalive header); not sure why + # first send succeeds + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too short + # for cl header + to_send = ( + b"GET /long_body HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, b"abcdefgh") + # remote does not close connection (keepalive header) + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line = fp.readline() # status line + version, status, reason = (x.strip() for x in line.split(None, 2)) + headers = parse_headers(fp) + content_length = int(headers.get("content-length")) or None + response_body = fp.read(content_length) + self.assertEqual(int(status), 200) + + +class NoContentLengthTests: + def setUp(self): + from tests.fixtureapps import nocl + + self.start_subprocess(nocl.app) + + def tearDown(self): + self.stop_subprocess() + + def test_http10_generator(self): + body = string.ascii_letters.encode("latin-1") + to_send = ( + b"GET / HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, body) + # remote closed connection (despite keepalive header), because + # generators cannot have a content-length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http10_list(self): + body = string.ascii_letters.encode("latin-1") + to_send = ( + b"GET /list HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(headers.get("connection"), "Keep-Alive") + self.assertEqual(response_body, body) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_http10_listlentwo(self): + body = string.ascii_letters.encode("latin-1") + to_send = ( + b"GET /list_lentwo HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: %d\r\n\r\n" % len(body) + ) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(headers.get("content-length"), None) + self.assertEqual(headers.get("connection"), "close") + self.assertEqual(response_body, body) + # remote closed connection (despite keepalive header), because + # lists of length > 1 cannot have their content length divined + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_generator(self): + body = string.ascii_letters + body = body.encode("latin-1") + to_send = b"GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + + for chunk in chunks(body, 10): + expected += b"%s\r\n%s\r\n" % ( + hex(len(chunk))[2:].upper().encode("latin-1"), + chunk, + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_http11_list(self): + body = string.ascii_letters.encode("latin-1") + to_send = b"GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(headers["content-length"], str(len(body))) + self.assertEqual(response_body, body) + # remote keeps connection open because it divined the content length + # from a length-1 list + self.sock.send(to_send) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + + def test_http11_listlentwo(self): + body = string.ascii_letters.encode("latin-1") + to_send = b"GET /list_lentwo HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + expected = b"" + + for chunk in (body[:1], body[1:]): + expected += b"%s\r\n%s\r\n" % ( + (hex(len(chunk))[2:].upper().encode("latin-1")), + chunk, + ) + expected += b"0\r\n\r\n" + self.assertEqual(response_body, expected) + # connection is always closed at the end of a chunked response + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class WriteCallbackTests: + def setUp(self): + from tests.fixtureapps import writecb + + self.start_subprocess(writecb.app) + + def tearDown(self): + self.stop_subprocess() + + def test_short_body(self): + # check to see if server closes connection when body is too short + # for cl header + to_send = ( + b"GET /short_body HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, 9) + self.assertNotEqual(cl, len(response_body)) + self.assertEqual(len(response_body), cl - 1) + self.assertEqual(response_body, b"abcdefgh") + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_long_body(self): + # check server doesnt close connection when body is too long + # for cl header + to_send = ( + b"GET /long_body HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, b"abcdefghi") + # remote does not close connection (keepalive header) + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_equal_body(self): + # check server doesnt close connection when body is equal to + # cl header + to_send = ( + b"GET /equal_body HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + content_length = int(headers.get("content-length")) or None + self.assertEqual(content_length, 9) + self.assertline(line, "200", "OK", "HTTP/1.0") + self.assertEqual(content_length, len(response_body)) + self.assertEqual(response_body, b"abcdefghi") + # remote does not close connection (keepalive header) + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + + def test_no_content_length(self): + # wtf happens when there's no content-length + to_send = ( + b"GET /no_content_length HTTP/1.0\r\n" + b"Connection: Keep-Alive\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line = fp.readline() # status line + line, headers, response_body = read_http(fp) + content_length = headers.get("content-length") + self.assertEqual(content_length, None) + self.assertEqual(response_body, b"abcdefghi") + # remote closed connection (despite keepalive header) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TooLargeTests: + + toobig = 1050 + + def setUp(self): + from tests.fixtureapps import toolarge + + self.start_subprocess( + toolarge.app, max_request_header_size=1000, max_request_body_size=1000 + ) + + def tearDown(self): + self.stop_subprocess() + + def test_request_headers_too_large_http11(self): + body = b"" + bad_headers = b"X-Random-Header: 100\r\n" * int(self.toobig / 20) + to_send = b"GET / HTTP/1.1\r\nContent-Length: 0\r\n" + to_send += bad_headers + to_send += b"\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + response_line, headers, response_body = read_http(fp) + self.assertline( + response_line, "431", "Request Header Fields Too Large", "HTTP/1.0" + ) + self.assertEqual(headers["connection"], "close") + + def test_request_body_too_large_with_wrong_cl_http10(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server trusts the content-length header; no pipelining, + # so request fulfilled, extra bytes are thrown away + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): + body = b"a" * self.toobig + to_send = ( + b"GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" + ) + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.0\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # extra bytes are thrown away (no pipelining), connection closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http10_keepalive(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed zero) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + line, headers, response_body = read_http(fp) + # next response overruns because the extra data appears to be + # header data + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + # first request succeeds (content-length 5) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_wrong_cl_http11_connclose(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # server trusts the content-length header (5) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.1\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb") as fp: + # server trusts the content-length header (assumed 0) + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # server assumes pipelined requests due to http/1.1, and the first + # request was assumed c-l 0 because it had no content-length header, + # so entire body looks like the header of the subsequent request + # second response is an error response + line, headers, response_body = read_http(fp) + self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_with_no_cl_http11_connclose(self): + body = b"a" * self.toobig + to_send = b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n" + to_send += body + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # server trusts the content-length header (assumed 0) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_request_body_too_large_chunked_encoding(self): + control_line = b"20;\r\n" # 20 hex = 32 dec + s = b"This string has 32 characters.\r\n" + to_send = b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + repeat = control_line + s + to_send += repeat * ((self.toobig // len(repeat)) + 1) + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + # body bytes counter caught a max_request_body_size overrun + self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertEqual(headers["content-type"], "text/plain; charset=utf-8") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class InternalServerErrorTests: + def setUp(self): + from tests.fixtureapps import error + + self.start_subprocess(error.app, expose_tracebacks=True) + + def tearDown(self): + self.stop_subprocess() + + def test_before_start_response_http_10(self): + to_send = b"GET /before_start_response HTTP/1.0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11(self): + to_send = b"GET /before_start_response HTTP/1.1\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_before_start_response_http_11_close(self): + to_send = b"GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http10(self): + to_send = b"GET /after_start_response HTTP/1.0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11(self): + to_send = b"GET /after_start_response HTTP/1.1\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_start_response_http11_close(self): + to_send = b"GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_after_write_cb(self): + to_send = b"GET /after_write_cb HTTP/1.1\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_in_generator(self): + to_send = b"GET /in_generator HTTP/1.1\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + self.assertEqual(response_body, b"") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class InternalServerErrorTestsWithTraceback: + def setUp(self): + from tests.fixtureapps import error_traceback + + self.start_subprocess(error_traceback.app, expose_tracebacks=True) + + def tearDown(self): + self.stop_subprocess() + + def test_expose_tracebacks_http_10(self): + to_send = b"GET / HTTP/1.0\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual(headers["connection"], "close") + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_expose_tracebacks_http_11(self): + to_send = b"GET / HTTP/1.1\r\n\r\n" + self.connect() + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + self.assertTrue(response_body.startswith(b"Internal Server Error")) + self.assertEqual( + sorted(headers.keys()), + ["connection", "content-length", "content-type", "date", "server"], + ) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class FileWrapperTests: + def setUp(self): + from tests.fixtureapps import filewrapper + + self.start_subprocess(filewrapper.app) + + def tearDown(self): + self.stop_subprocess() + + def test_filelike_http11(self): + to_send = b"GET /filelike HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_nocl_http11(self): + to_send = b"GET /filelike_nocl HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_filelike_shortcl_http11(self): + to_send = b"GET /filelike_shortcl HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_filelike_longcl_http11(self): + to_send = b"GET /filelike_longcl HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_http11(self): + to_send = b"GET /notfilelike HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_iobase_http11(self): + to_send = b"GET /notfilelike_iobase HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has not been closed + + def test_notfilelike_nocl_http11(self): + to_send = b"GET /notfilelike_nocl HTTP/1.1\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_shortcl_http11(self): + to_send = b"GET /notfilelike_shortcl HTTP/1.1\r\n\r\n" + + self.connect() + + for t in range(0, 2): + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, 1) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377" in response_body) + # connection has not been closed + + def test_notfilelike_longcl_http11(self): + to_send = b"GET /notfilelike_longcl HTTP/1.1\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.1") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body) + 10) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_http10(self): + to_send = b"GET /filelike HTTP/1.0\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_filelike_nocl_http10(self): + to_send = b"GET /filelike_nocl HTTP/1.0\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_http10(self): + to_send = b"GET /notfilelike HTTP/1.0\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + cl = int(headers["content-length"]) + self.assertEqual(cl, len(response_body)) + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + def test_notfilelike_nocl_http10(self): + to_send = b"GET /notfilelike_nocl HTTP/1.0\r\n\r\n" + + self.connect() + + self.sock.send(to_send) + with self.sock.makefile("rb", 0) as fp: + line, headers, response_body = read_http(fp) + self.assertline(line, "200", "OK", "HTTP/1.0") + ct = headers["content-type"] + self.assertEqual(ct, "image/jpeg") + self.assertTrue(b"\377\330\377" in response_body) + # connection has been closed (no content-length) + self.send_check_error(to_send) + self.assertRaises(ConnectionClosed, read_http, fp) + + +class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): + pass + + +class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): + pass + + +class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): + pass + + +class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase): + pass + + +class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): + pass + + +class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): + pass + + +class TcpInternalServerErrorTests( + InternalServerErrorTests, TcpTests, unittest.TestCase +): + pass + + +class TcpInternalServerErrorTestsWithTraceback( + InternalServerErrorTestsWithTraceback, TcpTests, unittest.TestCase +): + pass + + +class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): + pass + + +if hasattr(socket, "AF_UNIX"): + + class FixtureUnixWSGIServer(server.UnixWSGIServer): + """A version of UnixWSGIServer that relays back what it's bound to.""" + + family = socket.AF_UNIX # Testing + + def __init__(self, application, queue, **kw): # pragma: no cover + # Coverage doesn't see this as it's ran in a separate process. + # To permit parallel testing, use a PID-dependent socket. + kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid() + super().__init__(application, **kw) + queue.put(self.socket.getsockname()) + + class UnixTests(SubprocessTests): + + server = FixtureUnixWSGIServer + + def make_http_connection(self): + return UnixHTTPConnection(self.bound_to) + + def stop_subprocess(self): + super().stop_subprocess() + cleanup_unix_socket(self.bound_to) + + def send_check_error(self, to_send): + # Unlike inet domain sockets, Unix domain sockets can trigger a + # 'Broken pipe' error when the socket it closed. + try: + self.sock.send(to_send) + except OSError as exc: + valid_errors = {errno.EPIPE, errno.ENOTCONN} + self.assertIn(get_errno(exc), valid_errors) + + class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): + pass + + class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): + pass + + class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase): + pass + + class UnixBadContentLengthTests( + BadContentLengthTests, UnixTests, unittest.TestCase + ): + pass + + class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase): + pass + + class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase): + pass + + class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): + pass + + class UnixInternalServerErrorTests( + InternalServerErrorTests, UnixTests, unittest.TestCase + ): + pass + + class UnixInternalServerErrorTestsWithTraceback( + InternalServerErrorTestsWithTraceback, UnixTests, unittest.TestCase + ): + pass + + class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): + pass + + +def parse_headers(fp): + """Parses only RFC2822 headers from a file pointer.""" + headers = {} + + while True: + line = fp.readline() + + if line in (b"\r\n", b"\n", b""): + break + line = line.decode("iso-8859-1") + name, value = line.strip().split(":", 1) + headers[name.lower().strip()] = value.lower().strip() + + return headers + + +class UnixHTTPConnection(httplib.HTTPConnection): + """Patched version of HTTPConnection that uses Unix domain sockets.""" + + def __init__(self, path): + httplib.HTTPConnection.__init__(self, "localhost") + self.path = path + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(self.path) + self.sock = sock + + def close(self): + self.sock.close() + + +class ConnectionClosed(Exception): + pass + + +# stolen from gevent +def read_http(fp): # pragma: no cover + try: + response_line = fp.readline() + except OSError as exc: + fp.close() + # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET + + if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): + raise ConnectionClosed + raise + + if not response_line: + raise ConnectionClosed + + header_lines = [] + + while True: + line = fp.readline() + + if line in (b"\r\n", b"\r\n", b""): + break + else: + header_lines.append(line) + headers = dict() + + for x in header_lines: + x = x.strip() + + if not x: + continue + key, value = x.split(b": ", 1) + key = key.decode("iso-8859-1").lower() + value = value.decode("iso-8859-1") + assert key not in headers, "%s header duplicated" % key + headers[key] = value + + if "content-length" in headers: + num = int(headers["content-length"]) + body = b"" + left = num + + while left > 0: + data = fp.read(left) + + if not data: + break + body += data + left -= len(data) + else: + # read until EOF + body = fp.read() + + return response_line, headers, body + + +# stolen from gevent +def get_errno(exc): # pragma: no cover + """Get the error code out of socket.error objects. + socket.error in <2.5 does not have errno attribute + socket.error in 3.x does not allow indexing access + e.args[0] works for all. + There are cases when args[0] is not errno. + i.e. http://bugs.python.org/issue6471 + Maybe there are cases when errno is set, but it is not the first argument? + """ + try: + if exc.errno is not None: + return exc.errno + except AttributeError: + pass + try: + return exc.args[0] + except IndexError: + return None + + +def chunks(l, n): + """Yield successive n-sized chunks from l.""" + + for i in range(0, len(l), n): + yield l[i : i + n] diff --git a/waitress/tests/test_init.py b/tests/test_init.py similarity index 95% rename from waitress/tests/test_init.py rename to tests/test_init.py index 66c34ce8..c824c21f 100644 --- a/waitress/tests/test_init.py +++ b/tests/test_init.py @@ -1,9 +1,10 @@ import unittest -class Test_serve(unittest.TestCase): +class Test_serve(unittest.TestCase): def _callFUT(self, app, **kw): from waitress import serve + return serve(app, **kw) def test_it(self): @@ -14,10 +15,11 @@ def test_it(self): self.assertEqual(result, None) self.assertEqual(server.ran, True) -class Test_serve_paste(unittest.TestCase): +class Test_serve_paste(unittest.TestCase): def _callFUT(self, app, **kw): from waitress import serve_paste + return serve_paste(app, None, **kw) def test_it(self): @@ -28,7 +30,8 @@ def test_it(self): self.assertEqual(result, 0) self.assertEqual(server.ran, True) -class DummyServerFactory(object): + +class DummyServerFactory: ran = False def __call__(self, app, **kw): @@ -40,7 +43,8 @@ def __call__(self, app, **kw): def run(self): self.ran = True -class DummyAdj(object): + +class DummyAdj: verbose = False def __init__(self, kw): diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 00000000..9e9f1cda --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,734 @@ +############################################################################## +# +# Copyright (c) 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +"""HTTP Request Parser tests +""" +import unittest + +from waitress.adjustments import Adjustments +from waitress.parser import ( + HTTPRequestParser, + ParsingError, + TransferEncodingNotImplemented, + crack_first_line, + get_header_lines, + split_uri, + unquote_bytes_to_wsgi, +) +from waitress.utilities import ( + BadRequest, + RequestEntityTooLarge, + RequestHeaderFieldsTooLarge, + ServerNotImplemented, +) + + +class TestHTTPRequestParser(unittest.TestCase): + def setUp(self): + + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def test_get_body_stream_None(self): + self.parser.body_recv = None + result = self.parser.get_body_stream() + self.assertEqual(result.getvalue(), b"") + + def test_get_body_stream_nonNone(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + result = self.parser.get_body_stream() + self.assertEqual(result, body_rcv) + + def test_received_get_no_headers(self): + data = b"HTTP/1.0 GET /foobar\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 24) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_bad_host_header(self): + data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 36) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, BadRequest) + + def test_received_bad_transfer_encoding(self): + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: foo\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 48) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.error.__class__, ServerNotImplemented) + + def test_received_nonsense_nothing(self): + data = b"\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 4) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_no_doublecr(self): + data = b"GET /foobar HTTP/8.4\r\n" + result = self.parser.received(data) + self.assertEqual(result, 22) + self.assertFalse(self.parser.completed) + self.assertEqual(self.parser.headers, {}) + + def test_received_already_completed(self): + self.parser.completed = True + result = self.parser.received(b"a") + self.assertEqual(result, 0) + + def test_received_cl_too_large(self): + + self.parser.adj.max_request_body_size = 2 + data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 44) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_headers_not_too_large_multiple_chunks(self): + + data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n" + data2 = b"X-Foo-Other: 3\r\n\r\n" + self.parser.adj.max_request_header_size = len(data) + len(data2) + 1 + result = self.parser.received(data) + self.assertEqual(result, 32) + result = self.parser.received(data2) + self.assertEqual(result, 18) + self.assertTrue(self.parser.completed) + self.assertFalse(self.parser.error) + + def test_received_headers_too_large(self): + + self.parser.adj.max_request_header_size = 2 + data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n" + result = self.parser.received(data) + self.assertEqual(result, 34) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge)) + + def test_received_body_too_large(self): + self.parser.adj.max_request_body_size = 2 + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d;\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + + result = self.parser.received(data) + self.assertEqual(result, 62) + self.parser.received(data[result:]) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) + + def test_received_error_from_parser(self): + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"garbage\r\n" + ) + # header + result = self.parser.received(data) + # body + result = self.parser.received(data[result:]) + self.assertEqual(result, 9) + self.assertTrue(self.parser.completed) + self.assertTrue(isinstance(self.parser.error, BadRequest)) + + def test_received_chunked_completed_sets_content_length(self): + data = ( + b"GET /foobar HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"X-Foo: 1\r\n" + b"\r\n" + b"1d\r\n" + b"This string has 29 characters\r\n" + b"0\r\n\r\n" + ) + result = self.parser.received(data) + self.assertEqual(result, 62) + data = data[result:] + result = self.parser.received(data) + self.assertTrue(self.parser.completed) + self.assertTrue(self.parser.error is None) + self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29") + + def test_parse_header_gardenpath(self): + data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_no_cr_in_headerplus(self): + data = b"GET /foobar HTTP/8.4" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_bad_content_length(self): + data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_bad_content_length_plus(self): + data = b"GET /foobar HTTP/8.4\r\ncontent-length: +10\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_bad_content_length_minus(self): + data = b"GET /foobar HTTP/8.4\r\ncontent-length: -10\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_multiple_content_length(self): + data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n" + + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Content-Length is invalid", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_te_chunked(self): + # NB: test that capitalization of header value is unimportant + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver") + + def test_parse_header_transfer_encoding_invalid(self): + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_multiple(self): + + data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_whitespace(self): + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_transfer_encoding_invalid_unicode(self): + # This is the binary encoding for the UTF-8 character + # https://www.compart.com/en/unicode/U+212A "unicode character "K"" + # which if waitress were to accidentally do the wrong thing get + # lowercased to just the ascii "k" due to unicode collisions during + # transformation + data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n" + + try: + self.parser.parse_header(data) + except TransferEncodingNotImplemented as e: + self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_11_expect_continue(self): + data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.expect_continue, True) + + def test_parse_header_connection_close(self): + data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n" + self.parser.parse_header(data) + self.assertEqual(self.parser.connection_close, True) + + def test_close_with_body_rcv(self): + body_rcv = DummyBodyStream() + self.parser.body_rcv = body_rcv + self.parser.close() + self.assertTrue(body_rcv.closed) + + def test_close_with_no_body_rcv(self): + self.parser.body_rcv = None + self.parser.close() # doesn't raise + + def test_parse_header_lf_only(self): + data = b"GET /foobar HTTP/8.4\nfoo: bar" + + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_cr_only(self): + data = b"GET /foobar HTTP/8.4\rfoo: bar" + try: + self.parser.parse_header(data) + except ParsingError: + pass + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_header(self): + data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in header line", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_extra_lf_in_first_line(self): + data = b"GET /foobar\n HTTP/8.4\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Bare CR or LF found in HTTP message", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace(self): + data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_whitespace_vtab(self): + data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_no_colon(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_folding_spacing(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_invalid_chars(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_empty(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n" + self.parser.parse_header(data) + + self.assertIn("EMPTY", self.parser.headers) + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["EMPTY"], "") + self.assertEqual(self.parser.headers["FOO"], "bar") + + def test_parse_header_multiple_values(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_header_folded_multiple(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") + + def test_parse_header_multiple_values_extra_space(self): + # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n" + self.parser.parse_header(data) + + self.assertIn("FOO", self.parser.headers) + self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)") + + def test_parse_header_invalid_backtrack_bad(self): + data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n" + try: + self.parser.parse_header(data) + except ParsingError as e: + self.assertIn("Invalid header", e.args[0]) + else: # pragma: nocover + self.assertTrue(False) + + def test_parse_header_short_values(self): + data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n" + self.parser.parse_header(data) + + self.assertIn("ONE", self.parser.headers) + self.assertIn("TWO", self.parser.headers) + self.assertEqual(self.parser.headers["ONE"], "1") + self.assertEqual(self.parser.headers["TWO"], "22") + + +class Test_split_uri(unittest.TestCase): + def _callFUT(self, uri): + ( + self.proxy_scheme, + self.proxy_netloc, + self.path, + self.query, + self.fragment, + ) = split_uri(uri) + + def test_split_uri_unquoting_unneeded(self): + self._callFUT(b"http://localhost:8080/abc def") + self.assertEqual(self.path, "/abc def") + + def test_split_uri_unquoting_needed(self): + self._callFUT(b"http://localhost:8080/abc%20def") + self.assertEqual(self.path, "/abc def") + + def test_split_url_with_query(self): + self._callFUT(b"http://localhost:8080/abc?a=1&b=2") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "a=1&b=2") + + def test_split_url_with_query_empty(self): + self._callFUT(b"http://localhost:8080/abc?") + self.assertEqual(self.path, "/abc") + self.assertEqual(self.query, "") + + def test_split_url_with_fragment(self): + self._callFUT(b"http://localhost:8080/#foo") + self.assertEqual(self.path, "/") + self.assertEqual(self.fragment, "foo") + + def test_split_url_https(self): + self._callFUT(b"https://localhost:8080/") + self.assertEqual(self.path, "/") + self.assertEqual(self.proxy_scheme, "https") + self.assertEqual(self.proxy_netloc, "localhost:8080") + + def test_split_uri_unicode_error_raises_parsing_error(self): + # See https://github.com/Pylons/waitress/issues/64 + + # Either pass or throw a ParsingError, just don't throw another type of + # exception as that will cause the connection to close badly: + try: + self._callFUT(b"/\xd0") + except ParsingError: + pass + + def test_split_uri_path(self): + self._callFUT(b"//testing/whatever") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query(self): + self._callFUT(b"//testing/whatever?a=1&b=2") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "") + + def test_split_uri_path_query_fragment(self): + self._callFUT(b"//testing/whatever?a=1&b=2#fragment") + self.assertEqual(self.path, "//testing/whatever") + self.assertEqual(self.proxy_scheme, "") + self.assertEqual(self.proxy_netloc, "") + self.assertEqual(self.query, "a=1&b=2") + self.assertEqual(self.fragment, "fragment") + + +class Test_get_header_lines(unittest.TestCase): + def _callFUT(self, data): + return get_header_lines(data) + + def test_get_header_lines(self): + result = self._callFUT(b"slam\r\nslim") + self.assertEqual(result, [b"slam", b"slim"]) + + def test_get_header_lines_folded(self): + # From RFC2616: + # HTTP/1.1 header field values can be folded onto multiple lines if the + # continuation line begins with a space or horizontal tab. All linear + # white space, including folding, has the same semantics as SP. A + # recipient MAY replace any linear white space with a single SP before + # interpreting the field value or forwarding the message downstream. + + # We are just preserving the whitespace that indicates folding. + result = self._callFUT(b"slim\r\n slam") + self.assertEqual(result, [b"slim slam"]) + + def test_get_header_lines_tabbed(self): + result = self._callFUT(b"slam\r\n\tslim") + self.assertEqual(result, [b"slam\tslim"]) + + def test_get_header_lines_malformed(self): + # https://corte.si/posts/code/pathod/pythonservers/index.html + self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n") + + +class Test_crack_first_line(unittest.TestCase): + def _callFUT(self, line): + return crack_first_line(line) + + def test_crack_first_line_matchok(self): + result = self._callFUT(b"GET / HTTP/1.0") + self.assertEqual(result, (b"GET", b"/", b"1.0")) + + def test_crack_first_line_lowercase_method(self): + self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0") + + def test_crack_first_line_nomatch(self): + result = self._callFUT(b"GET / bleh") + self.assertEqual(result, (b"", b"", b"")) + + result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0") + self.assertEqual(result, (b"", b"", b"")) + + def test_crack_first_line_missing_version(self): + result = self._callFUT(b"GET /") + self.assertEqual(result, (b"GET", b"/", b"")) + + +class TestHTTPRequestParserIntegration(unittest.TestCase): + def setUp(self): + my_adj = Adjustments() + self.parser = HTTPRequestParser(my_adj) + + def feed(self, data): + parser = self.parser + + for n in range(100): # make sure we never loop forever + consumed = parser.received(data) + data = data[consumed:] + + if parser.completed: + return + raise ValueError("Looping") # pragma: no cover + + def testSimpleGET(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + { + "FIRSTNAME": "mickey", + "LASTNAME": "Mouse", + "CONTENT_LENGTH": "6", + }, + ) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.proxy_scheme, "") + self.assertEqual(parser.proxy_netloc, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testComplexGET(self): + data = ( + b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n" + b"FirstName: mickey\r\n" + b"lastname: Mouse\r\n" + b"content-length: 10\r\n" + b"\r\n" + b"Hello mickey." + ) + parser = self.parser + self.feed(data) + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual( + parser.headers, + {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"}, + ) + # path should be utf-8 encoded + self.assertEqual( + parser.path.encode("latin-1").decode("utf-8"), + b"/foo/a++/\xc3\xa4=&a:int".decode("utf-8"), + ) + # parser.request_uri should preserve the % escape sequences and the query string. + self.assertEqual( + parser.request_uri, + "/foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6", + ) + self.assertEqual( + parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6" + ) + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick") + + def testProxyGET(self): + data = ( + b"GET https://example.com:8080/foobar HTTP/8.4\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + parser = self.parser + self.feed(data) + self.assertTrue(parser.completed) + self.assertEqual(parser.version, "8.4") + self.assertFalse(parser.empty) + self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"}) + self.assertEqual(parser.path, "/foobar") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.proxy_scheme, "https") + self.assertEqual(parser.proxy_netloc, "example.com:8080") + self.assertEqual(parser.command, "GET") + self.assertEqual(parser.query, "") + self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") + + def testDuplicateHeaders(self): + # Ensure that headers with the same key get concatenated as per + # RFC2616. + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-forwarded-for: 10.11.12.13\r\n" + b"x-forwarded-for: unknown,127.0.0.1\r\n" + b"X-Forwarded_for: 255.255.255.255\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual( + self.parser.headers, + { + "CONTENT_LENGTH": "6", + "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1", + }, + ) + + def testSpoofedHeadersDropped(self): + data = ( + b"GET /foobar HTTP/8.4\r\n" + b"x-auth_user: bob\r\n" + b"content-length: 6\r\n" + b"\r\n" + b"Hello." + ) + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual( + self.parser.headers, + { + "CONTENT_LENGTH": "6", + }, + ) + + +class Test_unquote_bytes_to_wsgi(unittest.TestCase): + def _callFUT(self, v): + + return unquote_bytes_to_wsgi(v) + + def test_highorder(self): + val = b"/a%C5%9B" + result = self._callFUT(val) + # PEP 3333 urlunquoted-latin1-decoded-bytes + self.assertEqual(result, "/aÅ\x9b") + + +class DummyBodyStream: + def getfile(self): + return self + + def getbuf(self): + return self + + def close(self): + self.closed = True diff --git a/tests/test_proxy_headers.py b/tests/test_proxy_headers.py new file mode 100644 index 00000000..45f98785 --- /dev/null +++ b/tests/test_proxy_headers.py @@ -0,0 +1,761 @@ +import unittest + + +class TestProxyHeadersMiddleware(unittest.TestCase): + def _makeOne(self, app, **kw): + from waitress.proxy_headers import proxy_headers_middleware + + return proxy_headers_middleware(app, **kw) + + def _callFUT(self, app, **kw): + response = DummyResponse() + environ = DummyEnviron(**kw) + + def start_response(status, response_headers): + response.status = status + response.headers = response_headers + + response.steps = list(app(environ, start_response)) + response.body = b"".join(s for s in response.steps) + return response + + def test_get_environment_values_w_scheme_override_untrusted(self): + inner = DummyApp() + app = self._makeOne(inner) + response = self._callFUT( + app, + headers={ + "X_FOO": "BAR", + "X_FORWARDED_PROTO": "https", + }, + ) + self.assertEqual(response.status, "200 OK") + self.assertEqual(inner.environ["wsgi.url_scheme"], "http") + + def test_get_environment_values_w_scheme_override_trusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 8080], + headers={ + "X_FOO": "BAR", + "X_FORWARDED_PROTO": "https", + }, + ) + + environ = inner.environ + self.assertEqual(response.status, "200 OK") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + + def test_get_environment_values_w_bogus_scheme_override(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_headers={"x-forwarded-proto"}, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FOO": "BAR", + "X_FORWARDED_PROTO": "http://p02n3e.com?url=http", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_get_environment_warning_other_proxy_headers(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + log_untrusted=True, + logger=logger, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "[2001:db8::1]", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_get_environment_contains_all_headers_including_untrusted(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=False, + ) + headers_orig = { + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + } + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers=headers_orig.copy(), + ) + self.assertEqual(response.status, "200 OK") + environ = inner.environ + for k, expected in headers_orig.items(): + result = environ["HTTP_%s" % k] + self.assertEqual(result, expected) + + def test_get_environment_contains_only_trusted_headers(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.1", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress") + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_get_environment_clears_headers_if_untrusted_proxy(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="192.168.1.1", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-by"}, + clear_untrusted=True, + ) + response = self._callFUT( + app, + addr=["192.168.1.255", 80], + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "X_FORWARDED_BY": "Waitress", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.org", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertNotIn("HTTP_X_FORWARDED_BY", environ) + self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) + self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) + self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) + + def test_parse_proxy_headers_forwarded_for(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1") + + def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"}) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0") + + def test_parse_proxy_headers_forwared_for_multiple(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT( + app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"} + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_multiple_proxies_trust_only_two(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=192.0.2.1;host=fake.com, " + "For=198.51.100.2;host=example.com:8080, " + "For=203.0.113.1" + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded_multiple_proxies(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", ' + 'for=192.0.2.1;host="example.internal:8080"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["REMOTE_PORT"], "3821") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8443") + self.assertEqual(environ["SERVER_PORT"], "8443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_multiple_proxies_minimal(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + 'for="[2001:db8::1]";proto="https", ' + 'for=192.0.2.1;host="example.org"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_proxy_headers_forwarded_host_with_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com:8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_without_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com") + self.assertEqual(environ["SERVER_PORT"], "80") + + def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=2, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted( + self, + ): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-for", + "x-forwarded-proto", + "x-forwarded-host", + "x-forwarded-port", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com, example.org", + "X_FORWARDED_PORT": "8080", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1") + self.assertEqual(environ["SERVER_NAME"], "example.org") + self.assertEqual(environ["HTTP_HOST"], "example.org:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + + def test_parse_forwarded(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["REMOTE_PORT"], "5858") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_empty_pair(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": "For=198.51.100.2;;proto=https;by=_unused", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + + def test_parse_forwarded_pair_token_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": "For=198.51.100.2; proto =https", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_value_whitespace(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": 'For= "198.51.100.2"; proto =https', + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_pair_no_equals(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + ) + response = self._callFUT(app, headers={"FORWARDED": "For"}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "Forwarded" malformed', response.body) + + def test_parse_forwarded_warning_unknown_token(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"forwarded"}, + logger=logger, + ) + response = self._callFUT( + app, + headers={ + "FORWARDED": ( + "For=198.51.100.2;host=example.com:8080;proto=https;" + 'unknown="yolo"' + ), + }, + ) + self.assertEqual(response.status, "200 OK") + + self.assertEqual(len(logger.logged), 1) + self.assertIn("Unknown Forwarded token", logger.logged[0]) + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:8080") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_no_valid_proxy_headers(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_FOR": "198.51.100.2", + "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80") + self.assertEqual(environ["SERVER_PORT"], "8080") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_multiple_x_forwarded_proto(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-proto"}, + logger=logger, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PROTO": "http, https", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) + + def test_parse_multiple_x_forwarded_port(self): + inner = DummyApp() + logger = DummyLogger() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-port"}, + logger=logger, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "443, 80", + }, + ) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body) + + def test_parse_forwarded_port_wrong_proto_port_80(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "80", + "X_FORWARDED_PROTO": "https", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:80") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["wsgi.url_scheme"], "https") + + def test_parse_forwarded_port_wrong_proto_port_443(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={ + "x-forwarded-port", + "x-forwarded-host", + "x-forwarded-proto", + }, + ) + response = self._callFUT( + app, + headers={ + "X_FORWARDED_PORT": "443", + "X_FORWARDED_PROTO": "http", + "X_FORWARDED_HOST": "example.com", + }, + ) + self.assertEqual(response.status, "200 OK") + + environ = inner.environ + self.assertEqual(environ["SERVER_NAME"], "example.com") + self.assertEqual(environ["HTTP_HOST"], "example.com:443") + self.assertEqual(environ["SERVER_PORT"], "443") + self.assertEqual(environ["wsgi.url_scheme"], "http") + + def test_parse_forwarded_for_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-for"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-For" malformed', response.body) + + def test_parse_forwarded_host_bad_quote(self): + inner = DummyApp() + app = self._makeOne( + inner, + trusted_proxy="*", + trusted_proxy_count=1, + trusted_proxy_headers={"x-forwarded-host"}, + ) + response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'}) + self.assertEqual(response.status, "400 Bad Request") + self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body) + + +class DummyLogger: + def __init__(self): + self.logged = [] + + def warning(self, msg, *args): + self.logged.append(msg % args) + + +class DummyApp: + def __call__(self, environ, start_response): + self.environ = environ + start_response("200 OK", [("Content-Type", "text/plain")]) + yield b"hello" + + +class DummyResponse: + status = None + headers = None + body = None + + +def DummyEnviron( + addr=("127.0.0.1", 8080), + scheme="http", + server="localhost", + headers=None, +): + environ = { + "REMOTE_ADDR": addr[0], + "REMOTE_HOST": addr[0], + "REMOTE_PORT": addr[1], + "SERVER_PORT": str(addr[1]), + "SERVER_NAME": server, + "wsgi.url_scheme": scheme, + "HTTP_HOST": "192.168.1.1:80", + } + if headers: + environ.update( + { + "HTTP_" + key.upper().replace("-", "_"): value + for key, value in headers.items() + } + ) + return environ diff --git a/tests/test_receiver.py b/tests/test_receiver.py new file mode 100644 index 00000000..d160cac4 --- /dev/null +++ b/tests/test_receiver.py @@ -0,0 +1,293 @@ +import unittest + +import pytest + + +class TestFixedStreamReceiver(unittest.TestCase): + def _makeOne(self, cl, buf): + from waitress.receiver import FixedStreamReceiver + + return FixedStreamReceiver(cl, buf) + + def test_received_remain_lt_1(self): + buf = DummyBuffer() + inst = self._makeOne(0, buf) + result = inst.received("a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_lte_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(1, buf) + result = inst.received("aa") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, True) + self.assertEqual(inst.completed, 1) + self.assertEqual(inst.remain, 0) + self.assertEqual(buf.data, ["a"]) + + def test_received_remain_gt_datalen(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + result = inst.received("aa") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, False) + self.assertEqual(inst.remain, 8) + self.assertEqual(buf.data, ["aa"]) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(10, buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(10, buf) + self.assertEqual(inst.__len__(), 2) + + +class TestChunkedReceiver(unittest.TestCase): + def _makeOne(self, buf): + from waitress.receiver import ChunkedReceiver + + return ChunkedReceiver(buf) + + def test_alreadycompleted(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.completed = True + result = inst.received(b"a") + self.assertEqual(result, 0) + self.assertEqual(inst.completed, True) + + def test_received_remain_gt_zero(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.chunk_remainder = 100 + result = inst.received(b"a") + self.assertEqual(inst.chunk_remainder, 99) + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_notfinished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a") + self.assertEqual(inst.control_line, b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_garbage_in_input(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"garbage\r\n") + self.assertEqual(result, 9) + self.assertTrue(inst.error) + + def test_received_control_line_finished_all_chunks_not_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"a;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.chunk_remainder, 10) + self.assertEqual(inst.all_chunks_received, False) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_control_line_finished_all_chunks_received(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + result = inst.received(b"0;discard\r\n") + self.assertEqual(inst.control_line, b"") + self.assertEqual(inst.all_chunks_received, True) + self.assertEqual(result, 11) + self.assertEqual(inst.completed, False) + + def test_received_trailer_startswith_crlf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\r\n") + self.assertEqual(result, 2) + self.assertEqual(inst.completed, True) + + def test_received_trailer_startswith_lf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"\n") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_not_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"a") + self.assertEqual(result, 1) + self.assertEqual(inst.completed, False) + + def test_received_trailer_finished(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + inst.all_chunks_received = True + result = inst.received(b"abc\r\n\r\n") + self.assertEqual(inst.trailer, b"abc\r\n\r\n") + self.assertEqual(result, 7) + self.assertEqual(inst.completed, True) + + def test_getfile(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getfile(), buf) + + def test_getbuf(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + self.assertEqual(inst.getbuf(), buf) + + def test___len__(self): + buf = DummyBuffer(["1", "2"]) + inst = self._makeOne(buf) + self.assertEqual(inst.__len__(), 2) + + def test_received_chunk_is_properly_terminated(self): + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWiki\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + + def test_received_chunk_not_properly_terminated(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4\r\nWikibadchunk\r\n" + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, False) + self.assertEqual(buf.data[0], b"Wiki") + self.assertEqual(inst.error.__class__, BadRequest) + + def test_received_multiple_chunks(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = ( + b"4\r\n" + b"Wiki\r\n" + b"5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + result = inst.received(data) + self.assertEqual(result, len(data)) + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + def test_received_multiple_chunks_split(self): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data1 = b"4\r\nWiki\r" + result = inst.received(data1) + self.assertEqual(result, len(data1)) + + data2 = ( + b"\n5\r\n" + b"pedia\r\n" + b"E\r\n" + b" in\r\n" + b"\r\n" + b"chunks.\r\n" + b"0\r\n" + b"\r\n" + ) + + result = inst.received(data2) + self.assertEqual(result, len(data2)) + + self.assertEqual(inst.completed, True) + self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") + self.assertEqual(inst.error, None) + + +class TestChunkedReceiverParametrized: + def _makeOne(self, buf): + from waitress.receiver import ChunkedReceiver + + return ChunkedReceiver(buf) + + @pytest.mark.parametrize( + "invalid_extension", [b"\n", b"invalid=", b"\r", b"invalid = true"] + ) + def test_received_invalid_extensions(self, invalid_extension): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4;" + invalid_extension + b"\r\ntest\r\n" + result = inst.received(data) + assert result == len(data) + assert inst.error.__class__ == BadRequest + assert inst.error.body == "Invalid chunk extension" + + @pytest.mark.parametrize( + "valid_extension", [b"test", b"valid=true", b"valid=true;other=true"] + ) + def test_received_valid_extensions(self, valid_extension): + # While waitress may ignore extensions in Chunked Encoding, we do want + # to make sure that we don't fail when we do encounter one that is + # valid + buf = DummyBuffer() + inst = self._makeOne(buf) + data = b"4;" + valid_extension + b"\r\ntest\r\n" + result = inst.received(data) + assert result == len(data) + assert inst.error == None + + @pytest.mark.parametrize( + "invalid_size", [b"0x04", b"+0x04", b"x04", b"+04", b" 04", b" 0x04"] + ) + def test_received_invalid_size(self, invalid_size): + from waitress.utilities import BadRequest + + buf = DummyBuffer() + inst = self._makeOne(buf) + data = invalid_size + b"\r\ntest\r\n" + result = inst.received(data) + assert result == len(data) + assert inst.error.__class__ == BadRequest + assert inst.error.body == "Invalid chunk size" + + +class DummyBuffer: + def __init__(self, data=None): + if data is None: + data = [] + self.data = data + + def append(self, s): + self.data.append(s) + + def getfile(self): + return self + + def __len__(self): + return len(self.data) diff --git a/waitress/tests/test_regression.py b/tests/test_regression.py similarity index 96% rename from waitress/tests/test_regression.py rename to tests/test_regression.py index f43895e1..840e599e 100644 --- a/waitress/tests/test_regression.py +++ b/tests/test_regression.py @@ -15,8 +15,9 @@ """ import doctest -class FakeSocket: # pragma: no cover - data = '' + +class FakeSocket: # pragma: no cover + data = "" setblocking = lambda *_: None close = lambda *_: None @@ -27,14 +28,15 @@ def fileno(self): return self.no def getpeername(self): - return ('localhost', self.no) + return ("localhost", self.no) def send(self, data): self.data += data return len(data) def recv(self, data): - return 'data' + return "data" + def zombies_test(): """Regression test for HTTPChannel.maintenance method @@ -137,8 +139,8 @@ def zombies_test(): >>> channel4.last_activity != last_active True + """ -""" def test_suite(): return doctest.DocTestSuite() diff --git a/waitress/tests/test_runner.py b/tests/test_runner.py similarity index 56% rename from waitress/tests/test_runner.py rename to tests/test_runner.py index fa927f0a..4cf6f6fd 100644 --- a/waitress/tests/test_runner.py +++ b/tests/test_runner.py @@ -2,124 +2,111 @@ import os import sys -if sys.version_info[:2] == (2, 6): # pragma: no cover +if sys.version_info[:2] == (2, 6): # pragma: no cover import unittest2 as unittest -else: # pragma: no cover +else: # pragma: no cover import unittest from waitress import runner -class Test_match(unittest.TestCase): +class Test_match(unittest.TestCase): def test_empty(self): - self.assertRaisesRegexp( - ValueError, "^Malformed application ''$", - runner.match, '') + self.assertRaisesRegex( + ValueError, "^Malformed application ''$", runner.match, "" + ) def test_module_only(self): - self.assertRaisesRegexp( - ValueError, r"^Malformed application 'foo\.bar'$", - runner.match, 'foo.bar') + self.assertRaisesRegex( + ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar" + ) def test_bad_module(self): - self.assertRaisesRegexp( + self.assertRaisesRegex( ValueError, r"^Malformed application 'foo#bar:barney'$", - runner.match, 'foo#bar:barney') + runner.match, + "foo#bar:barney", + ) def test_module_obj(self): self.assertTupleEqual( - runner.match('foo.bar:fred.barney'), - ('foo.bar', 'fred.barney')) + runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney") + ) -class Test_resolve(unittest.TestCase): +class Test_resolve(unittest.TestCase): def test_bad_module(self): self.assertRaises( - ImportError, - runner.resolve, 'nonexistent', 'nonexistent_function') + ImportError, runner.resolve, "nonexistent", "nonexistent_function" + ) def test_nonexistent_function(self): - self.assertRaisesRegexp( + self.assertRaisesRegex( AttributeError, r"has no attribute 'nonexistent_function'", - runner.resolve, 'os.path', 'nonexistent_function') + runner.resolve, + "os.path", + "nonexistent_function", + ) def test_simple_happy_path(self): from os.path import exists - self.assertIs(runner.resolve('os.path', 'exists'), exists) + + self.assertIs(runner.resolve("os.path", "exists"), exists) def test_complex_happy_path(self): # Ensure we can recursively resolve object attributes if necessary. - self.assertEquals( - runner.resolve('os.path', 'exists.__name__'), - 'exists') + self.assertEqual(runner.resolve("os.path", "exists.__name__"), "exists") -class Test_run(unittest.TestCase): +class Test_run(unittest.TestCase): def match_output(self, argv, code, regex): - argv = ['waitress-serve'] + argv + argv = ["waitress-serve"] + argv with capture() as captured: self.assertEqual(runner.run(argv=argv), code) - self.assertRegexpMatches(captured.getvalue(), regex) + self.assertRegex(captured.getvalue(), regex) captured.close() def test_bad(self): - self.match_output( - ['--bad-opt'], - 1, - '^Error: option --bad-opt not recognized') + self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized") def test_help(self): - self.match_output( - ['--help'], - 0, - "^Usage:\n\n waitress-serve") + self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve") def test_no_app(self): - self.match_output( - [], - 1, - "^Error: Specify one application only") + self.match_output([], 1, "^Error: Specify one application only") def test_multiple_apps_app(self): - self.match_output( - ['a:a', 'b:b'], - 1, - "^Error: Specify one application only") + self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only") def test_bad_apps_app(self): - self.match_output( - ['a'], - 1, - "^Error: Malformed application 'a'") + self.match_output(["a"], 1, "^Error: Malformed application 'a'") def test_bad_app_module(self): - self.match_output( - ['nonexistent:a'], - 1, - "^Error: Bad module 'nonexistent'") + self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'") self.match_output( - ['nonexistent:a'], + ["nonexistent:a"], 1, ( r"There was an exception \((ImportError|ModuleNotFoundError)\) " "importing your module.\n\nIt had these arguments: \n" "1. No module named '?nonexistent'?" - ) + ), ) def test_cwd_added_to_path(self): def null_serve(app, **kw): pass + sys_path = sys.path current_dir = os.getcwd() try: os.chdir(os.path.dirname(__file__)) argv = [ - 'waitress-serve', - 'fixtureapps.runner:app', + "waitress-serve", + "fixtureapps.runner:app", ] self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0) finally: @@ -128,37 +115,40 @@ def null_serve(app, **kw): def test_bad_app_object(self): self.match_output( - ['waitress.tests.fixtureapps.runner:a'], - 1, - "^Error: Bad object name 'a'") + ["tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'" + ) def test_simple_call(self): - import waitress.tests.fixtureapps.runner as _apps + from tests.fixtureapps import runner as _apps + def check_server(app, **kw): self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {'port': '80'}) + self.assertDictEqual(kw, {"port": "80"}) + argv = [ - 'waitress-serve', - '--port=80', - 'waitress.tests.fixtureapps.runner:app', + "waitress-serve", + "--port=80", + "tests.fixtureapps.runner:app", ] self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) def test_returned_app(self): - import waitress.tests.fixtureapps.runner as _apps + from tests.fixtureapps import runner as _apps + def check_server(app, **kw): self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {'port': '80'}) + self.assertDictEqual(kw, {"port": "80"}) + argv = [ - 'waitress-serve', - '--port=80', - '--call', - 'waitress.tests.fixtureapps.runner:returns_app', + "waitress-serve", + "--port=80", + "--call", + "tests.fixtureapps.runner:returns_app", ] self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) -class Test_helper(unittest.TestCase): +class Test_helper(unittest.TestCase): def test_exception_logging(self): from waitress.runner import show_exception @@ -172,10 +162,7 @@ def test_exception_logging(self): raise ImportError("My reason") except ImportError: self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches( - captured.getvalue(), - regex - ) + self.assertRegex(captured.getvalue(), regex) captured.close() regex = ( @@ -188,16 +175,15 @@ def test_exception_logging(self): raise ImportError except ImportError: self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches( - captured.getvalue(), - regex - ) + self.assertRegex(captured.getvalue(), regex) captured.close() + @contextlib.contextmanager def capture(): - from waitress.compat import NativeIO - fd = NativeIO() + from io import StringIO + + fd = StringIO() sys.stdout = fd sys.stderr = fd yield fd diff --git a/waitress/tests/test_server.py b/tests/test_server.py similarity index 56% rename from waitress/tests/test_server.py rename to tests/test_server.py index 39b90b3d..6edc3b24 100644 --- a/waitress/tests/test_server.py +++ b/tests/test_server.py @@ -4,23 +4,36 @@ dummy_app = object() -class TestWSGIServer(unittest.TestCase): - def _makeOne(self, application=dummy_app, host='127.0.0.1', port=0, - _dispatcher=None, adj=None, map=None, _start=True, - _sock=None, _server=None): +class TestWSGIServer(unittest.TestCase): + def _makeOne( + self, + application=dummy_app, + host="127.0.0.1", + port=0, + _dispatcher=None, + adj=None, + map=None, + _start=True, + _sock=None, + _server=None, + ): from waitress.server import create_server - return create_server( + + self.inst = create_server( application, host=host, port=port, map=map, _dispatcher=_dispatcher, _start=_start, - _sock=_sock) + _sock=_sock, + ) + return self.inst - def _makeOneWithMap(self, adj=None, _start=True, host='127.0.0.1', - port=0, app=dummy_app): + def _makeOneWithMap( + self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app + ): sock = DummySock() task_dispatcher = DummyTaskDispatcher() map = {} @@ -34,21 +47,55 @@ def _makeOneWithMap(self, adj=None, _start=True, host='127.0.0.1', _start=_start, ) - def _makeOneWithMulti(self, adj=None, _start=True, - app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0"): + def _makeOneWithMulti( + self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" + ): sock = DummySock() task_dispatcher = DummyTaskDispatcher() map = {} from waitress.server import create_server - return create_server( + + self.inst = create_server( app, listen=listen, map=map, _dispatcher=task_dispatcher, _start=_start, - _sock=sock) + _sock=sock, + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, + ) + return self.inst + + def tearDown(self): + if self.inst is not None: + self.inst.close() def test_ctor_app_is_None(self): + self.inst = None self.assertRaises(ValueError, self._makeOneWithMap, app=None) def test_ctor_start_true(self): @@ -58,36 +105,17 @@ def test_ctor_start_true(self): def test_ctor_makes_dispatcher(self): inst = self._makeOne(_start=False, map={}) - self.assertEqual(inst.task_dispatcher.__class__.__name__, - 'ThreadedTaskDispatcher') + self.assertEqual( + inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" + ) def test_ctor_start_false(self): inst = self._makeOneWithMap(_start=False) self.assertEqual(inst.accepting, False) - def test_get_server_name_empty(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name('') - self.assertTrue(result) - - def test_get_server_name_with_ip(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name('127.0.0.1') - self.assertTrue(result) - - def test_get_server_name_with_hostname(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name('fred.flintstone.com') - self.assertEqual(result, 'fred.flintstone.com') - - def test_get_server_name_0000(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name('0.0.0.0') - self.assertEqual(result, 'localhost') - def test_get_server_multi(self): inst = self._makeOneWithMulti() - self.assertEqual(inst.__class__.__name__, 'MultiSocketServer') + self.assertEqual(inst.__class__.__name__, "MultiSocketServer") def test_run(self): inst = self._makeOneWithMap(_start=False) @@ -105,6 +133,7 @@ def test_run_base_server(self): def test_pull_trigger(self): inst = self._makeOneWithMap(_start=False) + inst.trigger.close() inst.trigger = DummyTrigger() inst.pull_trigger() self.assertEqual(inst.trigger.pulled, True) @@ -125,8 +154,9 @@ def test_readable_maplen_gt_connection_limit(self): inst = self._makeOneWithMap() inst.accepting = True inst.adj = DummyAdj - inst._map = {'a': 1, 'b': 2} + inst._map = {"a": 1, "b": 2} self.assertFalse(inst.readable()) + self.assertTrue(inst.in_connection_overflow) def test_readable_maplen_lt_connection_limit(self): inst = self._makeOneWithMap() @@ -134,9 +164,23 @@ def test_readable_maplen_lt_connection_limit(self): inst.adj = DummyAdj inst._map = {} self.assertTrue(inst.readable()) + self.assertFalse(inst.in_connection_overflow) + + def test_readable_maplen_toggles_connection_overflow(self): + inst = self._makeOneWithMap() + inst.accepting = True + inst.adj = DummyAdj + inst._map = {"a": 1, "b": 2} + self.assertFalse(inst.in_connection_overflow) + self.assertFalse(inst.readable()) + self.assertTrue(inst.in_connection_overflow) + inst._map = {} + self.assertTrue(inst.readable()) + self.assertFalse(inst.in_connection_overflow) def test_readable_maintenance_false(self): import time + inst = self._makeOneWithMap() then = time.time() + 1000 inst.next_channel_cleanup = then @@ -179,8 +223,10 @@ def test_handle_accept_other_socket_error(self): eaborted = socket.error(errno.ECONNABORTED) inst.socket = DummySock(toraise=eaborted) inst.adj = DummyAdj + def foo(): - raise socket.error + raise OSError + inst.accept = foo inst.logger = DummyLogger() inst.handle_accept() @@ -196,14 +242,15 @@ def test_handle_accept_noerror(self): inst.channel_class = lambda *arg, **kw: L.append(arg) inst.handle_accept() self.assertEqual(inst.socket.accepted, True) - self.assertEqual(innersock.opts, [('level', 'optname', 'value')]) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) self.assertEqual(L, [(inst, innersock, None, inst.adj)]) def test_maintenance(self): inst = self._makeOneWithMap() - class DummyChannel(object): + class DummyChannel: requests = [] + zombie = DummyChannel() zombie.last_activity = 0 zombie.running_tasks = False @@ -212,30 +259,105 @@ class DummyChannel(object): self.assertEqual(zombie.will_close, True) def test_backward_compatibility(self): - from waitress.server import WSGIServer, TcpWSGIServer from waitress.adjustments import Adjustments + from waitress.server import TcpWSGIServer, WSGIServer + self.assertTrue(WSGIServer is TcpWSGIServer) - inst = WSGIServer(None, _start=False, port=1234) + self.inst = WSGIServer(None, _start=False, port=1234) # Ensure the adjustment was actually applied. self.assertNotEqual(Adjustments.port, 1234) - self.assertEqual(inst.adj.port, 1234) + self.assertEqual(self.inst.adj.port, 1234) + + def test_create_with_one_tcp_socket(self): + from waitress.server import TcpWSGIServer + + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] + sockets[0].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, TcpWSGIServer)) + + def test_create_with_multiple_tcp_sockets(self): + from waitress.server import MultiSocketServer + + sockets = [ + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + socket.socket(socket.AF_INET, socket.SOCK_STREAM), + ] + sockets[0].bind(("127.0.0.1", 0)) + sockets[1].bind(("127.0.0.1", 0)) + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertTrue(isinstance(inst, MultiSocketServer)) + self.assertEqual(len(inst.effective_listen), 2) + + def test_create_with_one_socket_should_not_bind_socket(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + sockets[0].bind_called = False + inst = self._makeWithSockets(_start=False, sockets=sockets) + self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) + self.assertFalse(inst.socket.bind_called) + + def test_create_with_one_socket_handle_accept_noerror(self): + innersock = DummySock() + sockets = [DummySock(acceptresult=(innersock, None))] + sockets[0].bind(("127.0.0.1", 80)) + inst = self._makeWithSockets(sockets=sockets) + L = [] + inst.channel_class = lambda *arg, **kw: L.append(arg) + inst.adj = DummyAdj + inst.handle_accept() + self.assertEqual(sockets[0].accepted, True) + self.assertEqual(innersock.opts, [("level", "optname", "value")]) + self.assertEqual(L, [(inst, innersock, None, inst.adj)]) + -if hasattr(socket, 'AF_UNIX'): +if hasattr(socket, "AF_UNIX"): class TestUnixWSGIServer(unittest.TestCase): - unix_socket = '/tmp/waitress.test.sock' + unix_socket = "/tmp/waitress.test.sock" def _makeOne(self, _start=True, _sock=None): from waitress.server import create_server - return create_server( + + self.inst = create_server( dummy_app, map={}, _start=_start, _sock=_sock, _dispatcher=DummyTaskDispatcher(), unix_socket=self.unix_socket, - unix_socket_perms='600' + unix_socket_perms="600", + ) + return self.inst + + def _makeWithSockets( + self, + application=dummy_app, + _dispatcher=None, + map=None, + _start=True, + _sock=None, + _server=None, + sockets=None, + ): + from waitress.server import create_server + + _sockets = [] + if sockets is not None: + _sockets = sockets + self.inst = create_server( + application, + map=map, + _dispatcher=_dispatcher, + _start=_start, + _sock=_sock, + sockets=_sockets, ) + return self.inst + + def tearDown(self): + self.inst.close() def _makeDummy(self, *args, **kwargs): sock = DummySock(*args, **kwargs) @@ -261,33 +383,54 @@ def test_handle_accept(self): inst.handle_accept() self.assertEqual(inst.socket.accepted, True) self.assertEqual(client.opts, []) - self.assertEqual( - L, - [(inst, client, ('localhost', None), inst.adj)] - ) + self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) def test_creates_new_sockinfo(self): from waitress.server import UnixWSGIServer - inst = UnixWSGIServer( - dummy_app, - unix_socket=self.unix_socket, - unix_socket_perms='600' + + self.inst = UnixWSGIServer( + dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" + ) + + self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) + + def test_create_with_unix_socket(self): + from waitress.server import ( + BaseWSGIServer, + MultiSocketServer, + TcpWSGIServer, + UnixWSGIServer, + ) + + sockets = [ + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), + ] + inst = self._makeWithSockets(sockets=sockets, _start=False) + self.assertTrue(isinstance(inst, MultiSocketServer)) + server = list( + filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) ) + self.assertTrue(isinstance(server[0], UnixWSGIServer)) + self.assertTrue(isinstance(server[1], UnixWSGIServer)) - self.assertEqual(inst.sockinfo[0], socket.AF_UNIX) -class DummySock(object): +class DummySock(socket.socket): accepted = False blocking = False family = socket.AF_INET + type = socket.SOCK_STREAM + proto = 0 def __init__(self, toraise=None, acceptresult=(None, None)): self.toraise = toraise self.acceptresult = acceptresult self.bound = None self.opts = [] + self.bind_called = False def bind(self, addr): + self.bind_called = True self.bound = addr def accept(self): @@ -303,7 +446,7 @@ def fileno(self): return 10 def getpeername(self): - return '127.0.0.1' + return "127.0.0.1" def setsockopt(self, *arg): self.opts.append(arg) @@ -317,8 +460,11 @@ def listen(self, num): def getsockname(self): return self.bound -class DummyTaskDispatcher(object): + def close(self): + pass + +class DummyTaskDispatcher: def __init__(self): self.tasks = [] @@ -328,38 +474,43 @@ def add_task(self, task): def shutdown(self): self.was_shutdown = True -class DummyTask(object): + +class DummyTask: serviced = False start_response_called = False wrote_header = False - status = '200 OK' + status = "200 OK" def __init__(self): self.response_headers = {} - self.written = '' + self.written = "" - def service(self): # pragma: no cover + def service(self): # pragma: no cover self.serviced = True + class DummyAdj: connection_limit = 1 log_socket_errors = True - socket_options = [('level', 'optname', 'value')] + socket_options = [("level", "optname", "value")] cleanup_interval = 900 channel_timeout = 300 -class DummyAsyncore(object): +class DummyAsyncore: def loop(self, timeout=30.0, use_poll=False, map=None, count=None): raise SystemExit -class DummyTrigger(object): +class DummyTrigger: def pull_trigger(self): self.pulled = True -class DummyLogger(object): + def close(self): + pass + +class DummyLogger: def __init__(self): self.logged = [] diff --git a/waitress/tests/test_task.py b/tests/test_task.py similarity index 50% rename from waitress/tests/test_task.py rename to tests/test_task.py index 2a2759a2..47868e15 100644 --- a/waitress/tests/test_task.py +++ b/tests/test_task.py @@ -1,31 +1,32 @@ -import unittest import io +import unittest -class TestThreadedTaskDispatcher(unittest.TestCase): +class TestThreadedTaskDispatcher(unittest.TestCase): def _makeOne(self): from waitress.task import ThreadedTaskDispatcher - return ThreadedTaskDispatcher() - def test_handler_thread_task_is_None(self): - inst = self._makeOne() - inst.threads[0] = True - inst.queue.put(None) - inst.handler_thread(0) - self.assertEqual(inst.stop_count, -1) - self.assertEqual(inst.threads, {}) + return ThreadedTaskDispatcher() def test_handler_thread_task_raises(self): - from waitress.task import JustTesting inst = self._makeOne() - inst.threads[0] = True + inst.threads.add(0) inst.logger = DummyLogger() - task = DummyTask(JustTesting) + + class BadDummyTask(DummyTask): + def service(self): + super().service() + inst.stop_count += 1 + raise Exception + + task = BadDummyTask() inst.logger = DummyLogger() - inst.queue.put(task) + inst.queue.append(task) + inst.active_count += 1 inst.handler_thread(0) - self.assertEqual(inst.stop_count, -1) - self.assertEqual(inst.threads, {}) + self.assertEqual(inst.stop_count, 0) + self.assertEqual(inst.active_count, 0) + self.assertEqual(inst.threads, set()) self.assertEqual(len(inst.logger.logged), 1) def test_set_thread_count_increase(self): @@ -33,221 +34,280 @@ def test_set_thread_count_increase(self): L = [] inst.start_new_thread = lambda *x: L.append(x) inst.set_thread_count(1) - self.assertEqual(L, [(inst.handler_thread, (0,))]) + self.assertEqual(L, [(inst.handler_thread, 0)]) def test_set_thread_count_increase_with_existing(self): inst = self._makeOne() L = [] - inst.threads = {0: 1} + inst.threads = {0} inst.start_new_thread = lambda *x: L.append(x) inst.set_thread_count(2) - self.assertEqual(L, [(inst.handler_thread, (1,))]) + self.assertEqual(L, [(inst.handler_thread, 1)]) def test_set_thread_count_decrease(self): inst = self._makeOne() - inst.threads = {'a': 1, 'b': 2} + inst.threads = {0, 1} inst.set_thread_count(1) - self.assertEqual(inst.queue.qsize(), 1) - self.assertEqual(inst.queue.get(), None) + self.assertEqual(inst.stop_count, 1) def test_set_thread_count_same(self): inst = self._makeOne() L = [] inst.start_new_thread = lambda *x: L.append(x) - inst.threads = {0: 1} + inst.threads = {0} inst.set_thread_count(1) self.assertEqual(L, []) - def test_add_task(self): + def test_add_task_with_idle_threads(self): task = DummyTask() inst = self._makeOne() + inst.threads.add(0) + inst.queue_logger = DummyLogger() inst.add_task(task) - self.assertEqual(inst.queue.qsize(), 1) - self.assertTrue(task.deferred) + self.assertEqual(len(inst.queue), 1) + self.assertEqual(len(inst.queue_logger.logged), 0) - def test_add_task_defer_raises(self): - task = DummyTask(ValueError) + def test_add_task_with_all_busy_threads(self): + task = DummyTask() inst = self._makeOne() - self.assertRaises(ValueError, inst.add_task, task) - self.assertEqual(inst.queue.qsize(), 0) - self.assertTrue(task.deferred) - self.assertTrue(task.cancelled) + inst.queue_logger = DummyLogger() + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 1) + inst.add_task(task) + self.assertEqual(len(inst.queue_logger.logged), 2) def test_shutdown_one_thread(self): inst = self._makeOne() - inst.threads[0] = 1 + inst.threads.add(0) inst.logger = DummyLogger() task = DummyTask() - inst.queue.put(task) - self.assertEqual(inst.shutdown(timeout=.01), True) - self.assertEqual(inst.logger.logged, ['1 thread(s) still running']) + inst.queue.append(task) + self.assertEqual(inst.shutdown(timeout=0.01), True) + self.assertEqual( + inst.logger.logged, + [ + "1 thread(s) still running", + "Canceling 1 pending task(s)", + ], + ) self.assertEqual(task.cancelled, True) def test_shutdown_no_threads(self): inst = self._makeOne() - self.assertEqual(inst.shutdown(timeout=.01), True) + self.assertEqual(inst.shutdown(timeout=0.01), True) def test_shutdown_no_cancel_pending(self): inst = self._makeOne() - self.assertEqual(inst.shutdown(cancel_pending=False, timeout=.01), - False) + self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False) -class TestTask(unittest.TestCase): +class TestTask(unittest.TestCase): def _makeOne(self, channel=None, request=None): if channel is None: channel = DummyChannel() if request is None: request = DummyParser() from waitress.task import Task + return Task(channel, request) def test_ctor_version_not_in_known(self): request = DummyParser() - request.version = '8.4' + request.version = "8.4" inst = self._makeOne(request=request) - self.assertEqual(inst.version, '1.0') - - def test_cancel(self): - inst = self._makeOne() - inst.cancel() - self.assertTrue(inst.close_on_finish) - - def test_defer(self): - inst = self._makeOne() - self.assertEqual(inst.defer(), None) + self.assertEqual(inst.version, "1.0") def test_build_response_header_bad_http_version(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '8.4' + inst.version = "8.4" self.assertRaises(AssertionError, inst.build_response_header) def test_build_response_header_v10_keepalive_no_content_length(self): inst = self._makeOne() inst.request = DummyParser() - inst.request.headers['CONNECTION'] = 'keep-alive' - inst.version = '1.0' + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.0" result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b'HTTP/1.0 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") self.assertEqual(inst.close_on_finish, True) - self.assertTrue(('Connection', 'close') in inst.response_headers) + self.assertTrue(("Connection", "close") in inst.response_headers) def test_build_response_header_v10_keepalive_with_content_length(self): inst = self._makeOne() inst.request = DummyParser() - inst.request.headers['CONNECTION'] = 'keep-alive' - inst.response_headers = [('Content-Length', '10')] - inst.version = '1.0' + inst.request.headers["CONNECTION"] = "keep-alive" + inst.response_headers = [("Content-Length", "10")] + inst.version = "1.0" inst.content_length = 0 result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b'HTTP/1.0 200 OK') - self.assertEqual(lines[1], b'Connection: Keep-Alive') - self.assertEqual(lines[2], b'Content-Length: 10') - self.assertTrue(lines[3].startswith(b'Date:')) - self.assertEqual(lines[4], b'Server: waitress') + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: Keep-Alive") + self.assertEqual(lines[2], b"Content-Length: 10") + self.assertTrue(lines[3].startswith(b"Date:")) + self.assertEqual(lines[4], b"Server: waitress") self.assertEqual(inst.close_on_finish, False) def test_build_response_header_v11_connection_closed_by_client(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '1.1' - inst.request.headers['CONNECTION'] = 'close' + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b'HTTP/1.1 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') - self.assertEqual(lines[4], b'Transfer-Encoding: chunked') - self.assertTrue(('Connection', 'close') in inst.response_headers) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) self.assertEqual(inst.close_on_finish, True) def test_build_response_header_v11_connection_keepalive_by_client(self): inst = self._makeOne() inst.request = DummyParser() - inst.request.headers['CONNECTION'] = 'keep-alive' - inst.version = '1.1' + inst.request.headers["CONNECTION"] = "keep-alive" + inst.version = "1.1" result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b'HTTP/1.1 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') - self.assertEqual(lines[4], b'Transfer-Encoding: chunked') - self.assertTrue(('Connection', 'close') in inst.response_headers) + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertTrue(("Connection", "close") in inst.response_headers) self.assertEqual(inst.close_on_finish, True) def test_build_response_header_v11_200_no_content_length(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '1.1' + inst.version = "1.1" result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b'HTTP/1.1 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') - self.assertEqual(lines[4], b'Transfer-Encoding: chunked') + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(lines[4], b"Transfer-Encoding: chunked") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "204 No Content" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 204 No Content") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") self.assertEqual(inst.close_on_finish, True) - self.assertTrue(('Connection', 'close') in inst.response_headers) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx or 204. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "100 Continue" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 100 Continue") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) + + def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self): + # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length + # for any response with a status code of 1xx, 204 or 304. + inst = self._makeOne() + inst.request = DummyParser() + inst.version = "1.1" + inst.status = "304 Not Modified" + result = inst.build_response_header() + lines = filter_lines(result) + self.assertEqual(len(lines), 4) + self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") + self.assertEqual(inst.close_on_finish, True) + self.assertTrue(("Connection", "close") in inst.response_headers) def test_build_response_header_via_added(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '1.0' - inst.response_headers = [('Server', 'abc')] + inst.version = "1.0" + inst.response_headers = [("Server", "abc")] result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b'HTTP/1.0 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: abc') - self.assertEqual(lines[4], b'Via: waitress') + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: abc") + self.assertEqual(lines[4], b"Via: waitress") def test_build_response_header_date_exists(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '1.0' - inst.response_headers = [('Date', 'date')] + inst.version = "1.0" + inst.response_headers = [("Date", "date")] result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b'HTTP/1.0 200 OK') - self.assertEqual(lines[1], b'Connection: close') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') + self.assertEqual(lines[0], b"HTTP/1.0 200 OK") + self.assertEqual(lines[1], b"Connection: close") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") def test_build_response_header_preexisting_content_length(self): inst = self._makeOne() inst.request = DummyParser() - inst.version = '1.1' + inst.version = "1.1" inst.content_length = 100 result = inst.build_response_header() lines = filter_lines(result) self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b'HTTP/1.1 200 OK') - self.assertEqual(lines[1], b'Content-Length: 100') - self.assertTrue(lines[2].startswith(b'Date:')) - self.assertEqual(lines[3], b'Server: waitress') + self.assertEqual(lines[0], b"HTTP/1.1 200 OK") + self.assertEqual(lines[1], b"Content-Length: 100") + self.assertTrue(lines[2].startswith(b"Date:")) + self.assertEqual(lines[3], b"Server: waitress") def test_remove_content_length_header(self): inst = self._makeOne() - inst.response_headers = [('Content-Length', '70')] + inst.response_headers = [("Content-Length", "70")] inst.remove_content_length_header() self.assertEqual(inst.response_headers, []) + def test_remove_content_length_header_with_other(self): + inst = self._makeOne() + inst.response_headers = [ + ("Content-Length", "70"), + ("Content-Type", "text/html"), + ] + inst.remove_content_length_header() + self.assertEqual(inst.response_headers, [("Content-Type", "text/html")]) + def test_start(self): inst = self._makeOne() inst.start() @@ -271,35 +331,35 @@ def test_finish_chunked_response(self): inst.wrote_header = True inst.chunked_response = True inst.finish() - self.assertEqual(inst.channel.written, b'0\r\n\r\n') + self.assertEqual(inst.channel.written, b"0\r\n\r\n") def test_write_wrote_header(self): inst = self._makeOne() inst.wrote_header = True inst.complete = True inst.content_length = 3 - inst.write(b'abc') - self.assertEqual(inst.channel.written, b'abc') + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"abc") def test_write_header_not_written(self): inst = self._makeOne() inst.wrote_header = False inst.complete = True - inst.write(b'abc') + inst.write(b"abc") self.assertTrue(inst.channel.written) self.assertEqual(inst.wrote_header, True) def test_write_start_response_uncalled(self): inst = self._makeOne() - self.assertRaises(RuntimeError, inst.write, b'') + self.assertRaises(RuntimeError, inst.write, b"") def test_write_chunked_response(self): inst = self._makeOne() inst.wrote_header = True inst.chunked_response = True inst.complete = True - inst.write(b'abc') - self.assertEqual(inst.channel.written, b'3\r\nabc\r\n') + inst.write(b"abc") + self.assertEqual(inst.channel.written, b"3\r\nabc\r\n") def test_write_preexisting_content_length(self): inst = self._makeOne() @@ -307,25 +367,28 @@ def test_write_preexisting_content_length(self): inst.complete = True inst.content_length = 1 inst.logger = DummyLogger() - inst.write(b'abc') + inst.write(b"abc") self.assertTrue(inst.channel.written) self.assertEqual(inst.logged_write_excess, True) self.assertEqual(len(inst.logger.logged), 1) -class TestWSGITask(unittest.TestCase): +class TestWSGITask(unittest.TestCase): def _makeOne(self, channel=None, request=None): if channel is None: channel = DummyChannel() if request is None: request = DummyParser() from waitress.task import WSGITask + return WSGITask(channel, request) def test_service(self): inst = self._makeOne() + def execute(): inst.executed = True + inst.execute = execute inst.complete = True inst.service() @@ -336,9 +399,12 @@ def execute(): def test_service_server_raises_socket_error(self): import socket + inst = self._makeOne() + def execute(): - raise socket.error + raise OSError + inst.execute = execute self.assertRaises(socket.error, inst.service) self.assertTrue(inst.start_time) @@ -347,41 +413,45 @@ def execute(): def test_execute_app_calls_start_response_twice_wo_exc_info(self): def app(environ, start_response): - start_response('200 OK', []) - start_response('200 OK', []) + start_response("200 OK", []) + start_response("200 OK", []) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) def test_execute_app_calls_start_response_w_exc_info_complete(self): def app(environ, start_response): - start_response('200 OK', [], [ValueError, ValueError(), None]) - return [b'a'] + start_response("200 OK", [], [ValueError, ValueError(), None]) + return [b"a"] + inst = self._makeOne() inst.complete = True inst.channel.server.application = app inst.execute() self.assertTrue(inst.complete) - self.assertEqual(inst.status, '200 OK') + self.assertEqual(inst.status, "200 OK") self.assertTrue(inst.channel.written) def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self): def app(environ, start_response): - start_response('200 OK', [], [ValueError, None, None]) - return [b'a'] + start_response("200 OK", [], [ValueError, None, None]) + return [b"a"] + inst = self._makeOne() inst.wrote_header = False inst.channel.server.application = app - inst.response_headers = [('a', 'b')] + inst.response_headers = [("a", "b")] inst.execute() self.assertTrue(inst.complete) - self.assertEqual(inst.status, '200 OK') + self.assertEqual(inst.status, "200 OK") self.assertTrue(inst.channel.written) - self.assertFalse(('a','b') in inst.response_headers) + self.assertFalse(("a", "b") in inst.response_headers) def test_execute_app_calls_start_response_w_excinf_headers_written(self): def app(environ, start_response): - start_response('200 OK', [], [ValueError, ValueError(), None]) + start_response("200 OK", [], [ValueError, ValueError(), None]) + inst = self._makeOne() inst.complete = True inst.wrote_header = True @@ -390,67 +460,76 @@ def app(environ, start_response): def test_execute_bad_header_key(self): def app(environ, start_response): - start_response('200 OK', [(None, 'a')]) + start_response("200 OK", [(None, "a")]) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) def test_execute_bad_header_value(self): def app(environ, start_response): - start_response('200 OK', [('a', None)]) + start_response("200 OK", [("a", None)]) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) def test_execute_hopbyhop_header(self): def app(environ, start_response): - start_response('200 OK', [('Connection', 'close')]) + start_response("200 OK", [("Connection", "close")]) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) def test_execute_bad_header_value_control_characters(self): def app(environ, start_response): - start_response('200 OK', [('a', '\n')]) + start_response("200 OK", [("a", "\n")]) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(ValueError, inst.execute) def test_execute_bad_header_name_control_characters(self): def app(environ, start_response): - start_response('200 OK', [('a\r', 'value')]) + start_response("200 OK", [("a\r", "value")]) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(ValueError, inst.execute) def test_execute_bad_status_control_characters(self): def app(environ, start_response): - start_response('200 OK\r', []) + start_response("200 OK\r", []) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(ValueError, inst.execute) def test_preserve_header_value_order(self): def app(environ, start_response): - write = start_response('200 OK', [('C', 'b'), ('A', 'b'), ('A', 'a')]) - write(b'abc') + write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")]) + write(b"abc") return [] + inst = self._makeOne() inst.channel.server.application = app inst.execute() - self.assertTrue(b'A: b\r\nA: a\r\nC: b\r\n' in inst.channel.written) + self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written) def test_execute_bad_status_value(self): def app(environ, start_response): start_response(None, []) + inst = self._makeOne() inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) def test_execute_with_content_length_header(self): def app(environ, start_response): - start_response('200 OK', [('Content-Length', '1')]) - return [b'a'] + start_response("200 OK", [("Content-Length", "1")]) + return [b"a"] + inst = self._makeOne() inst.channel.server.application = app inst.execute() @@ -458,18 +537,20 @@ def app(environ, start_response): def test_execute_app_calls_write(self): def app(environ, start_response): - write = start_response('200 OK', [('Content-Length', '3')]) - write(b'abc') + write = start_response("200 OK", [("Content-Length", "3")]) + write(b"abc") return [] + inst = self._makeOne() inst.channel.server.application = app inst.execute() - self.assertEqual(inst.channel.written[-3:], b'abc') + self.assertEqual(inst.channel.written[-3:], b"abc") def test_execute_app_returns_len1_chunk_without_cl(self): def app(environ, start_response): - start_response('200 OK', []) - return [b'abc'] + start_response("200 OK", []) + return [b"abc"] + inst = self._makeOne() inst.channel.server.application = app inst.execute() @@ -477,8 +558,9 @@ def app(environ, start_response): def test_execute_app_returns_empty_chunk_as_first(self): def app(environ, start_response): - start_response('200 OK', []) - return ['', b'abc'] + start_response("200 OK", []) + return ["", b"abc"] + inst = self._makeOne() inst.channel.server.application = app inst.execute() @@ -486,8 +568,9 @@ def app(environ, start_response): def test_execute_app_returns_too_many_bytes(self): def app(environ, start_response): - start_response('200 OK', [('Content-Length', '1')]) - return [b'abc'] + start_response("200 OK", [("Content-Length", "1")]) + return [b"abc"] + inst = self._makeOne() inst.channel.server.application = app inst.logger = DummyLogger() @@ -497,8 +580,9 @@ def app(environ, start_response): def test_execute_app_returns_too_few_bytes(self): def app(environ, start_response): - start_response('200 OK', [('Content-Length', '3')]) - return [b'a'] + start_response("200 OK", [("Content-Length", "3")]) + return [b"a"] + inst = self._makeOne() inst.channel.server.application = app inst.logger = DummyLogger() @@ -508,24 +592,58 @@ def app(environ, start_response): def test_execute_app_do_not_warn_on_head(self): def app(environ, start_response): - start_response('200 OK', [('Content-Length', '3')]) - return [b''] + start_response("200 OK", [("Content-Length", "3")]) + return [b""] + inst = self._makeOne() - inst.request.command = 'HEAD' + inst.request.command = "HEAD" inst.channel.server.application = app inst.logger = DummyLogger() inst.execute() self.assertEqual(inst.close_on_finish, True) self.assertEqual(len(inst.logger.logged), 0) + def test_execute_app_without_body_204_logged(self): + def app(environ, start_response): + start_response("204 No Content", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + + def test_execute_app_without_body_304_logged(self): + def app(environ, start_response): + start_response("304 Not Modified", [("Content-Length", "3")]) + return [b"abc"] + + inst = self._makeOne() + inst.channel.server.application = app + inst.logger = DummyLogger() + inst.execute() + self.assertEqual(inst.close_on_finish, True) + self.assertNotIn(b"abc", inst.channel.written) + self.assertNotIn(b"Content-Length", inst.channel.written) + self.assertNotIn(b"Transfer-Encoding", inst.channel.written) + self.assertEqual(len(inst.logger.logged), 1) + def test_execute_app_returns_closeable(self): class closeable(list): def close(self): self.closed = True - foo = closeable([b'abc']) + + foo = closeable([b"abc"]) + def app(environ, start_response): - start_response('200 OK', [('Content-Length', '3')]) + start_response("200 OK", [("Content-Length", "3")]) return foo + inst = self._makeOne() inst.channel.server.application = app inst.execute() @@ -533,47 +651,56 @@ def app(environ, start_response): def test_execute_app_returns_filewrapper_prepare_returns_True(self): from waitress.buffers import ReadOnlyFileBasedBuffer - f = io.BytesIO(b'abc') + + f = io.BytesIO(b"abc") app_iter = ReadOnlyFileBasedBuffer(f, 8192) + def app(environ, start_response): - start_response('200 OK', [('Content-Length', '3')]) + start_response("200 OK", [("Content-Length", "3")]) return app_iter + inst = self._makeOne() inst.channel.server.application = app inst.execute() - self.assertTrue(inst.channel.written) # header + self.assertTrue(inst.channel.written) # header self.assertEqual(inst.channel.otherdata, [app_iter]) def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self): from waitress.buffers import ReadOnlyFileBasedBuffer - f = io.BytesIO(b'abc') + + f = io.BytesIO(b"abc") app_iter = ReadOnlyFileBasedBuffer(f, 8192) + def app(environ, start_response): - start_response('200 OK', []) + start_response("200 OK", []) return app_iter + inst = self._makeOne() inst.channel.server.application = app inst.execute() - self.assertTrue(inst.channel.written) # header + self.assertTrue(inst.channel.written) # header self.assertEqual(inst.channel.otherdata, [app_iter]) self.assertEqual(inst.content_length, 3) def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self): from waitress.buffers import ReadOnlyFileBasedBuffer - f = io.BytesIO(b'abc') + + f = io.BytesIO(b"abc") app_iter = ReadOnlyFileBasedBuffer(f, 8192) + def app(environ, start_response): - start_response('200 OK', []) + start_response("200 OK", []) return app_iter + inst = self._makeOne() inst.channel.server.application = app inst.content_length = 10 - inst.response_headers = [('Content-Length', '10')] + inst.response_headers = [("Content-Length", "10")] inst.execute() - self.assertTrue(inst.channel.written) # header + self.assertTrue(inst.channel.written) # header self.assertEqual(inst.channel.otherdata, [app_iter]) self.assertEqual(inst.content_length, 3) - self.assertEqual(dict(inst.response_headers)['Content-Length'], '3') + self.assertEqual(dict(inst.response_headers)["Content-Length"], "3") def test_get_environment_already_cached(self): inst = self._makeOne() @@ -583,315 +710,265 @@ def test_get_environment_already_cached(self): def test_get_environment_path_startswith_more_than_one_slash(self): inst = self._makeOne() request = DummyParser() - request.path = '///abc' + request.path = "///abc" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['PATH_INFO'], '/abc') + self.assertEqual(environ["PATH_INFO"], "/abc") def test_get_environment_path_empty(self): inst = self._makeOne() request = DummyParser() - request.path = '' + request.path = "" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['PATH_INFO'], '') + self.assertEqual(environ["PATH_INFO"], "") def test_get_environment_no_query(self): inst = self._makeOne() request = DummyParser() inst.request = request environ = inst.get_environment() - self.assertEqual(environ['QUERY_STRING'], '') + self.assertEqual(environ["QUERY_STRING"], "") def test_get_environment_with_query(self): inst = self._makeOne() request = DummyParser() - request.query = 'abc' + request.query = "abc" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['QUERY_STRING'], 'abc') + self.assertEqual(environ["QUERY_STRING"], "abc") def test_get_environ_with_url_prefix_miss(self): inst = self._makeOne() - inst.channel.server.adj.url_prefix = '/foo' + inst.channel.server.adj.url_prefix = "/foo" request = DummyParser() - request.path = '/bar' + request.path = "/bar" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['PATH_INFO'], '/bar') - self.assertEqual(environ['SCRIPT_NAME'], '/foo') + self.assertEqual(environ["PATH_INFO"], "/bar") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") def test_get_environ_with_url_prefix_hit(self): inst = self._makeOne() - inst.channel.server.adj.url_prefix = '/foo' + inst.channel.server.adj.url_prefix = "/foo" request = DummyParser() - request.path = '/foo/fuz' + request.path = "/foo/fuz" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['PATH_INFO'], '/fuz') - self.assertEqual(environ['SCRIPT_NAME'], '/foo') + self.assertEqual(environ["PATH_INFO"], "/fuz") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") def test_get_environ_with_url_prefix_empty_path(self): inst = self._makeOne() - inst.channel.server.adj.url_prefix = '/foo' + inst.channel.server.adj.url_prefix = "/foo" request = DummyParser() - request.path = '/foo' + request.path = "/foo" inst.request = request environ = inst.get_environment() - self.assertEqual(environ['PATH_INFO'], '') - self.assertEqual(environ['SCRIPT_NAME'], '/foo') + self.assertEqual(environ["PATH_INFO"], "") + self.assertEqual(environ["SCRIPT_NAME"], "/foo") def test_get_environment_values(self): import sys - inst = self._makeOne() - request = DummyParser() - request.headers = { - 'CONTENT_TYPE': 'abc', - 'CONTENT_LENGTH': '10', - 'X_FOO': 'BAR', - 'CONNECTION': 'close', - } - request.query = 'abc' - inst.request = request - environ = inst.get_environment() - - # nail the keys of environ - self.assertEqual(sorted(environ.keys()), [ - 'CONTENT_LENGTH', 'CONTENT_TYPE', 'HTTP_CONNECTION', 'HTTP_X_FOO', - 'PATH_INFO', 'QUERY_STRING', 'REMOTE_ADDR', 'REQUEST_METHOD', - 'SCRIPT_NAME', 'SERVER_NAME', 'SERVER_PORT', 'SERVER_PROTOCOL', - 'SERVER_SOFTWARE', 'wsgi.errors', 'wsgi.file_wrapper', 'wsgi.input', - 'wsgi.multiprocess', 'wsgi.multithread', 'wsgi.run_once', - 'wsgi.url_scheme', 'wsgi.version']) - - self.assertEqual(environ['REQUEST_METHOD'], 'GET') - self.assertEqual(environ['SERVER_PORT'], '80') - self.assertEqual(environ['SERVER_NAME'], 'localhost') - self.assertEqual(environ['SERVER_SOFTWARE'], 'waitress') - self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') - self.assertEqual(environ['SCRIPT_NAME'], '') - self.assertEqual(environ['HTTP_CONNECTION'], 'close') - self.assertEqual(environ['PATH_INFO'], '/') - self.assertEqual(environ['QUERY_STRING'], 'abc') - self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') - self.assertEqual(environ['CONTENT_TYPE'], 'abc') - self.assertEqual(environ['CONTENT_LENGTH'], '10') - self.assertEqual(environ['HTTP_X_FOO'], 'BAR') - self.assertEqual(environ['wsgi.version'], (1, 0)) - self.assertEqual(environ['wsgi.url_scheme'], 'http') - self.assertEqual(environ['wsgi.errors'], sys.stderr) - self.assertEqual(environ['wsgi.multithread'], True) - self.assertEqual(environ['wsgi.multiprocess'], False) - self.assertEqual(environ['wsgi.run_once'], False) - self.assertEqual(environ['wsgi.input'], 'stream') - self.assertEqual(inst.environ, environ) - - def test_get_environment_values_w_scheme_override_untrusted(self): - inst = self._makeOne() - request = DummyParser() - request.headers = { - 'CONTENT_TYPE': 'abc', - 'CONTENT_LENGTH': '10', - 'X_FOO': 'BAR', - 'X_FORWARDED_PROTO': 'https', - 'CONNECTION': 'close', - } - request.query = 'abc' - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ['wsgi.url_scheme'], 'http') - def test_get_environment_values_w_scheme_override_trusted(self): - import sys inst = self._makeOne() - inst.channel.addr = ['192.168.1.1'] - inst.channel.server.adj.trusted_proxy = '192.168.1.1' request = DummyParser() request.headers = { - 'CONTENT_TYPE': 'abc', - 'CONTENT_LENGTH': '10', - 'X_FOO': 'BAR', - 'X_FORWARDED_PROTO': 'https', - 'CONNECTION': 'close', + "CONTENT_TYPE": "abc", + "CONTENT_LENGTH": "10", + "X_FOO": "BAR", + "CONNECTION": "close", } - request.query = 'abc' + request.query = "abc" inst.request = request environ = inst.get_environment() # nail the keys of environ - self.assertEqual(sorted(environ.keys()), [ - 'CONTENT_LENGTH', 'CONTENT_TYPE', 'HTTP_CONNECTION', 'HTTP_X_FOO', - 'PATH_INFO', 'QUERY_STRING', 'REMOTE_ADDR', 'REQUEST_METHOD', - 'SCRIPT_NAME', 'SERVER_NAME', 'SERVER_PORT', 'SERVER_PROTOCOL', - 'SERVER_SOFTWARE', 'wsgi.errors', 'wsgi.file_wrapper', 'wsgi.input', - 'wsgi.multiprocess', 'wsgi.multithread', 'wsgi.run_once', - 'wsgi.url_scheme', 'wsgi.version']) - - self.assertEqual(environ['REQUEST_METHOD'], 'GET') - self.assertEqual(environ['SERVER_PORT'], '80') - self.assertEqual(environ['SERVER_NAME'], 'localhost') - self.assertEqual(environ['SERVER_SOFTWARE'], 'waitress') - self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') - self.assertEqual(environ['SCRIPT_NAME'], '') - self.assertEqual(environ['HTTP_CONNECTION'], 'close') - self.assertEqual(environ['PATH_INFO'], '/') - self.assertEqual(environ['QUERY_STRING'], 'abc') - self.assertEqual(environ['REMOTE_ADDR'], '192.168.1.1') - self.assertEqual(environ['CONTENT_TYPE'], 'abc') - self.assertEqual(environ['CONTENT_LENGTH'], '10') - self.assertEqual(environ['HTTP_X_FOO'], 'BAR') - self.assertEqual(environ['wsgi.version'], (1, 0)) - self.assertEqual(environ['wsgi.url_scheme'], 'https') - self.assertEqual(environ['wsgi.errors'], sys.stderr) - self.assertEqual(environ['wsgi.multithread'], True) - self.assertEqual(environ['wsgi.multiprocess'], False) - self.assertEqual(environ['wsgi.run_once'], False) - self.assertEqual(environ['wsgi.input'], 'stream') + self.assertEqual( + sorted(environ.keys()), + [ + "CONTENT_LENGTH", + "CONTENT_TYPE", + "HTTP_CONNECTION", + "HTTP_X_FOO", + "PATH_INFO", + "QUERY_STRING", + "REMOTE_ADDR", + "REMOTE_HOST", + "REMOTE_PORT", + "REQUEST_METHOD", + "REQUEST_URI", + "SCRIPT_NAME", + "SERVER_NAME", + "SERVER_PORT", + "SERVER_PROTOCOL", + "SERVER_SOFTWARE", + "waitress.client_disconnected", + "wsgi.errors", + "wsgi.file_wrapper", + "wsgi.input", + "wsgi.input_terminated", + "wsgi.multiprocess", + "wsgi.multithread", + "wsgi.run_once", + "wsgi.url_scheme", + "wsgi.version", + ], + ) + + self.assertEqual(environ["REQUEST_METHOD"], "GET") + self.assertEqual(environ["SERVER_PORT"], "80") + self.assertEqual(environ["SERVER_NAME"], "localhost") + self.assertEqual(environ["SERVER_SOFTWARE"], "waitress") + self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0") + self.assertEqual(environ["SCRIPT_NAME"], "") + self.assertEqual(environ["HTTP_CONNECTION"], "close") + self.assertEqual(environ["PATH_INFO"], "/") + self.assertEqual(environ["QUERY_STRING"], "abc") + self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") + self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1") + self.assertEqual(environ["REMOTE_PORT"], "39830") + self.assertEqual(environ["CONTENT_TYPE"], "abc") + self.assertEqual(environ["CONTENT_LENGTH"], "10") + self.assertEqual(environ["HTTP_X_FOO"], "BAR") + self.assertEqual(environ["wsgi.version"], (1, 0)) + self.assertEqual(environ["wsgi.url_scheme"], "http") + self.assertEqual(environ["wsgi.errors"], sys.stderr) + self.assertEqual(environ["wsgi.multithread"], True) + self.assertEqual(environ["wsgi.multiprocess"], False) + self.assertEqual(environ["wsgi.run_once"], False) + self.assertEqual(environ["wsgi.input"], "stream") + self.assertEqual(environ["wsgi.input_terminated"], True) self.assertEqual(inst.environ, environ) - def test_get_environment_values_w_bogus_scheme_override(self): - inst = self._makeOne() - inst.channel.addr = ['192.168.1.1'] - inst.channel.server.adj.trusted_proxy = '192.168.1.1' - request = DummyParser() - request.headers = { - 'CONTENT_TYPE': 'abc', - 'CONTENT_LENGTH': '10', - 'X_FOO': 'BAR', - 'X_FORWARDED_PROTO': 'http://p02n3e.com?url=http', - 'CONNECTION': 'close', - } - request.query = 'abc' - inst.request = request - self.assertRaises(ValueError, inst.get_environment) class TestErrorTask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): if channel is None: channel = DummyChannel() if request is None: request = DummyParser() - request.error = DummyError() + request.error = self._makeDummyError() from waitress.task import ErrorTask + return ErrorTask(channel, request) + def _makeDummyError(self): + from waitress.utilities import Error + + e = Error("body") + e.code = 432 + e.reason = "Too Ugly" + return e + def test_execute_http_10(self): inst = self._makeOne() inst.execute() lines = filter_lines(inst.channel.written) self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b'HTTP/1.0 432 Too Ugly') - self.assertEqual(lines[1], b'Connection: close') - self.assertEqual(lines[2], b'Content-Length: 43') - self.assertEqual(lines[3], b'Content-Type: text/plain') + self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain; charset=utf-8") self.assertTrue(lines[4]) - self.assertEqual(lines[5], b'Server: waitress') - self.assertEqual(lines[6], b'Too Ugly') - self.assertEqual(lines[7], b'body') - self.assertEqual(lines[8], b'(generated by waitress)') + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") def test_execute_http_11(self): inst = self._makeOne() - inst.version = '1.1' + inst.version = "1.1" inst.execute() lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 8) - self.assertEqual(lines[0], b'HTTP/1.1 432 Too Ugly') - self.assertEqual(lines[1], b'Content-Length: 43') - self.assertEqual(lines[2], b'Content-Type: text/plain') - self.assertTrue(lines[3]) - self.assertEqual(lines[4], b'Server: waitress') - self.assertEqual(lines[5], b'Too Ugly') - self.assertEqual(lines[6], b'body') - self.assertEqual(lines[7], b'(generated by waitress)') + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain; charset=utf-8") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") def test_execute_http_11_close(self): inst = self._makeOne() - inst.version = '1.1' - inst.request.headers['CONNECTION'] = 'close' + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "close" inst.execute() lines = filter_lines(inst.channel.written) self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b'HTTP/1.1 432 Too Ugly') - self.assertEqual(lines[1], b'Connection: close') - self.assertEqual(lines[2], b'Content-Length: 43') - self.assertEqual(lines[3], b'Content-Type: text/plain') + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain; charset=utf-8") self.assertTrue(lines[4]) - self.assertEqual(lines[5], b'Server: waitress') - self.assertEqual(lines[6], b'Too Ugly') - self.assertEqual(lines[7], b'body') - self.assertEqual(lines[8], b'(generated by waitress)') + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") - def test_execute_http_11_keep(self): + def test_execute_http_11_keep_forces_close(self): inst = self._makeOne() - inst.version = '1.1' - inst.request.headers['CONNECTION'] = 'keep-alive' + inst.version = "1.1" + inst.request.headers["CONNECTION"] = "keep-alive" inst.execute() lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 8) - self.assertEqual(lines[0], b'HTTP/1.1 432 Too Ugly') - self.assertEqual(lines[1], b'Content-Length: 43') - self.assertEqual(lines[2], b'Content-Type: text/plain') - self.assertTrue(lines[3]) - self.assertEqual(lines[4], b'Server: waitress') - self.assertEqual(lines[5], b'Too Ugly') - self.assertEqual(lines[6], b'body') - self.assertEqual(lines[7], b'(generated by waitress)') - - -class DummyError(object): - code = '432' - reason = 'Too Ugly' - body = 'body' - -class DummyTask(object): + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") + self.assertEqual(lines[1], b"Connection: close") + self.assertEqual(lines[2], b"Content-Length: 43") + self.assertEqual(lines[3], b"Content-Type: text/plain; charset=utf-8") + self.assertTrue(lines[4]) + self.assertEqual(lines[5], b"Server: waitress") + self.assertEqual(lines[6], b"Too Ugly") + self.assertEqual(lines[7], b"body") + self.assertEqual(lines[8], b"(generated by waitress)") + + +class DummyTask: serviced = False - deferred = False cancelled = False - def __init__(self, toraise=None): - self.toraise = toraise - def service(self): self.serviced = True - if self.toraise: - raise self.toraise - - def defer(self): - self.deferred = True - if self.toraise: - raise self.toraise def cancel(self): self.cancelled = True -class DummyAdj(object): + +class DummyAdj: log_socket_errors = True - ident = 'waitress' - host = '127.0.0.1' + ident = "waitress" + host = "127.0.0.1" port = 80 - url_prefix = '' - trusted_proxy = None + url_prefix = "" + -class DummyServer(object): - server_name = 'localhost' +class DummyServer: + server_name = "localhost" effective_port = 80 def __init__(self): self.adj = DummyAdj() -class DummyChannel(object): + +class DummyChannel: closed_when_done = False adj = DummyAdj() creation_time = 0 - addr = ['127.0.0.1'] + addr = ("127.0.0.1", 39830) + + def check_client_disconnected(self): + # For now, until we have tests handling this feature + return False def __init__(self, server=None): if server is None: server = DummyServer() self.server = server - self.written = b'' + self.written = b"" self.otherdata = [] def write_soon(self, data): @@ -901,12 +978,14 @@ def write_soon(self, data): self.otherdata.append(data) return len(data) -class DummyParser(object): - version = '1.0' - command = 'GET' - path = '/' - query = '' - url_scheme = 'http' + +class DummyParser: + version = "1.0" + command = "GET" + path = "/" + request_uri = "/" + query = "" + url_scheme = "http" expect_continue = False headers_finished = False @@ -914,18 +993,19 @@ def __init__(self): self.headers = {} def get_body_stream(self): - return 'stream' + return "stream" + def filter_lines(s): - return list(filter(None, s.split(b'\r\n'))) + return list(filter(None, s.split(b"\r\n"))) -class DummyLogger(object): +class DummyLogger: def __init__(self): self.logged = [] - def warning(self, msg): - self.logged.append(msg) + def warning(self, msg, *args): + self.logged.append(msg % args) - def exception(self, msg): - self.logged.append(msg) + def exception(self, msg, *args): + self.logged.append(msg % args) diff --git a/waitress/tests/test_trigger.py b/tests/test_trigger.py similarity index 87% rename from waitress/tests/test_trigger.py rename to tests/test_trigger.py index bfff16e4..265679a4 100644 --- a/waitress/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -1,29 +1,33 @@ -import unittest import os import sys +import unittest if not sys.platform.startswith("win"): class Test_trigger(unittest.TestCase): - def _makeOne(self, map): from waitress.trigger import trigger - return trigger(map) + + self.inst = trigger(map) + return self.inst + + def tearDown(self): + self.inst.close() # prevent __del__ warning from file_dispatcher def test__close(self): map = {} inst = self._makeOne(map) - fd = os.open(os.path.abspath(__file__), os.O_RDONLY) - inst._fds = (fd,) + fd1, fd2 = inst._fds inst.close() - self.assertRaises(OSError, os.read, fd, 1) + self.assertRaises(OSError, os.read, fd1, 1) + self.assertRaises(OSError, os.read, fd2, 1) def test__physical_pull(self): map = {} inst = self._makeOne(map) inst._physical_pull() r = os.read(inst._fds[0], 1) - self.assertEqual(r, b'x') + self.assertEqual(r, b"x") def test_readable(self): map = {} @@ -57,7 +61,7 @@ def test_pull_trigger_nothunk(self): inst = self._makeOne(map) self.assertEqual(inst.pull_trigger(), None) r = os.read(inst._fds[0], 1) - self.assertEqual(r, b'x') + self.assertEqual(r, b"x") def test_pull_trigger_thunk(self): map = {} @@ -65,7 +69,7 @@ def test_pull_trigger_thunk(self): self.assertEqual(inst.pull_trigger(True), None) self.assertEqual(len(inst.thunks), 1) r = os.read(inst._fds[0], 1) - self.assertEqual(r, b'x') + self.assertEqual(r, b"x") def test_handle_read_socket_error(self): map = {} @@ -94,8 +98,10 @@ def test_handle_read_thunk(self): def test_handle_read_thunk_error(self): map = {} inst = self._makeOne(map) + def errorthunk(): raise ValueError + inst.pull_trigger(errorthunk) L = [] inst.log_info = lambda *arg: L.append(arg) diff --git a/waitress/tests/test_utilities.py b/tests/test_utilities.py similarity index 60% rename from waitress/tests/test_utilities.py rename to tests/test_utilities.py index 73f6c7b7..ea08477e 100644 --- a/waitress/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -14,108 +14,128 @@ import unittest -class Test_parse_http_date(unittest.TestCase): +class Test_parse_http_date(unittest.TestCase): def _callFUT(self, v): from waitress.utilities import parse_http_date + return parse_http_date(v) def test_rfc850(self): - val = 'Tuesday, 08-Feb-94 14:15:29 GMT' + val = "Tuesday, 08-Feb-94 14:15:29 GMT" result = self._callFUT(val) self.assertEqual(result, 760716929) def test_rfc822(self): - val = 'Sun, 08 Feb 1994 14:15:29 GMT' + val = "Sun, 08 Feb 1994 14:15:29 GMT" result = self._callFUT(val) self.assertEqual(result, 760716929) def test_neither(self): - val = '' + val = "" result = self._callFUT(val) self.assertEqual(result, 0) -class Test_build_http_date(unittest.TestCase): +class Test_build_http_date(unittest.TestCase): def test_rountdrip(self): - from waitress.utilities import build_http_date, parse_http_date from time import time + + from waitress.utilities import build_http_date, parse_http_date + t = int(time()) self.assertEqual(t, parse_http_date(build_http_date(t))) -class Test_unpack_rfc850(unittest.TestCase): +class Test_unpack_rfc850(unittest.TestCase): def _callFUT(self, val): - from waitress.utilities import unpack_rfc850, rfc850_reg + from waitress.utilities import rfc850_reg, unpack_rfc850 + return unpack_rfc850(rfc850_reg.match(val.lower())) def test_it(self): - val = 'Tuesday, 08-Feb-94 14:15:29 GMT' + val = "Tuesday, 08-Feb-94 14:15:29 GMT" result = self._callFUT(val) self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) -class Test_unpack_rfc_822(unittest.TestCase): +class Test_unpack_rfc_822(unittest.TestCase): def _callFUT(self, val): - from waitress.utilities import unpack_rfc822, rfc822_reg + from waitress.utilities import rfc822_reg, unpack_rfc822 + return unpack_rfc822(rfc822_reg.match(val.lower())) def test_it(self): - val = 'Sun, 08 Feb 1994 14:15:29 GMT' + val = "Sun, 08 Feb 1994 14:15:29 GMT" result = self._callFUT(val) self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) -class Test_find_double_newline(unittest.TestCase): +class Test_find_double_newline(unittest.TestCase): def _callFUT(self, val): from waitress.utilities import find_double_newline + return find_double_newline(val) def test_empty(self): - self.assertEqual(self._callFUT(b''), -1) + self.assertEqual(self._callFUT(b""), -1) def test_one_linefeed(self): - self.assertEqual(self._callFUT(b'\n'), -1) + self.assertEqual(self._callFUT(b"\n"), -1) def test_double_linefeed(self): - self.assertEqual(self._callFUT(b'\n\n'), 2) + self.assertEqual(self._callFUT(b"\n\n"), -1) def test_one_crlf(self): - self.assertEqual(self._callFUT(b'\r\n'), -1) + self.assertEqual(self._callFUT(b"\r\n"), -1) def test_double_crfl(self): - self.assertEqual(self._callFUT(b'\r\n\r\n'), 4) + self.assertEqual(self._callFUT(b"\r\n\r\n"), 4) def test_mixed(self): - self.assertEqual(self._callFUT(b'\n\n00\r\n\r\n'), 2) + self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8) -class Test_logging_dispatcher(unittest.TestCase): - - def _makeOne(self): - from waitress.utilities import logging_dispatcher - return logging_dispatcher(map={}) - - def test_log_info(self): - import logging - inst = self._makeOne() - logger = DummyLogger() - inst.logger = logger - inst.log_info('message', 'warning') - self.assertEqual(logger.severity, logging.WARN) - self.assertEqual(logger.message, 'message') class TestBadRequest(unittest.TestCase): - def _makeOne(self): from waitress.utilities import BadRequest + return BadRequest(1) def test_it(self): inst = self._makeOne() self.assertEqual(inst.body, 1) -class DummyLogger(object): - def log(self, severity, message): - self.severity = severity - self.message = message +class Test_undquote(unittest.TestCase): + def _callFUT(self, value): + from waitress.utilities import undquote + + return undquote(value) + + def test_empty(self): + self.assertEqual(self._callFUT(""), "") + + def test_quoted(self): + self.assertEqual(self._callFUT('"test"'), "test") + + def test_unquoted(self): + self.assertEqual(self._callFUT("test"), "test") + + def test_quoted_backslash_quote(self): + self.assertEqual(self._callFUT('"\\""'), '"') + + def test_quoted_htab(self): + self.assertEqual(self._callFUT('"\t"'), "\t") + + def test_quoted_backslash_htab(self): + self.assertEqual(self._callFUT('"\\\t"'), "\t") + + def test_quoted_backslash_invalid(self): + self.assertRaises(ValueError, self._callFUT, '"\\"') + + def test_invalid_quoting(self): + self.assertRaises(ValueError, self._callFUT, '"test') + + def test_invalid_quoting_single_quote(self): + self.assertRaises(ValueError, self._callFUT, '"') diff --git a/tests/test_wasyncore.py b/tests/test_wasyncore.py new file mode 100644 index 00000000..e833c7ed --- /dev/null +++ b/tests/test_wasyncore.py @@ -0,0 +1,1801 @@ +import _thread as thread +import contextlib +import errno +import functools +import gc +from io import BytesIO +import os +import re +import select +import socket +import struct +import sys +import threading +import time +import unittest +import warnings + +from waitress import compat, wasyncore as asyncore + +TIMEOUT = 3 +HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") +HOST = "localhost" +HOSTv4 = "127.0.0.1" +HOSTv6 = "::1" + +# Filename used for testing + +if os.name == "java": # pragma: no cover + # Jython disallows @ in module names + TESTFN = "$test" +else: + TESTFN = "@test" + +TESTFN = f"{TESTFN}_{os.getpid()}_tmp" + + +class DummyLogger: # pragma: no cover + def __init__(self): + self.messages = [] + + def log(self, severity, message): + self.messages.append((severity, message)) + + +class WarningsRecorder: # pragma: no cover + """Convenience wrapper for the warnings list returned on + entry to the warnings.catch_warnings() context manager. + """ + + def __init__(self, warnings_list): + self._warnings = warnings_list + self._last = 0 + + @property + def warnings(self): + return self._warnings[self._last :] + + def reset(self): + self._last = len(self._warnings) + + +def _filterwarnings(filters, quiet=False): # pragma: no cover + """Catch the warnings, then check if all the expected + warnings have been raised and re-raise unexpected warnings. + If 'quiet' is True, only re-raise the unexpected warnings. + """ + # Clear the warning registry of the calling module + # in order to re-raise the warnings. + frame = sys._getframe(2) + registry = frame.f_globals.get("__warningregistry__") + + if registry: + registry.clear() + with warnings.catch_warnings(record=True) as w: + # Set filter "always" to record all warnings. Because + # test_warnings swap the module, we need to look up in + # the sys.modules dictionary. + sys.modules["warnings"].simplefilter("always") + yield WarningsRecorder(w) + # Filter the recorded warnings + reraise = list(w) + missing = [] + + for msg, cat in filters: + seen = False + + for w in reraise[:]: + warning = w.message + # Filter out the matching messages + + if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat): + seen = True + reraise.remove(w) + + if not seen and not quiet: + # This filter caught nothing + missing.append((msg, cat.__name__)) + + if reraise: + raise AssertionError("unhandled warning %s" % reraise[0]) + + if missing: + raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0]) + + +@contextlib.contextmanager +def check_warnings(*filters, **kwargs): # pragma: no cover + """Context manager to silence warnings. + + Accept 2-tuples as positional arguments: + ("message regexp", WarningCategory) + + Optional argument: + - if 'quiet' is True, it does not fail if a filter catches nothing + (default True without argument, + default False if some filters are defined) + + Without argument, it defaults to: + check_warnings(("", Warning), quiet=True) + """ + quiet = kwargs.get("quiet") + + if not filters: + filters = (("", Warning),) + # Preserve backward compatibility + + if quiet is None: + quiet = True + + return _filterwarnings(filters, quiet) + + +def gc_collect(): # pragma: no cover + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + + if sys.platform.startswith("java"): + time.sleep(0.1) + gc.collect() + gc.collect() + + +def threading_setup(): # pragma: no cover + return (thread._count(), None) + + +def threading_cleanup(*original_values): # pragma: no cover + global environment_altered + + _MAX_COUNT = 100 + + for count in range(_MAX_COUNT): + values = (thread._count(), None) + + if values == original_values: + break + + if not count: + # Display a warning at the first iteration + environment_altered = True + sys.stderr.write( + "Warning -- threading_cleanup() failed to cleanup " + "%s threads" % (values[0] - original_values[0]) + ) + sys.stderr.flush() + + values = None + + time.sleep(0.01) + gc_collect() + + +def reap_threads(func): # pragma: no cover + """Use this function when threads are being used. This will + ensure that the threads are cleaned up even when the test fails. + """ + + @functools.wraps(func) + def decorator(*args): + key = threading_setup() + try: + return func(*args) + finally: + threading_cleanup(*key) + + return decorator + + +def join_thread(thread, timeout=30.0): # pragma: no cover + """Join a thread. Raise an AssertionError if the thread is still alive + after timeout seconds. + """ + thread.join(timeout) + + if thread.is_alive(): + msg = "failed to join the thread in %.1f seconds" % timeout + raise AssertionError(msg) + + +def bind_port(sock, host=HOST): # pragma: no cover + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, "SO_REUSEADDR"): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEADDR " + "socket option on TCP/IP sockets!" + ) + + if hasattr(socket, "SO_REUSEPORT"): + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise RuntimeError( + "tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!" + ) + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + + if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + + return port + + +@contextlib.contextmanager +def closewrapper(sock): # pragma: no cover + try: + yield sock + finally: + sock.close() + + +class dummysocket: # pragma: no cover + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + def fileno(self): + return 42 + + def setblocking(self, yesno): + self.isblocking = yesno + + def getpeername(self): + return "peername" + + +class dummychannel: # pragma: no cover + def __init__(self): + self.socket = dummysocket() + + def close(self): + self.socket.close() + + +class exitingdummy: # pragma: no cover + def __init__(self): + pass + + def handle_read_event(self): + raise asyncore.ExitNow() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + +class crashingdummy: + def __init__(self): + self.error_handled = False + + def handle_read_event(self): + raise Exception() + + handle_write_event = handle_read_event + handle_close = handle_read_event + handle_expt_event = handle_read_event + + def handle_error(self): + self.error_handled = True + + +# used when testing senders; just collects what it gets until newline is sent +def capture_server(evt, buf, serv): # pragma no cover + try: + serv.listen(0) + conn, addr = serv.accept() + except socket.timeout: + pass + else: + n = 200 + start = time.time() + + while n > 0 and time.time() - start < 3.0: + r, w, e = select.select([conn], [], [], 0.1) + + if r: + n -= 1 + data = conn.recv(10) + # keep everything except for the newline terminator + buf.write(data.replace(b"\n", b"")) + + if b"\n" in data: + break + time.sleep(0.01) + + conn.close() + finally: + serv.close() + evt.set() + + +def bind_unix_socket(sock, addr): # pragma: no cover + """Bind a unix socket, raising SkipTest if PermissionError is raised.""" + assert sock.family == socket.AF_UNIX + try: + sock.bind(addr) + except PermissionError: + sock.close() + raise unittest.SkipTest("cannot bind AF_UNIX sockets") + + +def bind_af_aware(sock, addr): + """Helper function to bind a socket according to its family.""" + + if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: + # Make sure the path doesn't exist. + unlink(addr) + bind_unix_socket(sock, addr) + else: + sock.bind(addr) + + +if sys.platform.startswith("win"): # pragma: no cover + + def _waitfor(func, pathname, waitall=False): + # Perform the operation + func(pathname) + # Now setup the wait loop + + if waitall: + dirname = pathname + else: + dirname, name = os.path.split(pathname) + dirname = dirname or "." + # Check for `pathname` to be removed from the filesystem. + # The exponential backoff of the timeout amounts to a total + # of ~1 second after which the deletion is probably an error + # anyway. + # Testing on an i7@4.3GHz shows that usually only 1 iteration is + # required when contention occurs. + timeout = 0.001 + + while timeout < 1.0: + # Note we are only testing for the existence of the file(s) in + # the contents of the directory regardless of any security or + # access rights. If we have made it this far, we have sufficient + # permissions to do that much using Python's equivalent of the + # Windows API FindFirstFile. + # Other Windows APIs can fail or give incorrect results when + # dealing with files that are pending deletion. + L = os.listdir(dirname) + + if not (L if waitall else name in L): + return + # Increase the timeout and try again + time.sleep(timeout) + timeout *= 2 + warnings.warn( + "tests may fail, delete still pending for " + pathname, + RuntimeWarning, + stacklevel=4, + ) + + def _unlink(filename): + _waitfor(os.unlink, filename) + +else: + _unlink = os.unlink + + +def unlink(filename): + try: + _unlink(filename) + except OSError: + pass + + +def _is_ipv6_enabled(): # pragma: no cover + """Check whether IPv6 is enabled on this host.""" + + if compat.HAS_IPV6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind(("::1", 0)) + + return True + except OSError: + pass + finally: + if sock: + sock.close() + + return False + + +IPV6_ENABLED = _is_ipv6_enabled() + + +class HelperFunctionTests(unittest.TestCase): + def test_readwriteexc(self): + # Check exception handling behavior of read, write and _exception + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore read/write/_exception calls + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) + self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) + self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + asyncore.read(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore.write(tr2) + self.assertEqual(tr2.error_handled, True) + + tr2 = crashingdummy() + asyncore._exception(tr2) + self.assertEqual(tr2.error_handled, True) + + # asyncore.readwrite uses constants in the select module that + # are not present in Windows systems (see this thread: + # http://mail.python.org/pipermail/python-list/2001-October/109973.html) + # These constants should be present as long as poll is available + + @unittest.skipUnless(hasattr(select, "poll"), "select.poll required") + def test_readwrite(self): + # Check that correct methods are called by readwrite() + + attributes = ("read", "expt", "write", "closed", "error_handled") + + expected = ( + (select.POLLIN, "read"), + (select.POLLPRI, "expt"), + (select.POLLOUT, "write"), + (select.POLLERR, "closed"), + (select.POLLHUP, "closed"), + (select.POLLNVAL, "closed"), + ) + + class testobj: + def __init__(self): + self.read = False + self.write = False + self.closed = False + self.expt = False + self.error_handled = False + + def handle_read_event(self): + self.read = True + + def handle_write_event(self): + self.write = True + + def handle_close(self): + self.closed = True + + def handle_expt_event(self): + self.expt = True + + # def handle_error(self): + # self.error_handled = True + + for flag, expectedattr in expected: + tobj = testobj() + self.assertEqual(getattr(tobj, expectedattr), False) + asyncore.readwrite(tobj, flag) + + # Only the attribute modified by the routine we expect to be + # called should be True. + + for attr in attributes: + self.assertEqual(getattr(tobj, attr), attr == expectedattr) + + # check that ExitNow exceptions in the object handler method + # bubbles all the way up through asyncore readwrite call + tr1 = exitingdummy() + self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) + + # check that an exception other than ExitNow in the object handler + # method causes the handle_error method to get called + tr2 = crashingdummy() + self.assertEqual(tr2.error_handled, False) + asyncore.readwrite(tr2, flag) + self.assertEqual(tr2.error_handled, True) + + def test_closeall(self): + self.closeall_check(False) + + def test_closeall_default(self): + self.closeall_check(True) + + def closeall_check(self, usedefault): + # Check that close_all() closes everything in a given map + + l = [] + testmap = {} + + for i in range(10): + c = dummychannel() + l.append(c) + self.assertEqual(c.socket.closed, False) + testmap[i] = c + + if usedefault: + socketmap = asyncore.socket_map + try: + asyncore.socket_map = testmap + asyncore.close_all() + finally: + testmap, asyncore.socket_map = asyncore.socket_map, socketmap + else: + asyncore.close_all(testmap) + + self.assertEqual(len(testmap), 0) + + for c in l: + self.assertEqual(c.socket.closed, True) + + def test_compact_traceback(self): + try: + raise Exception("I don't like spam!") + except: + real_t, real_v, real_tb = sys.exc_info() + r = asyncore.compact_traceback() + + (f, function, line), t, v, info = r + self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py") + self.assertEqual(function, "test_compact_traceback") + self.assertEqual(t, real_t) + self.assertEqual(v, real_v) + self.assertEqual(info, f"[{f}|{function}|{line}]") + + +class DispatcherTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + def test_basic(self): + d = asyncore.dispatcher() + self.assertEqual(d.readable(), True) + self.assertEqual(d.writable(), True) + + def test_repr(self): + d = asyncore.dispatcher() + self.assertEqual(repr(d), "" % id(d)) + + def test_log_info(self): + import logging + + inst = asyncore.dispatcher(map={}) + logger = DummyLogger() + inst.logger = logger + inst.log_info("message", "warning") + self.assertEqual(logger.messages, [(logging.WARN, "message")]) + + def test_log(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + inst.log("message") + self.assertEqual(logger.messages, [(logging.DEBUG, "message")]) + + def test_unhandled(self): + import logging + + inst = asyncore.dispatcher() + logger = DummyLogger() + inst.logger = logger + + inst.handle_expt() + inst.handle_read() + inst.handle_write() + inst.handle_connect() + + expected = [ + (logging.WARN, "unhandled incoming priority event"), + (logging.WARN, "unhandled read event"), + (logging.WARN, "unhandled write event"), + (logging.WARN, "unhandled connect event"), + ] + self.assertEqual(logger.messages, expected) + + def test_strerror(self): + # refers to bug #8573 + err = asyncore._strerror(errno.EPERM) + + if hasattr(os, "strerror"): + self.assertEqual(err, os.strerror(errno.EPERM)) + err = asyncore._strerror(-1) + self.assertTrue(err != "") + + +class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover + def readable(self): + return False + + def handle_connect(self): + pass + + +class DispatcherWithSendTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + asyncore.close_all() + + @reap_threads + def test_send(self): + evt = threading.Event() + sock = socket.socket() + sock.settimeout(3) + port = bind_port(sock) + + cap = BytesIO() + args = (evt, cap, sock) + t = threading.Thread(target=capture_server, args=args) + t.start() + try: + # wait a little longer for the server to initialize (it sometimes + # refuses connections on slow machines without this wait) + time.sleep(0.2) + + data = b"Suppose there isn't a 16-ton weight?" + d = dispatcherwithsend_noread() + d.create_socket() + d.connect((HOST, port)) + + # give time for socket to connect + time.sleep(0.1) + + d.send(data) + d.send(data) + d.send(b"\n") + + n = 1000 + + while d.out_buffer and n > 0: # pragma: no cover + asyncore.poll() + n -= 1 + + evt.wait() + + self.assertEqual(cap.getvalue(), data * 2) + finally: + join_thread(t, timeout=TIMEOUT) + + +@unittest.skipUnless( + hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required" +) +class FileWrapperTest(unittest.TestCase): + def setUp(self): + self.d = b"It's not dead, it's sleeping!" + with open(TESTFN, "wb") as file: + file.write(self.d) + + def tearDown(self): + unlink(TESTFN) + + def test_recv(self): + fd = os.open(TESTFN, os.O_RDONLY) + w = asyncore.file_wrapper(fd) + os.close(fd) + + self.assertNotEqual(w.fd, fd) + self.assertNotEqual(w.fileno(), fd) + self.assertEqual(w.recv(13), b"It's not dead") + self.assertEqual(w.read(6), b", it's") + w.close() + self.assertRaises(OSError, w.read, 1) + + def test_send(self): + d1 = b"Come again?" + d2 = b"I want to buy some cheese." + fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) + w = asyncore.file_wrapper(fd) + os.close(fd) + + w.write(d1) + w.send(d2) + w.close() + with open(TESTFN, "rb") as file: + self.assertEqual(file.read(), self.d + d1 + d2) + + @unittest.skipUnless( + hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required" + ) + def test_dispatcher(self): + fd = os.open(TESTFN, os.O_RDONLY) + data = [] + + class FileDispatcher(asyncore.file_dispatcher): + def handle_read(self): + data.append(self.recv(29)) + + FileDispatcher(fd) + os.close(fd) + asyncore.loop(timeout=0.01, use_poll=True, count=2) + self.assertEqual(b"".join(data), self.d) + + def test_resource_warning(self): + # Issue #11453 + got_warning = False + + while got_warning is False: + # we try until we get the outcome we want because this + # test is not deterministic (gc_collect() may not + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + + os.close(fd) + + try: + with check_warnings(("", ResourceWarning)): + f = None + gc_collect() + except AssertionError: # pragma: no cover + pass + else: + got_warning = True + + def test_close_twice(self): + fd = os.open(TESTFN, os.O_RDONLY) + f = asyncore.file_wrapper(fd) + os.close(fd) + + os.close(f.fd) # file_wrapper dupped fd + with self.assertRaises(OSError): + f.close() + + self.assertEqual(f.fd, -1) + # calling close twice should not fail + f.close() + + +class BaseTestHandler(asyncore.dispatcher): # pragma: no cover + def __init__(self, sock=None): + asyncore.dispatcher.__init__(self, sock) + self.flag = False + + def handle_accept(self): + raise Exception("handle_accept not supposed to be called") + + def handle_accepted(self): + raise Exception("handle_accepted not supposed to be called") + + def handle_connect(self): + raise Exception("handle_connect not supposed to be called") + + def handle_expt(self): + raise Exception("handle_expt not supposed to be called") + + def handle_close(self): + raise Exception("handle_close not supposed to be called") + + def handle_error(self): + raise + + +class BaseServer(asyncore.dispatcher): + """A server which listens on an address and dispatches the + connection to a handler. + """ + + def __init__(self, family, addr, handler=BaseTestHandler): + asyncore.dispatcher.__init__(self) + self.create_socket(family) + self.set_reuse_addr() + bind_af_aware(self.socket, addr) + self.listen(5) + self.handler = handler + + @property + def address(self): + return self.socket.getsockname() + + def handle_accepted(self, sock, addr): + self.handler(sock) + + def handle_error(self): # pragma: no cover + raise + + +class BaseClient(BaseTestHandler): + def __init__(self, family, address): + BaseTestHandler.__init__(self) + self.create_socket(family) + self.connect(address) + + def handle_connect(self): + pass + + +class BaseTestAPI: + def tearDown(self): + asyncore.close_all(ignore_all=True) + + def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover + timeout = float(timeout) / 100 + count = 100 + + while asyncore.socket_map and count > 0: + asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) + + if instance.flag: + return + count -= 1 + time.sleep(timeout) + self.fail("flag not set") + + def test_handle_connect(self): + # make sure handle_connect is called on connect() + + class TestClient(BaseClient): + def handle_connect(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_accept(self): + # make sure handle_accept() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_accepted(self): + # make sure handle_accepted() is called when a client connects + + class TestListener(BaseTestHandler): + def __init__(self, family, addr): + BaseTestHandler.__init__(self) + self.create_socket(family) + bind_af_aware(self.socket, addr) + self.listen(5) + self.address = self.socket.getsockname() + + def handle_accept(self): + asyncore.dispatcher.handle_accept(self) + + def handle_accepted(self, sock, addr): + sock.close() + self.flag = True + + server = TestListener(self.family, self.addr) + client = BaseClient(self.family, server.address) + self.loop_waiting_for_flag(server) + + def test_handle_read(self): + # make sure handle_read is called on data received + + class TestClient(BaseClient): + def handle_read(self): + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.send(b"x" * 1024) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_write(self): + # make sure handle_write is called + + class TestClient(BaseClient): + def handle_write(self): + self.flag = True + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close(self): + # make sure handle_close is called when the other end closes + # the connection + + class TestClient(BaseClient): + def handle_read(self): + # in order to make handle_close be called we are supposed + # to make at least one recv() call + self.recv(1024) + + def handle_close(self): + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.close() + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_close_after_conn_broken(self): + # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and + # #11265). + + data = b"\0" * 128 + + class TestClient(BaseClient): + def handle_write(self): + self.send(data) + + def handle_close(self): + self.flag = True + self.close() + + def handle_expt(self): # pragma: no cover + # needs to exist for MacOS testing + self.flag = True + self.close() + + class TestHandler(BaseTestHandler): + def handle_read(self): + self.recv(len(data)) + self.close() + + def writable(self): + return False + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + @unittest.skipIf( + sys.platform.startswith("sunos"), "OOB support is broken on Solaris" + ) + def test_handle_expt(self): + # Make sure handle_expt is called on OOB data received. + # Note: this might fail on some platforms as OOB data is + # tenuously supported and rarely used. + + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + if sys.platform == "darwin" and self.use_poll: # pragma: no cover + self.skipTest("poll may fail on macOS; see issue #28087") + + class TestClient(BaseClient): + def handle_expt(self): + self.socket.recv(1024, socket.MSG_OOB) + self.flag = True + + class TestHandler(BaseTestHandler): + def __init__(self, conn): + BaseTestHandler.__init__(self, conn) + self.socket.send(chr(244).encode("latin-1"), socket.MSG_OOB) + + server = BaseServer(self.family, self.addr, TestHandler) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_handle_error(self): + class TestClient(BaseClient): + def handle_write(self): + 1.0 / 0 + + def handle_error(self): + self.flag = True + try: + raise + except ZeroDivisionError: + pass + else: # pragma: no cover + raise Exception("exception not raised") + + server = BaseServer(self.family, self.addr) + client = TestClient(self.family, server.address) + self.loop_waiting_for_flag(client) + + def test_connection_attributes(self): + server = BaseServer(self.family, self.addr) + client = BaseClient(self.family, server.address) + + # we start disconnected + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + # this can't be taken for granted across all platforms + # self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # execute some loops so that client connects to server + asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertTrue(client.connected) + self.assertFalse(client.accepting) + + # disconnect the client + client.close() + self.assertFalse(server.connected) + self.assertTrue(server.accepting) + self.assertFalse(client.connected) + self.assertFalse(client.accepting) + + # stop serving + server.close() + self.assertFalse(server.connected) + self.assertFalse(server.accepting) + + def test_create_socket(self): + s = asyncore.dispatcher() + s.create_socket(self.family) + # self.assertEqual(s.socket.type, socket.SOCK_STREAM) + self.assertEqual(s.socket.family, self.family) + self.assertEqual(s.socket.gettimeout(), 0) + # self.assertFalse(s.socket.get_inheritable()) + + def test_bind(self): + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + s1 = asyncore.dispatcher() + s1.create_socket(self.family) + s1.bind(self.addr) + s1.listen(5) + port = s1.socket.getsockname()[1] + + s2 = asyncore.dispatcher() + s2.create_socket(self.family) + # EADDRINUSE indicates the socket was correctly bound + self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) + + def test_set_reuse_addr(self): # pragma: no cover + if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: + self.skipTest("Not applicable to AF_UNIX sockets.") + + with closewrapper(socket.socket(self.family)) as sock: + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except OSError: + unittest.skip("SO_REUSEADDR not supported on this platform") + else: + # if SO_REUSEADDR succeeded for sock we expect asyncore + # to do the same + s = asyncore.dispatcher(socket.socket(self.family)) + self.assertFalse( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + s.socket.close() + s.create_socket(self.family) + s.set_reuse_addr() + self.assertTrue( + s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) + ) + + @reap_threads + def test_quick_connect(self): # pragma: no cover + # see: http://bugs.python.org/issue10340 + + if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())): + self.skipTest("test specific to AF_INET and AF_INET6") + + server = BaseServer(self.family, self.addr) + # run the thread 500 ms: the socket should be connected in 200 ms + t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5)) + t.start() + try: + sock = socket.socket(self.family, socket.SOCK_STREAM) + with closewrapper(sock) as s: + s.settimeout(0.2) + s.setsockopt( + socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) + ) + + try: + s.connect(server.address) + except OSError: + pass + finally: + join_thread(t, timeout=TIMEOUT) + + +class BaseTestAPI_UseIPv4Sockets(BaseTestAPI): + family = socket.AF_INET + addr = (HOST, 0) + + +@unittest.skipUnless(IPV6_ENABLED, "IPv6 support required") +class BaseTestAPI_UseIPv6Sockets(BaseTestAPI): + family = socket.AF_INET6 + addr = (HOSTv6, 0) + + +@unittest.skipUnless(HAS_UNIX_SOCKETS, "Unix sockets required") +class BaseTestAPI_UseUnixSockets(BaseTestAPI): + if HAS_UNIX_SOCKETS: + family = socket.AF_UNIX + addr = TESTFN + + def tearDown(self): + unlink(self.addr) + BaseTestAPI.tearDown(self) + + +class TestAPI_UseIPv4Select(BaseTestAPI_UseIPv4Sockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv4Poll(BaseTestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseIPv6Select(BaseTestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseIPv6Poll(BaseTestAPI_UseIPv6Sockets, unittest.TestCase): + use_poll = True + + +class TestAPI_UseUnixSocketsSelect(BaseTestAPI_UseUnixSockets, unittest.TestCase): + use_poll = False + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class TestAPI_UseUnixSocketsPoll(BaseTestAPI_UseUnixSockets, unittest.TestCase): + use_poll = True + + +class Test__strerror(unittest.TestCase): + def _callFUT(self, err): + from waitress.wasyncore import _strerror + + return _strerror(err) + + def test_gardenpath(self): + self.assertEqual(self._callFUT(1), "Operation not permitted") + + def test_unknown(self): + self.assertEqual(self._callFUT("wut"), "Unknown error wut") + + +class Test_read(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import read + + return read(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.read_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + +class Test_write(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import write + + return write(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.write_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.write_event_handled) + self.assertTrue(inst.error_handled) + + +class Test__exception(unittest.TestCase): + def _callFUT(self, dispatcher): + from waitress.wasyncore import _exception + + return _exception(dispatcher) + + def test_gardenpath(self): + inst = DummyDispatcher() + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_reraised(self): + from waitress.wasyncore import ExitNow + + inst = DummyDispatcher(ExitNow) + self.assertRaises(ExitNow, self._callFUT, inst) + self.assertTrue(inst.expt_event_handled) + self.assertFalse(inst.error_handled) + + def test_non_reraised(self): + inst = DummyDispatcher(OSError) + self._callFUT(inst) + self.assertTrue(inst.expt_event_handled) + self.assertTrue(inst.error_handled) + + +@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") +class Test_readwrite(unittest.TestCase): + def _callFUT(self, obj, flags): + from waitress.wasyncore import readwrite + + return readwrite(obj, flags) + + def test_handle_read_event(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_handle_write_event(self): + flags = 0 + flags |= select.POLLOUT + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.write_event_handled) + + def test_handle_expt_event(self): + flags = 0 + flags |= select.POLLPRI + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.expt_event_handled) + + def test_handle_close(self): + flags = 0 + flags |= select.POLLHUP + inst = DummyDispatcher() + self._callFUT(inst, flags) + self.assertTrue(inst.close_handled) + + def test_socketerror_not_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.error_handled) + + def test_socketerror_in_disconnected(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET")) + self._callFUT(inst, flags) + self.assertTrue(inst.read_event_handled) + self.assertTrue(inst.close_handled) + + def test_exception_in_reraised(self): + from waitress import wasyncore + + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(wasyncore.ExitNow) + self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) + self.assertTrue(inst.read_event_handled) + + def test_exception_not_in_reraised(self): + flags = 0 + flags |= select.POLLIN + inst = DummyDispatcher(ValueError) + self._callFUT(inst, flags) + self.assertTrue(inst.error_handled) + + +class Test_poll(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll + + return poll(timeout, map) + + def test_nothing_writable_nothing_readable_but_map_not_empty(self): + # i read the mock.patch docs. nerp. + dummy_time = DummyTime() + map = {0: DummyDispatcher()} + try: + from waitress import wasyncore + + old_time = wasyncore.time + wasyncore.time = dummy_time + result = self._callFUT(map=map) + finally: + wasyncore.time = old_time + self.assertEqual(result, None) + self.assertEqual(dummy_time.sleepvals, [0.0]) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EINTR)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + result = self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(result, None) + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + dummy_select = DummySelect(select.error(errno.EBADF)) + disp = DummyDispatcher() + disp.readable = lambda: True + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) + + +class Test_poll2(unittest.TestCase): + def _callFUT(self, timeout=0.0, map=None): + from waitress.wasyncore import poll2 + + return poll2(timeout, map) + + def test_select_raises_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EINTR)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self._callFUT(map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + def test_select_raises_non_EINTR(self): + # i read the mock.patch docs. nerp. + pollster = DummyPollster(exc=select.error(errno.EBADF)) + dummy_select = DummySelect(pollster=pollster) + disp = DummyDispatcher() + map = {0: disp} + try: + from waitress import wasyncore + + old_select = wasyncore.select + wasyncore.select = dummy_select + self.assertRaises(select.error, self._callFUT, map=map) + finally: + wasyncore.select = old_select + self.assertEqual(pollster.polled, [0.0]) + + +class Test_dispatcher(unittest.TestCase): + def _makeOne(self, sock=None, map=None): + from waitress.wasyncore import dispatcher + + return dispatcher(sock=sock, map=map) + + def test_unexpected_getpeername_exc(self): + sock = dummysocket() + + def getpeername(): + raise OSError(errno.EBADF) + + map = {} + sock.getpeername = getpeername + self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) + self.assertEqual(map, {}) + + def test___repr__accepting(self): + sock = dummysocket() + map = {} + inst = self._makeOne(sock=sock, map=map) + inst.accepting = True + inst.addr = ("localhost", 8080) + result = repr(inst) + expected = "= self.adj.send_bytes. We need to do this now, or it - # won't get done. - flush = self._flush_some_if_lockable - self.force_flush = False - elif (self.total_outbufs_len() >= self.adj.send_bytes): - # 1. There's a running task, so we need to try to lock - # the outbuf before sending - # 2. Only try to send if the data in the out buffer is larger - # than self.adj_bytes to avoid TCP fragmentation - flush = self._flush_some_if_lockable - else: - # 1. There's not enough data in the out buffer to bother to send - # right now. - flush = None - - if flush: - try: - flush() - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception('Socket error') - self.will_close = True - except: - self.logger.exception('Unexpected exception when flushing') - self.will_close = True - - if self.close_when_flushed and not self.any_outbuf_has_data(): - self.close_when_flushed = False - self.will_close = True - - if self.will_close: - self.handle_close() - - def readable(self): - # We might want to create a new task. We can only do this if: - # 1. We're not already about to close the connection. - # 2. There's no already currently running task(s). - # 3. There's no data in the output buffer that needs to be sent - # before we potentially create a new task. - return not (self.will_close or self.requests or - self.any_outbuf_has_data()) - - def handle_read(self): - try: - data = self.recv(self.adj.recv_bytes) - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception('Socket error') - self.handle_close() - return - if data: - self.last_activity = time.time() - self.received(data) - - def received(self, data): - """ - Receives input asynchronously and assigns one or more requests to the - channel. - """ - # Preconditions: there's no task(s) already running - request = self.request - requests = [] - - if not data: - return False - - while data: - if request is None: - request = self.parser_class(self.adj) - n = request.received(data) - if request.expect_continue and request.headers_finished: - # guaranteed by parser to be a 1.1 request - request.expect_continue = False - if not self.sent_continue: - # there's no current task, so we don't need to try to - # lock the outbuf to append to it. - self.outbufs[-1].append(b'HTTP/1.1 100 Continue\r\n\r\n') - self.sent_continue = True - self._flush_some() - request.completed = False - if request.completed: - # The request (with the body) is ready to use. - self.request = None - if not request.empty: - requests.append(request) - request = None - else: - self.request = request - if n >= len(data): - break - data = data[n:] - - if requests: - self.requests = requests - self.server.add_task(self) - - return True - - def _flush_some_if_lockable(self): - # Since our task may be appending to the outbuf, we try to acquire - # the lock, but we don't block if we can't. - locked = self.outbuf_lock.acquire(False) - if locked: - try: - self._flush_some() - finally: - self.outbuf_lock.release() - - def _flush_some(self): - # Send as much data as possible to our client - - sent = 0 - dobreak = False - - while True: - outbuf = self.outbufs[0] - # use outbuf.__len__ rather than len(outbuf) FBO of not getting - # OverflowError on Python 2 - outbuflen = outbuf.__len__() - if outbuflen <= 0: - # self.outbufs[-1] must always be a writable outbuf - if len(self.outbufs) > 1: - toclose = self.outbufs.pop(0) - try: - toclose.close() - except: - self.logger.exception( - 'Unexpected error when closing an outbuf') - continue # pragma: no cover (coverage bug, it is hit) - else: - if hasattr(outbuf, 'prune'): - outbuf.prune() - dobreak = True - - while outbuflen > 0: - chunk = outbuf.get(self.adj.send_bytes) - num_sent = self.send(chunk) - if num_sent: - outbuf.skip(num_sent, True) - outbuflen -= num_sent - sent += num_sent - else: - dobreak = True - break - - if dobreak: - break - - if sent: - self.last_activity = time.time() - return True - - return False - - def handle_close(self): - for outbuf in self.outbufs: - try: - outbuf.close() - except: - self.logger.exception( - 'Unknown exception while trying to close outbuf') - self.connected = False - asyncore.dispatcher.close(self) - - def add_channel(self, map=None): - """See asyncore.dispatcher - - This hook keeps track of opened channels. - """ - asyncore.dispatcher.add_channel(self, map) - self.server.active_channels[self._fileno] = self - - def del_channel(self, map=None): - """See asyncore.dispatcher - - This hook keeps track of closed channels. - """ - fd = self._fileno # next line sets this to None - asyncore.dispatcher.del_channel(self, map) - ac = self.server.active_channels - if fd in ac: - del ac[fd] - - # - # SYNCHRONOUS METHODS - # - - def write_soon(self, data): - if data: - # the async mainloop might be popping data off outbuf; we can - # block here waiting for it because we're in a task thread - with self.outbuf_lock: - if data.__class__ is ReadOnlyFileBasedBuffer: - # they used wsgi.file_wrapper - self.outbufs.append(data) - nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) - self.outbufs.append(nextbuf) - else: - self.outbufs[-1].append(data) - # XXX We might eventually need to pull the trigger here (to - # instruct select to stop blocking), but it slows things down so - # much that I'll hold off for now; "server push" on otherwise - # unbusy systems may suffer. - return len(data) - return 0 - - def service(self): - """Execute all pending requests """ - with self.task_lock: - while self.requests: - request = self.requests[0] - if request.error: - task = self.error_task_class(self, request) - else: - task = self.task_class(self, request) - try: - task.service() - except: - self.logger.exception('Exception when serving %s' % - task.request.path) - if not task.wrote_header: - if self.adj.expose_tracebacks: - body = traceback.format_exc() - else: - body = ('The server encountered an unexpected ' - 'internal server error') - req_version = request.version - req_headers = request.headers - request = self.parser_class(self.adj) - request.error = InternalServerError(body) - # copy some original request attributes to fulfill - # HTTP 1.1 requirements - request.version = req_version - try: - request.headers['CONNECTION'] = req_headers[ - 'CONNECTION'] - except KeyError: - pass - task = self.error_task_class(self, request) - task.service() # must not fail - else: - task.close_on_finish = True - # we cannot allow self.requests to drop to empty til - # here; otherwise the mainloop gets confused - if task.close_on_finish: - self.close_when_flushed = True - for request in self.requests: - request.close() - self.requests = [] - else: - request = self.requests.pop(0) - request.close() - - self.force_flush = True - self.server.pull_trigger() - self.last_activity = time.time() - - def cancel(self): - """ Cancels all pending requests """ - self.force_flush = True - self.last_activity = time.time() - self.requests = [] - - def defer(self): - pass diff --git a/waitress/compat.py b/waitress/compat.py deleted file mode 100644 index 700f7a1e..00000000 --- a/waitress/compat.py +++ /dev/null @@ -1,140 +0,0 @@ -import sys -import types -import platform -import warnings - -try: - import urlparse -except ImportError: # pragma: no cover - from urllib import parse as urlparse - -# True if we are running on Python 3. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -# True if we are running on Windows -WIN = platform.system() == 'Windows' - -if PY3: # pragma: no cover - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - long = int -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - long = long - -if PY3: # pragma: no cover - from urllib.parse import unquote_to_bytes - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring).decode('latin-1') -else: - from urlparse import unquote as unquote_to_bytes - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring) - -def text_(s, encoding='latin-1', errors='strict'): - """ If ``s`` is an instance of ``binary_type``, return - ``s.decode(encoding, errors)``, otherwise return ``s``""" - if isinstance(s, binary_type): - return s.decode(encoding, errors) - return s # pragma: no cover - -if PY3: # pragma: no cover - def tostr(s): - if isinstance(s, text_type): - s = s.encode('latin-1') - return str(s, 'latin-1', 'strict') - - def tobytes(s): - return bytes(s, 'latin-1') -else: - tostr = str - - def tobytes(s): - return s - -try: - from Queue import ( - Queue, - Empty, - ) -except ImportError: # pragma: no cover - from queue import ( - Queue, - Empty, - ) - -if PY3: # pragma: no cover - import builtins - exec_ = getattr(builtins, "exec") - - def reraise(tp, value, tb=None): - if value is None: - value = tp - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - - del builtins - -else: # pragma: no cover - def exec_(code, globs=None, locs=None): - """Execute code in a namespace.""" - if globs is None: - frame = sys._getframe(1) - globs = frame.f_globals - if locs is None: - locs = frame.f_locals - del frame - elif locs is None: - locs = globs - exec("""exec code in globs, locs""") - - exec_("""def reraise(tp, value, tb=None): - raise tp, value, tb -""") - -try: - from StringIO import StringIO as NativeIO -except ImportError: # pragma: no cover - from io import StringIO as NativeIO - -try: - import httplib -except ImportError: # pragma: no cover - from http import client as httplib - -try: - MAXINT = sys.maxint -except AttributeError: # pragma: no cover - MAXINT = sys.maxsize - - -# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, -# Python on Windows may not define IPPROTO_IPV6 in socket. -import socket - -HAS_IPV6 = socket.has_ipv6 - -if hasattr(socket, 'IPPROTO_IPV6') and hasattr(socket, 'IPV6_V6ONLY'): - IPPROTO_IPV6 = socket.IPPROTO_IPV6 - IPV6_V6ONLY = socket.IPV6_V6ONLY -else: # pragma: no cover - if WIN: - IPPROTO_IPV6 = 41 - IPV6_V6ONLY = 27 - else: - warnings.warn( - 'OS does not support required IPv6 socket flags. This is requirement ' - 'for Waitress. Please open an issue at https://github.com/Pylons/waitress. ' - 'IPv6 support has been disabled.', - RuntimeWarning - ) - HAS_IPV6 = False diff --git a/waitress/parser.py b/waitress/parser.py deleted file mode 100644 index 6d2f3409..00000000 --- a/waitress/parser.py +++ /dev/null @@ -1,313 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser - -This server uses asyncore to accept connections and do initial -processing but threads to do work. -""" -import re -from io import BytesIO - -from waitress.compat import ( - tostr, - urlparse, - unquote_bytes_to_wsgi, -) - -from waitress.buffers import OverflowableBuffer - -from waitress.receiver import ( - FixedStreamReceiver, - ChunkedReceiver, -) - -from waitress.utilities import ( - find_double_newline, - RequestEntityTooLarge, - RequestHeaderFieldsTooLarge, - BadRequest, -) - -class ParsingError(Exception): - pass - -class HTTPRequestParser(object): - """A structure that collects the HTTP request. - - Once the stream is completed, the instance is passed to - a server task constructor. - """ - completed = False # Set once request is completed. - empty = False # Set if no request was made. - expect_continue = False # client sent "Expect: 100-continue" header - headers_finished = False # True when headers have been read - header_plus = b'' - chunked = False - content_length = 0 - header_bytes_received = 0 - body_bytes_received = 0 - body_rcv = None - version = '1.0' - error = None - connection_close = False - - # Other attributes: first_line, header, headers, command, uri, version, - # path, query, fragment - - def __init__(self, adj): - """ - adj is an Adjustments object. - """ - # headers is a mapping containing keys translated to uppercase - # with dashes turned into underscores. - self.headers = {} - self.adj = adj - - def received(self, data): - """ - Receives the HTTP stream for one request. Returns the number of - bytes consumed. Sets the completed flag once both the header and the - body have been received. - """ - if self.completed: - return 0 # Can't consume any more. - datalen = len(data) - br = self.body_rcv - if br is None: - # In header. - s = self.header_plus + data - index = find_double_newline(s) - if index >= 0: - # Header finished. - header_plus = s[:index] - consumed = len(data) - (len(s) - index) - # Remove preceeding blank lines. - header_plus = header_plus.lstrip() - if not header_plus: - self.empty = True - self.completed = True - else: - try: - self.parse_header(header_plus) - except ParsingError as e: - self.error = BadRequest(e.args[0]) - self.completed = True - else: - if self.body_rcv is None: - # no content-length header and not a t-e: chunked - # request - self.completed = True - if self.content_length > 0: - max_body = self.adj.max_request_body_size - # we won't accept this request if the content-length - # is too large - if self.content_length >= max_body: - self.error = RequestEntityTooLarge( - 'exceeds max_body of %s' % max_body) - self.completed = True - self.headers_finished = True - return consumed - else: - # Header not finished yet. - self.header_bytes_received += datalen - max_header = self.adj.max_request_header_size - if self.header_bytes_received >= max_header: - # malformed header, we need to construct some request - # on our own. we disregard the incoming(?) requests HTTP - # version and just use 1.0. IOW someone just sent garbage - # over the wire - self.parse_header(b'GET / HTTP/1.0\n') - self.error = RequestHeaderFieldsTooLarge( - 'exceeds max_header of %s' % max_header) - self.completed = True - self.header_plus = s - return datalen - else: - # In body. - consumed = br.received(data) - self.body_bytes_received += consumed - max_body = self.adj.max_request_body_size - if self.body_bytes_received >= max_body: - # this will only be raised during t-e: chunked requests - self.error = RequestEntityTooLarge( - 'exceeds max_body of %s' % max_body) - self.completed = True - elif br.error: - # garbage in chunked encoding input probably - self.error = br.error - self.completed = True - elif br.completed: - # The request (with the body) is ready to use. - self.completed = True - if self.chunked: - # We've converted the chunked transfer encoding request - # body into a normal request body, so we know its content - # length; set the header here. We already popped the - # TRANSFER_ENCODING header in parse_header, so this will - # appear to the client to be an entirely non-chunked HTTP - # request with a valid content-length. - self.headers['CONTENT_LENGTH'] = str(br.__len__()) - return consumed - - def parse_header(self, header_plus): - """ - Parses the header_plus block of text (the headers plus the - first line of the request). - """ - index = header_plus.find(b'\n') - if index >= 0: - first_line = header_plus[:index].rstrip() - header = header_plus[index + 1:] - else: - first_line = header_plus.rstrip() - header = b'' - - self.first_line = first_line # for testing - - lines = get_header_lines(header) - - headers = self.headers - for line in lines: - index = line.find(b':') - if index > 0: - key = line[:index] - if b'_' in key: - continue - value = line[index + 1:].strip() - key1 = tostr(key.upper().replace(b'-', b'_')) - # If a header already exists, we append subsequent values - # seperated by a comma. Applications already need to handle - # the comma seperated values, as HTTP front ends might do - # the concatenation for you (behavior specified in RFC2616). - try: - headers[key1] += tostr(b', ' + value) - except KeyError: - headers[key1] = tostr(value) - # else there's garbage in the headers? - - # command, uri, version will be bytes - command, uri, version = crack_first_line(first_line) - version = tostr(version) - command = tostr(command) - self.command = command - self.version = version - (self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, self.fragment) = split_uri(uri) - self.url_scheme = self.adj.url_scheme - connection = headers.get('CONNECTION', '') - - if version == '1.0': - if connection.lower() != 'keep-alive': - self.connection_close = True - - if version == '1.1': - # since the server buffers data from chunked transfers and clients - # never need to deal with chunked requests, downstream clients - # should not see the HTTP_TRANSFER_ENCODING header; we pop it - # here - te = headers.pop('TRANSFER_ENCODING', '') - if te.lower() == 'chunked': - self.chunked = True - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = ChunkedReceiver(buf) - expect = headers.get('EXPECT', '').lower() - self.expect_continue = expect == '100-continue' - if connection.lower() == 'close': - self.connection_close = True - - if not self.chunked: - try: - cl = int(headers.get('CONTENT_LENGTH', 0)) - except ValueError: - cl = 0 - self.content_length = cl - if cl > 0: - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = FixedStreamReceiver(cl, buf) - - def get_body_stream(self): - body_rcv = self.body_rcv - if body_rcv is not None: - return body_rcv.getfile() - else: - return BytesIO() - - def close(self): - body_rcv = self.body_rcv - if body_rcv is not None: - body_rcv.getbuf().close() - -def split_uri(uri): - # urlsplit handles byte input by returning bytes on py3, so - # scheme, netloc, path, query, and fragment are bytes - try: - scheme, netloc, path, query, fragment = urlparse.urlsplit(uri) - except UnicodeError: - raise ParsingError('Bad URI') - return ( - tostr(scheme), - tostr(netloc), - unquote_bytes_to_wsgi(path), - tostr(query), - tostr(fragment), - ) - -def get_header_lines(header): - """ - Splits the header into lines, putting multi-line headers together. - """ - r = [] - lines = header.split(b'\n') - for line in lines: - if line.startswith((b' ', b'\t')): - if not r: - # http://corte.si/posts/code/pathod/pythonservers/index.html - raise ParsingError('Malformed header line "%s"' % tostr(line)) - r[-1] += line - else: - r.append(line) - return r - -first_line_re = re.compile( - b'([^ ]+) ' - b'((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)' - b'(( HTTP/([0-9.]+))$|$)' -) - -def crack_first_line(line): - m = first_line_re.match(line) - if m is not None and m.end() == len(line): - if m.group(3): - version = m.group(5) - else: - version = None - method = m.group(1) - - # the request methods that are currently defined are all uppercase: - # https://www.iana.org/assignments/http-methods/http-methods.xhtml and - # the request method is case sensitive according to - # https://tools.ietf.org/html/rfc7231#section-4.1 - - # By disallowing anything but uppercase methods we save poor - # unsuspecting souls from sending lowercase HTTP methods to waitress - # and having the request complete, while servers like nginx drop the - # request onto the floor. - if method != method.upper(): - raise ParsingError('Malformed HTTP method "%s"' % tostr(method)) - uri = m.group(2) - return method, uri, version - else: - return b'', b'', b'' diff --git a/waitress/task.py b/waitress/task.py deleted file mode 100644 index 4ce410cf..00000000 --- a/waitress/task.py +++ /dev/null @@ -1,528 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import socket -import sys -import threading -import time - -from waitress.buffers import ReadOnlyFileBasedBuffer - -from waitress.compat import ( - tobytes, - Queue, - Empty, - reraise, -) - -from waitress.utilities import ( - build_http_date, - logger, -) - -rename_headers = { # or keep them without the HTTP_ prefix added - 'CONTENT_LENGTH': 'CONTENT_LENGTH', - 'CONTENT_TYPE': 'CONTENT_TYPE', -} - -hop_by_hop = frozenset(( - 'connection', - 'keep-alive', - 'proxy-authenticate', - 'proxy-authorization', - 'te', - 'trailers', - 'transfer-encoding', - 'upgrade' -)) - -class JustTesting(Exception): - pass - -class ThreadedTaskDispatcher(object): - """A Task Dispatcher that creates a thread for each task. - """ - stop_count = 0 # Number of threads that will stop soon. - logger = logger - - def __init__(self): - self.threads = {} # { thread number -> 1 } - self.queue = Queue() - self.thread_mgmt_lock = threading.Lock() - - def start_new_thread(self, target, args): - t = threading.Thread(target=target, name='waitress', args=args) - t.daemon = True - t.start() - - def handler_thread(self, thread_no): - threads = self.threads - try: - while threads.get(thread_no): - task = self.queue.get() - if task is None: - # Special value: kill this thread. - break - try: - task.service() - except Exception as e: - self.logger.exception( - 'Exception when servicing %r' % task) - if isinstance(e, JustTesting): - break - finally: - with self.thread_mgmt_lock: - self.stop_count -= 1 - threads.pop(thread_no, None) - - def set_thread_count(self, count): - with self.thread_mgmt_lock: - threads = self.threads - thread_no = 0 - running = len(threads) - self.stop_count - while running < count: - # Start threads. - while thread_no in threads: - thread_no = thread_no + 1 - threads[thread_no] = 1 - running += 1 - self.start_new_thread(self.handler_thread, (thread_no,)) - thread_no = thread_no + 1 - if running > count: - # Stop threads. - to_stop = running - count - self.stop_count += to_stop - for n in range(to_stop): - self.queue.put(None) - running -= 1 - - def add_task(self, task): - try: - task.defer() - self.queue.put(task) - except: - task.cancel() - raise - - def shutdown(self, cancel_pending=True, timeout=5): - self.set_thread_count(0) - # Ensure the threads shut down. - threads = self.threads - expiration = time.time() + timeout - while threads: - if time.time() >= expiration: - self.logger.warning( - "%d thread(s) still running" % - len(threads)) - break - time.sleep(0.1) - if cancel_pending: - # Cancel remaining tasks. - try: - queue = self.queue - while not queue.empty(): - task = queue.get() - if task is not None: - task.cancel() - except Empty: # pragma: no cover - pass - return True - return False - -class Task(object): - close_on_finish = False - status = '200 OK' - wrote_header = False - start_time = 0 - content_length = None - content_bytes_written = 0 - logged_write_excess = False - complete = False - chunked_response = False - logger = logger - - def __init__(self, channel, request): - self.channel = channel - self.request = request - self.response_headers = [] - version = request.version - if version not in ('1.0', '1.1'): - # fall back to a version we support. - version = '1.0' - self.version = version - - def service(self): - try: - try: - self.start() - self.execute() - self.finish() - except socket.error: - self.close_on_finish = True - if self.channel.adj.log_socket_errors: - raise - finally: - pass - - def cancel(self): - self.close_on_finish = True - - def defer(self): - pass - - def build_response_header(self): - version = self.version - # Figure out whether the connection should be closed. - connection = self.request.headers.get('CONNECTION', '').lower() - response_headers = self.response_headers - content_length_header = None - date_header = None - server_header = None - connection_close_header = None - - for i, (headername, headerval) in enumerate(response_headers): - headername = '-'.join( - [x.capitalize() for x in headername.split('-')] - ) - if headername == 'Content-Length': - content_length_header = headerval - if headername == 'Date': - date_header = headerval - if headername == 'Server': - server_header = headerval - if headername == 'Connection': - connection_close_header = headerval.lower() - # replace with properly capitalized version - response_headers[i] = (headername, headerval) - - if content_length_header is None and self.content_length is not None: - content_length_header = str(self.content_length) - self.response_headers.append( - ('Content-Length', content_length_header) - ) - - def close_on_finish(): - if connection_close_header is None: - response_headers.append(('Connection', 'close')) - self.close_on_finish = True - - if version == '1.0': - if connection == 'keep-alive': - if not content_length_header: - close_on_finish() - else: - response_headers.append(('Connection', 'Keep-Alive')) - else: - close_on_finish() - - elif version == '1.1': - if connection == 'close': - close_on_finish() - - if not content_length_header: - response_headers.append(('Transfer-Encoding', 'chunked')) - self.chunked_response = True - if not self.close_on_finish: - close_on_finish() - - # under HTTP 1.1 keep-alive is default, no need to set the header - else: - raise AssertionError('neither HTTP/1.0 or HTTP/1.1') - - # Set the Server and Date field, if not yet specified. This is needed - # if the server is used as a proxy. - ident = self.channel.server.adj.ident - if not server_header: - response_headers.append(('Server', ident)) - else: - response_headers.append(('Via', ident)) - if not date_header: - response_headers.append(('Date', build_http_date(self.start_time))) - - first_line = 'HTTP/%s %s' % (self.version, self.status) - # NB: sorting headers needs to preserve same-named-header order - # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; - # rely on stable sort to keep relative position of same-named headers - next_lines = ['%s: %s' % hv for hv in sorted( - self.response_headers, key=lambda x: x[0])] - lines = [first_line] + next_lines - res = '%s\r\n\r\n' % '\r\n'.join(lines) - return tobytes(res) - - def remove_content_length_header(self): - for i, (header_name, header_value) in enumerate(self.response_headers): - if header_name.lower() == 'content-length': - del self.response_headers[i] - - def start(self): - self.start_time = time.time() - - def finish(self): - if not self.wrote_header: - self.write(b'') - if self.chunked_response: - # not self.write, it will chunk it! - self.channel.write_soon(b'0\r\n\r\n') - - def write(self, data): - if not self.complete: - raise RuntimeError('start_response was not called before body ' - 'written') - channel = self.channel - if not self.wrote_header: - rh = self.build_response_header() - channel.write_soon(rh) - self.wrote_header = True - if data: - towrite = data - cl = self.content_length - if self.chunked_response: - # use chunked encoding response - towrite = tobytes(hex(len(data))[2:].upper()) + b'\r\n' - towrite += data + b'\r\n' - elif cl is not None: - towrite = data[:cl - self.content_bytes_written] - self.content_bytes_written += len(towrite) - if towrite != data and not self.logged_write_excess: - self.logger.warning( - 'application-written content exceeded the number of ' - 'bytes specified by Content-Length header (%s)' % cl) - self.logged_write_excess = True - if towrite: - channel.write_soon(towrite) - -class ErrorTask(Task): - """ An error task produces an error response - """ - complete = True - - def execute(self): - e = self.request.error - body = '%s\r\n\r\n%s' % (e.reason, e.body) - tag = '\r\n\r\n(generated by waitress)' - body = body + tag - self.status = '%s %s' % (e.code, e.reason) - cl = len(body) - self.content_length = cl - self.response_headers.append(('Content-Length', str(cl))) - self.response_headers.append(('Content-Type', 'text/plain')) - if self.version == '1.1': - connection = self.request.headers.get('CONNECTION', '').lower() - if connection == 'close': - self.response_headers.append(('Connection', 'close')) - # under HTTP 1.1 keep-alive is default, no need to set the header - else: - # HTTP 1.0 - self.response_headers.append(('Connection', 'close')) - self.close_on_finish = True - self.write(tobytes(body)) - -class WSGITask(Task): - """A WSGI task produces a response from a WSGI application. - """ - environ = None - - def execute(self): - env = self.get_environment() - - def start_response(status, headers, exc_info=None): - if self.complete and not exc_info: - raise AssertionError("start_response called a second time " - "without providing exc_info.") - if exc_info: - try: - if self.wrote_header: - # higher levels will catch and handle raised exception: - # 1. "service" method in task.py - # 2. "service" method in channel.py - # 3. "handler_thread" method in task.py - reraise(exc_info[0], exc_info[1], exc_info[2]) - else: - # As per WSGI spec existing headers must be cleared - self.response_headers = [] - finally: - exc_info = None - - self.complete = True - - if not status.__class__ is str: - raise AssertionError('status %s is not a string' % status) - if '\n' in status or '\r' in status: - raise ValueError("carriage return/line " - "feed character present in status") - - self.status = status - - # Prepare the headers for output - for k, v in headers: - if not k.__class__ is str: - raise AssertionError( - 'Header name %r is not a string in %r' % (k, (k, v)) - ) - if not v.__class__ is str: - raise AssertionError( - 'Header value %r is not a string in %r' % (v, (k, v)) - ) - - if '\n' in v or '\r' in v: - raise ValueError("carriage return/line " - "feed character present in header value") - if '\n' in k or '\r' in k: - raise ValueError("carriage return/line " - "feed character present in header name") - - kl = k.lower() - if kl == 'content-length': - self.content_length = int(v) - elif kl in hop_by_hop: - raise AssertionError( - '%s is a "hop-by-hop" header; it cannot be used by ' - 'a WSGI application (see PEP 3333)' % k) - - self.response_headers.extend(headers) - - # Return a method used to write the response data. - return self.write - - # Call the application to handle the request and write a response - app_iter = self.channel.server.application(env, start_response) - - if app_iter.__class__ is ReadOnlyFileBasedBuffer: - # NB: do not put this inside the below try: finally: which closes - # the app_iter; we need to defer closing the underlying file. It's - # intention that we don't want to call ``close`` here if the - # app_iter is a ROFBB; the buffer (and therefore the file) will - # eventually be closed within channel.py's _flush_some or - # handle_close instead. - cl = self.content_length - size = app_iter.prepare(cl) - if size: - if cl != size: - if cl is not None: - self.remove_content_length_header() - self.content_length = size - self.write(b'') # generate headers - self.channel.write_soon(app_iter) - return - - try: - first_chunk_len = None - for chunk in app_iter: - if first_chunk_len is None: - first_chunk_len = len(chunk) - # Set a Content-Length header if one is not supplied. - # start_response may not have been called until first - # iteration as per PEP, so we must reinterrogate - # self.content_length here - if self.content_length is None: - app_iter_len = None - if hasattr(app_iter, '__len__'): - app_iter_len = len(app_iter) - if app_iter_len == 1: - self.content_length = first_chunk_len - # transmit headers only after first iteration of the iterable - # that returns a non-empty bytestring (PEP 3333) - if chunk: - self.write(chunk) - - cl = self.content_length - if cl is not None: - if self.content_bytes_written != cl: - # close the connection so the client isn't sitting around - # waiting for more data when there are too few bytes - # to service content-length - self.close_on_finish = True - if self.request.command != 'HEAD': - self.logger.warning( - 'application returned too few bytes (%s) ' - 'for specified Content-Length (%s) via app_iter' % ( - self.content_bytes_written, cl), - ) - finally: - if hasattr(app_iter, 'close'): - app_iter.close() - - def get_environment(self): - """Returns a WSGI environment.""" - environ = self.environ - if environ is not None: - # Return the cached copy. - return environ - - request = self.request - path = request.path - channel = self.channel - server = channel.server - url_prefix = server.adj.url_prefix - - if path.startswith('/'): - # strip extra slashes at the beginning of a path that starts - # with any number of slashes - path = '/' + path.lstrip('/') - - if url_prefix: - # NB: url_prefix is guaranteed by the configuration machinery to - # be either the empty string or a string that starts with a single - # slash and ends without any slashes - if path == url_prefix: - # if the path is the same as the url prefix, the SCRIPT_NAME - # should be the url_prefix and PATH_INFO should be empty - path = '' - else: - # if the path starts with the url prefix plus a slash, - # the SCRIPT_NAME should be the url_prefix and PATH_INFO should - # the value of path from the slash until its end - url_prefix_with_trailing_slash = url_prefix + '/' - if path.startswith(url_prefix_with_trailing_slash): - path = path[len(url_prefix):] - - environ = {} - environ['REQUEST_METHOD'] = request.command.upper() - environ['SERVER_PORT'] = str(server.effective_port) - environ['SERVER_NAME'] = server.server_name - environ['SERVER_SOFTWARE'] = server.adj.ident - environ['SERVER_PROTOCOL'] = 'HTTP/%s' % self.version - environ['SCRIPT_NAME'] = url_prefix - environ['PATH_INFO'] = path - environ['QUERY_STRING'] = request.query - host = environ['REMOTE_ADDR'] = channel.addr[0] - - headers = dict(request.headers) - if host == server.adj.trusted_proxy: - wsgi_url_scheme = headers.pop('X_FORWARDED_PROTO', - request.url_scheme) - else: - wsgi_url_scheme = request.url_scheme - if wsgi_url_scheme not in ('http', 'https'): - raise ValueError('Invalid X_FORWARDED_PROTO value') - for key, value in headers.items(): - value = value.strip() - mykey = rename_headers.get(key, None) - if mykey is None: - mykey = 'HTTP_%s' % key - if mykey not in environ: - environ[mykey] = value - - # the following environment variables are required by the WSGI spec - environ['wsgi.version'] = (1, 0) - environ['wsgi.url_scheme'] = wsgi_url_scheme - environ['wsgi.errors'] = sys.stderr # apps should use the logging module - environ['wsgi.multithread'] = True - environ['wsgi.multiprocess'] = False - environ['wsgi.run_once'] = False - environ['wsgi.input'] = request.get_body_stream() - environ['wsgi.file_wrapper'] = ReadOnlyFileBasedBuffer - - self.environ = environ - return environ diff --git a/waitress/tests/fixtureapps/badcl.py b/waitress/tests/fixtureapps/badcl.py deleted file mode 100644 index 2289a125..00000000 --- a/waitress/tests/fixtureapps/badcl.py +++ /dev/null @@ -1,12 +0,0 @@ -def app(environ, start_response): # pragma: no cover - body = b'abcdefghi' - cl = len(body) - if environ['PATH_INFO'] == '/short_body': - cl = len(body) + 1 - if environ['PATH_INFO'] == '/long_body': - cl = len(body) - 1 - start_response( - '200 OK', - [('Content-Length', str(cl)), ('Content-Type', 'text/plain')] - ) - return [body] diff --git a/waitress/tests/fixtureapps/echo.py b/waitress/tests/fixtureapps/echo.py deleted file mode 100644 index f5fd5d13..00000000 --- a/waitress/tests/fixtureapps/echo.py +++ /dev/null @@ -1,11 +0,0 @@ -def app(environ, start_response): # pragma: no cover - cl = environ.get('CONTENT_LENGTH', None) - if cl is not None: - cl = int(cl) - body = environ['wsgi.input'].read(cl) - cl = str(len(body)) - start_response( - '200 OK', - [('Content-Length', cl), ('Content-Type', 'text/plain')] - ) - return [body] diff --git a/waitress/tests/fixtureapps/error.py b/waitress/tests/fixtureapps/error.py deleted file mode 100644 index cab8ad6e..00000000 --- a/waitress/tests/fixtureapps/error.py +++ /dev/null @@ -1,20 +0,0 @@ -def app(environ, start_response): # pragma: no cover - cl = environ.get('CONTENT_LENGTH', None) - if cl is not None: - cl = int(cl) - body = environ['wsgi.input'].read(cl) - cl = str(len(body)) - if environ['PATH_INFO'] == '/before_start_response': - raise ValueError('wrong') - write = start_response( - '200 OK', - [('Content-Length', cl), ('Content-Type', 'text/plain')] - ) - if environ['PATH_INFO'] == '/after_write_cb': - write('abc') - if environ['PATH_INFO'] == '/in_generator': - def foo(): - yield 'abc' - raise ValueError - return foo() - raise ValueError('wrong') diff --git a/waitress/tests/fixtureapps/filewrapper.py b/waitress/tests/fixtureapps/filewrapper.py deleted file mode 100644 index be35b025..00000000 --- a/waitress/tests/fixtureapps/filewrapper.py +++ /dev/null @@ -1,70 +0,0 @@ -import os - -here = os.path.dirname(os.path.abspath(__file__)) -fn = os.path.join(here, 'groundhog1.jpg') - -class KindaFilelike(object): # pragma: no cover - - def __init__(self, bytes): - self.bytes = bytes - - def read(self, n): - bytes = self.bytes[:n] - self.bytes = self.bytes[n:] - return bytes - -def app(environ, start_response): # pragma: no cover - path_info = environ['PATH_INFO'] - if path_info.startswith('/filelike'): - f = open(fn, 'rb') - f.seek(0, 2) - cl = f.tell() - f.seek(0) - if path_info == '/filelike': - headers = [ - ('Content-Length', str(cl)), - ('Content-Type', 'image/jpeg'), - ] - elif path_info == '/filelike_nocl': - headers = [('Content-Type', 'image/jpeg')] - elif path_info == '/filelike_shortcl': - # short content length - headers = [ - ('Content-Length', '1'), - ('Content-Type', 'image/jpeg'), - ] - else: - # long content length (/filelike_longcl) - headers = [ - ('Content-Length', str(cl + 10)), - ('Content-Type', 'image/jpeg'), - ] - else: - data = open(fn, 'rb').read() - cl = len(data) - f = KindaFilelike(data) - if path_info == '/notfilelike': - headers = [ - ('Content-Length', str(len(data))), - ('Content-Type', 'image/jpeg'), - ] - elif path_info == '/notfilelike_nocl': - headers = [('Content-Type', 'image/jpeg')] - elif path_info == '/notfilelike_shortcl': - # short content length - headers = [ - ('Content-Length', '1'), - ('Content-Type', 'image/jpeg'), - ] - else: - # long content length (/notfilelike_longcl) - headers = [ - ('Content-Length', str(cl + 10)), - ('Content-Type', 'image/jpeg'), - ] - - start_response( - '200 OK', - headers - ) - return environ['wsgi.file_wrapper'](f, 8192) diff --git a/waitress/tests/fixtureapps/getline.py b/waitress/tests/fixtureapps/getline.py deleted file mode 100644 index 7d8ae5d2..00000000 --- a/waitress/tests/fixtureapps/getline.py +++ /dev/null @@ -1,17 +0,0 @@ -import sys - -if __name__ == '__main__': - try: - from urllib.request import urlopen, URLError - except ImportError: - from urllib2 import urlopen, URLError - - url = sys.argv[1] - headers = {'Content-Type': 'text/plain; charset=utf-8'} - try: - resp = urlopen(url) - line = resp.readline().decode('ascii') # py3 - except URLError: - line = 'failed to read %s' % url - sys.stdout.write(line) - sys.stdout.flush() diff --git a/waitress/tests/fixtureapps/nocl.py b/waitress/tests/fixtureapps/nocl.py deleted file mode 100644 index 05e1d18f..00000000 --- a/waitress/tests/fixtureapps/nocl.py +++ /dev/null @@ -1,24 +0,0 @@ -def chunks(l, n): # pragma: no cover - """ Yield successive n-sized chunks from l. - """ - for i in range(0, len(l), n): - yield l[i:i + n] - -def gen(body): # pragma: no cover - for chunk in chunks(body, 10): - yield chunk - -def app(environ, start_response): # pragma: no cover - cl = environ.get('CONTENT_LENGTH', None) - if cl is not None: - cl = int(cl) - body = environ['wsgi.input'].read(cl) - start_response( - '200 OK', - [('Content-Type', 'text/plain')] - ) - if environ['PATH_INFO'] == '/list': - return [body] - if environ['PATH_INFO'] == '/list_lentwo': - return [body[0:1], body[1:]] - return gen(body) diff --git a/waitress/tests/fixtureapps/runner.py b/waitress/tests/fixtureapps/runner.py deleted file mode 100644 index eee0e45f..00000000 --- a/waitress/tests/fixtureapps/runner.py +++ /dev/null @@ -1,5 +0,0 @@ -def app(): # pragma: no cover - return None - -def returns_app(): # pragma: no cover - return app diff --git a/waitress/tests/fixtureapps/sleepy.py b/waitress/tests/fixtureapps/sleepy.py deleted file mode 100644 index 03bd0ab0..00000000 --- a/waitress/tests/fixtureapps/sleepy.py +++ /dev/null @@ -1,14 +0,0 @@ -import time - -def app(environ, start_response): # pragma: no cover - if environ['PATH_INFO'] == '/sleepy': - time.sleep(2) - body = b'sleepy returned' - else: - body = b'notsleepy returned' - cl = str(len(body)) - start_response( - '200 OK', - [('Content-Length', cl), ('Content-Type', 'text/plain')] - ) - return [body] diff --git a/waitress/tests/fixtureapps/toolarge.py b/waitress/tests/fixtureapps/toolarge.py deleted file mode 100644 index 150e9087..00000000 --- a/waitress/tests/fixtureapps/toolarge.py +++ /dev/null @@ -1,8 +0,0 @@ -def app(environ, start_response): # pragma: no cover - body = b'abcdef' - cl = len(body) - start_response( - '200 OK', - [('Content-Length', str(cl)), ('Content-Type', 'text/plain')] - ) - return [body] diff --git a/waitress/tests/fixtureapps/writecb.py b/waitress/tests/fixtureapps/writecb.py deleted file mode 100644 index ac59eb96..00000000 --- a/waitress/tests/fixtureapps/writecb.py +++ /dev/null @@ -1,14 +0,0 @@ -def app(environ, start_response): # pragma: no cover - path_info = environ['PATH_INFO'] - if path_info == '/no_content_length': - headers = [] - else: - headers = [('Content-Length', '9')] - write = start_response('200 OK', headers) - if path_info == '/long_body': - write(b'abcdefghij') - elif path_info == '/short_body': - write(b'abcdefgh') - else: - write(b'abcdefghi') - return [] diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py deleted file mode 100644 index 9446705d..00000000 --- a/waitress/tests/test_adjustments.py +++ /dev/null @@ -1,294 +0,0 @@ -import sys -import socket - -from waitress.compat import ( - PY2, - WIN, - ) - -if sys.version_info[:2] == (2, 6): # pragma: no cover - import unittest2 as unittest -else: # pragma: no cover - import unittest - -class Test_asbool(unittest.TestCase): - - def _callFUT(self, s): - from waitress.adjustments import asbool - return asbool(s) - - def test_s_is_None(self): - result = self._callFUT(None) - self.assertEqual(result, False) - - def test_s_is_True(self): - result = self._callFUT(True) - self.assertEqual(result, True) - - def test_s_is_False(self): - result = self._callFUT(False) - self.assertEqual(result, False) - - def test_s_is_true(self): - result = self._callFUT('True') - self.assertEqual(result, True) - - def test_s_is_false(self): - result = self._callFUT('False') - self.assertEqual(result, False) - - def test_s_is_yes(self): - result = self._callFUT('yes') - self.assertEqual(result, True) - - def test_s_is_on(self): - result = self._callFUT('on') - self.assertEqual(result, True) - - def test_s_is_1(self): - result = self._callFUT(1) - self.assertEqual(result, True) - -class TestAdjustments(unittest.TestCase): - - def _hasIPv6(self): # pragma: nocover - if not socket.has_ipv6: - return False - - try: - socket.getaddrinfo( - '::1', - 0, - socket.AF_UNSPEC, - socket.SOCK_STREAM, - socket.IPPROTO_TCP, - socket.AI_PASSIVE | socket.AI_ADDRCONFIG - ) - - return True - except socket.gaierror as e: - # Check to see what the error is - if e.errno == socket.EAI_ADDRFAMILY: - return False - else: - raise e - - def _makeOne(self, **kw): - from waitress.adjustments import Adjustments - return Adjustments(**kw) - - def test_goodvars(self): - inst = self._makeOne( - host='localhost', - port='8080', - threads='5', - trusted_proxy='192.168.1.1', - url_scheme='https', - backlog='20', - recv_bytes='200', - send_bytes='300', - outbuf_overflow='400', - inbuf_overflow='500', - connection_limit='1000', - cleanup_interval='1100', - channel_timeout='1200', - log_socket_errors='true', - max_request_header_size='1300', - max_request_body_size='1400', - expose_tracebacks='true', - ident='abc', - asyncore_loop_timeout='5', - asyncore_use_poll=True, - unix_socket='/tmp/waitress.sock', - unix_socket_perms='777', - url_prefix='///foo/', - ipv4=True, - ipv6=False, - ) - - self.assertEqual(inst.host, 'localhost') - self.assertEqual(inst.port, 8080) - self.assertEqual(inst.threads, 5) - self.assertEqual(inst.trusted_proxy, '192.168.1.1') - self.assertEqual(inst.url_scheme, 'https') - self.assertEqual(inst.backlog, 20) - self.assertEqual(inst.recv_bytes, 200) - self.assertEqual(inst.send_bytes, 300) - self.assertEqual(inst.outbuf_overflow, 400) - self.assertEqual(inst.inbuf_overflow, 500) - self.assertEqual(inst.connection_limit, 1000) - self.assertEqual(inst.cleanup_interval, 1100) - self.assertEqual(inst.channel_timeout, 1200) - self.assertEqual(inst.log_socket_errors, True) - self.assertEqual(inst.max_request_header_size, 1300) - self.assertEqual(inst.max_request_body_size, 1400) - self.assertEqual(inst.expose_tracebacks, True) - self.assertEqual(inst.asyncore_loop_timeout, 5) - self.assertEqual(inst.asyncore_use_poll, True) - self.assertEqual(inst.ident, 'abc') - self.assertEqual(inst.unix_socket, '/tmp/waitress.sock') - self.assertEqual(inst.unix_socket_perms, 0o777) - self.assertEqual(inst.url_prefix, '/foo') - self.assertEqual(inst.ipv4, True) - self.assertEqual(inst.ipv6, False) - - bind_pairs = [ - sockaddr[:2] - for (family, _, _, sockaddr) in inst.listen - if family == socket.AF_INET - ] - - # On Travis, somehow we start listening to two sockets when resolving - # localhost... - self.assertEqual(('127.0.0.1', 8080), bind_pairs[0]) - - def test_goodvar_listen(self): - inst = self._makeOne(listen='127.0.0.1') - - bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] - - self.assertEqual(bind_pairs, [('127.0.0.1', 8080)]) - - def test_default_listen(self): - inst = self._makeOne() - - bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] - - self.assertEqual(bind_pairs, [('0.0.0.0', 8080)]) - - def test_multiple_listen(self): - inst = self._makeOne(listen='127.0.0.1:9090 127.0.0.1:8080') - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, - [('127.0.0.1', 9090), - ('127.0.0.1', 8080)]) - - def test_wildcard_listen(self): - inst = self._makeOne(listen='*:8080') - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertTrue(len(bind_pairs) >= 1) - - def test_ipv6_no_port(self): # pragma: nocover - if not self._hasIPv6(): - return - - inst = self._makeOne(listen='[::1]') - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [('::1', 8080)]) - - def test_bad_port(self): - self.assertRaises(ValueError, self._makeOne, listen='127.0.0.1:test') - - def test_service_port(self): - if WIN and PY2: # pragma: no cover - # On Windows and Python 2 this is broken, so we raise a ValueError - self.assertRaises( - ValueError, - self._makeOne, - listen='127.0.0.1:http', - ) - return - - inst = self._makeOne(listen='127.0.0.1:http 0.0.0.0:https') - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [('127.0.0.1', 80), ('0.0.0.0', 443)]) - - def test_dont_mix_host_port_listen(self): - self.assertRaises( - ValueError, - self._makeOne, - host='localhost', - port='8080', - listen='127.0.0.1:8080', - ) - - def test_badvar(self): - self.assertRaises(ValueError, self._makeOne, nope=True) - - def test_ipv4_disabled(self): - self.assertRaises(ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080") - - def test_ipv6_disabled(self): - self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") - -class TestCLI(unittest.TestCase): - - def parse(self, argv): - from waitress.adjustments import Adjustments - return Adjustments.parse_args(argv) - - def test_noargs(self): - opts, args = self.parse([]) - self.assertDictEqual(opts, {'call': False, 'help': False}) - self.assertSequenceEqual(args, []) - - def test_help(self): - opts, args = self.parse(['--help']) - self.assertDictEqual(opts, {'call': False, 'help': True}) - self.assertSequenceEqual(args, []) - - def test_call(self): - opts, args = self.parse(['--call']) - self.assertDictEqual(opts, {'call': True, 'help': False}) - self.assertSequenceEqual(args, []) - - def test_both(self): - opts, args = self.parse(['--call', '--help']) - self.assertDictEqual(opts, {'call': True, 'help': True}) - self.assertSequenceEqual(args, []) - - def test_positive_boolean(self): - opts, args = self.parse(['--expose-tracebacks']) - self.assertDictContainsSubset({'expose_tracebacks': 'true'}, opts) - self.assertSequenceEqual(args, []) - - def test_negative_boolean(self): - opts, args = self.parse(['--no-expose-tracebacks']) - self.assertDictContainsSubset({'expose_tracebacks': 'false'}, opts) - self.assertSequenceEqual(args, []) - - def test_cast_params(self): - opts, args = self.parse([ - '--host=localhost', - '--port=80', - '--unix-socket-perms=777' - ]) - self.assertDictContainsSubset({ - 'host': 'localhost', - 'port': '80', - 'unix_socket_perms': '777', - }, opts) - self.assertSequenceEqual(args, []) - - def test_listen_params(self): - opts, args = self.parse([ - '--listen=test:80', - ]) - - self.assertDictContainsSubset({ - 'listen': ' test:80' - }, opts) - self.assertSequenceEqual(args, []) - - def test_multiple_listen_params(self): - opts, args = self.parse([ - '--listen=test:80', - '--listen=test:8080', - ]) - - self.assertDictContainsSubset({ - 'listen': ' test:80 test:8080' - }, opts) - self.assertSequenceEqual(args, []) - - def test_bad_param(self): - import getopt - self.assertRaises(getopt.GetoptError, self.parse, ['--no-host']) diff --git a/waitress/tests/test_compat.py b/waitress/tests/test_compat.py deleted file mode 100644 index b5f66257..00000000 --- a/waitress/tests/test_compat.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - -class Test_unquote_bytes_to_wsgi(unittest.TestCase): - - def _callFUT(self, v): - from waitress.compat import unquote_bytes_to_wsgi - return unquote_bytes_to_wsgi(v) - - def test_highorder(self): - from waitress.compat import PY3 - val = b'/a%C5%9B' - result = self._callFUT(val) - if PY3: # pragma: no cover - # PEP 3333 urlunquoted-latin1-decoded-bytes - self.assertEqual(result, '/aÅ\x9b') - else: # pragma: no cover - # sanity - self.assertEqual(result, b'/a\xc5\x9b') diff --git a/waitress/tests/test_functional.py b/waitress/tests/test_functional.py deleted file mode 100644 index 59ef4e4d..00000000 --- a/waitress/tests/test_functional.py +++ /dev/null @@ -1,1551 +0,0 @@ -import errno -import logging -import multiprocessing -import os -import socket -import string -import subprocess -import sys -import time -import unittest -from waitress import server -from waitress.compat import ( - httplib, - tobytes -) -from waitress.utilities import cleanup_unix_socket - -dn = os.path.dirname -here = dn(__file__) - -class NullHandler(logging.Handler): # pragma: no cover - """A logging handler that swallows all emitted messages. - """ - def emit(self, record): - pass - -def start_server(app, svr, queue, **kwargs): # pragma: no cover - """Run a fixture application. - """ - logging.getLogger('waitress').addHandler(NullHandler()) - svr(app, queue, **kwargs).run() - -class FixtureTcpWSGIServer(server.TcpWSGIServer): - """A version of TcpWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_INET # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - kw['port'] = 0 # Bind to any available port. - super(FixtureTcpWSGIServer, self).__init__(application, **kw) - host, port = self.socket.getsockname() - if os.name == 'nt': - host = '127.0.0.1' - queue.put((host, port)) - -class SubprocessTests(object): - - # For nose: all tests may be ran in separate processes. - _multiprocess_can_split_ = True - - exe = sys.executable - - server = None - - def start_subprocess(self, target, **kw): - # Spawn a server process. - self.queue = multiprocessing.Queue() - self.proc = multiprocessing.Process( - target=start_server, - args=(target, self.server, self.queue), - kwargs=kw, - ) - self.proc.start() - if self.proc.exitcode is not None: # pragma: no cover - raise RuntimeError("%s didn't start" % str(target)) - # Get the socket the server is listening on. - self.bound_to = self.queue.get(timeout=5) - self.sock = self.create_socket() - - def stop_subprocess(self): - if self.proc.exitcode is None: - self.proc.terminate() - self.sock.close() - # This give us one FD back ... - self.queue.close() - - def assertline(self, line, status, reason, version): - v, s, r = (x.strip() for x in line.split(None, 2)) - self.assertEqual(s, tobytes(status)) - self.assertEqual(r, tobytes(reason)) - self.assertEqual(v, tobytes(version)) - - def create_socket(self): - return socket.socket(self.server.family, socket.SOCK_STREAM) - - def connect(self): - self.sock.connect(self.bound_to) - - def make_http_connection(self): - raise NotImplementedError # pragma: no cover - - def send_check_error(self, to_send): - self.sock.send(to_send) - -class TcpTests(SubprocessTests): - - server = FixtureTcpWSGIServer - - def make_http_connection(self): - return httplib.HTTPConnection(*self.bound_to) - -class SleepyThreadTests(TcpTests, unittest.TestCase): - # test that sleepy thread doesnt block other requests - - def setUp(self): - from waitress.tests.fixtureapps import sleepy - self.start_subprocess(sleepy.app) - - def tearDown(self): - self.stop_subprocess() - - def test_it(self): - getline = os.path.join(here, 'fixtureapps', 'getline.py') - cmds = ( - [self.exe, getline, 'http://%s:%d/sleepy' % self.bound_to], - [self.exe, getline, 'http://%s:%d/' % self.bound_to] - ) - r, w = os.pipe() - procs = [] - for cmd in cmds: - procs.append(subprocess.Popen(cmd, stdout=w)) - time.sleep(3) - for proc in procs: - if proc.returncode is not None: # pragma: no cover - proc.terminate() - # the notsleepy response should always be first returned (it sleeps - # for 2 seconds, then returns; the notsleepy response should be - # processed in the meantime) - result = os.read(r, 10000) - os.close(r) - os.close(w) - self.assertEqual(result, b'notsleepy returnedsleepy returned') - -class EchoTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import echo - self.start_subprocess(echo.app) - - def tearDown(self): - self.stop_subprocess() - - def test_date_and_server(self): - to_send = ("GET / HTTP/1.0\n" - "Content-Length: 0\n\n") - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('server'), 'waitress') - self.assertTrue(headers.get('date')) - - def test_bad_host_header(self): - # http://corte.si/posts/code/pathod/pythonservers/index.html - to_send = ("GET / HTTP/1.0\n" - " Host: 0\n\n") - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '400', 'Bad Request', 'HTTP/1.0') - self.assertEqual(headers.get('server'), 'waitress') - self.assertTrue(headers.get('date')) - - def test_send_with_body(self): - to_send = ("GET / HTTP/1.0\n" - "Content-Length: 5\n\n") - to_send += 'hello' - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('content-length'), '5') - self.assertEqual(response_body, b'hello') - - def test_send_empty_body(self): - to_send = ("GET / HTTP/1.0\n" - "Content-Length: 0\n\n") - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('content-length'), '0') - self.assertEqual(response_body, b'') - - def test_multiple_requests_with_body(self): - for x in range(3): - self.sock = self.create_socket() - self.test_send_with_body() - self.sock.close() - - def test_multiple_requests_without_body(self): - for x in range(3): - self.sock = self.create_socket() - self.test_send_empty_body() - self.sock.close() - - def test_without_crlf(self): - data = "Echo\nthis\r\nplease" - s = tobytes( - "GET / HTTP/1.0\n" - "Connection: close\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(int(headers['content-length']), len(data)) - self.assertEqual(len(response_body), len(data)) - self.assertEqual(response_body, tobytes(data)) - - def test_large_body(self): - # 1024 characters. - body = 'This string has 32 characters.\r\n' * 32 - s = tobytes( - "GET / HTTP/1.0\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(body), body) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('content-length'), '1024') - self.assertEqual(response_body, tobytes(body)) - - def test_many_clients(self): - conns = [] - for n in range(50): - h = self.make_http_connection() - h.request("GET", "/", headers={"Accept": "text/plain"}) - conns.append(h) - responses = [] - for h in conns: - response = h.getresponse() - self.assertEqual(response.status, 200) - responses.append(response) - for response in responses: - response.read() - - def test_chunking_request_without_content(self): - header = tobytes( - "GET / HTTP/1.1\n" - "Transfer-Encoding: chunked\n\n" - ) - self.connect() - self.sock.send(header) - self.sock.send(b"0\r\n\r\n") - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - self.assertEqual(response_body, b'') - self.assertEqual(headers['content-length'], '0') - self.assertFalse('transfer-encoding' in headers) - - def test_chunking_request_with_content(self): - control_line = b"20;\r\n" # 20 hex = 32 dec - s = b'This string has 32 characters.\r\n' - expected = s * 12 - header = tobytes( - "GET / HTTP/1.1\n" - "Transfer-Encoding: chunked\n\n" - ) - self.connect() - self.sock.send(header) - fp = self.sock.makefile('rb', 0) - for n in range(12): - self.sock.send(control_line) - self.sock.send(s) - self.sock.send(b"0\r\n\r\n") - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - self.assertEqual(response_body, expected) - self.assertEqual(headers['content-length'], str(len(expected))) - self.assertFalse('transfer-encoding' in headers) - - def test_broken_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = 'This string has 32 characters.\r\n' - to_send = "GET / HTTP/1.1\nTransfer-Encoding: chunked\n\n" - to_send += (control_line + s) - # garbage in input - to_send += "GET / HTTP/1.1\nTransfer-Encoding: chunked\n\n" - to_send += (control_line + s) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # receiver caught garbage and turned it into a 400 - self.assertline(line, '400', 'Bad Request', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertEqual(sorted(headers.keys()), - ['content-length', 'content-type', 'date', 'server']) - self.assertEqual(headers['content-type'], 'text/plain') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_keepalive_http_10(self): - # Handling of Keep-Alive within HTTP 1.0 - data = "Default: Don't keep me alive" - s = tobytes( - "GET / HTTP/1.0\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader('Connection', '') - # We sent no Connection: Keep-Alive header - # Connection: close (or no header) is default. - self.assertTrue(connection != 'Keep-Alive') - - def test_keepalive_http10_explicit(self): - # If header Connection: Keep-Alive is explicitly sent, - # we want to keept the connection open, we also need to return - # the corresponding header - data = "Keep me alive" - s = tobytes( - "GET / HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader('Connection', '') - self.assertEqual(connection, 'Keep-Alive') - - def test_keepalive_http_11(self): - # Handling of Keep-Alive within HTTP 1.1 - - # All connections are kept alive, unless stated otherwise - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data)) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader('connection') != 'close') - - def test_keepalive_http11_explicit(self): - # Explicitly set keep-alive - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\n" - "Connection: keep-alive\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader('connection') != 'close') - - def test_keepalive_http11_connclose(self): - # specifying Connection: close explicitly - data = "Don't keep me alive" - s = tobytes( - "GET / HTTP/1.1\n" - "Connection: close\n" - "Content-Length: %d\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertEqual(response.getheader('connection'), 'close') - -class PipeliningTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import echo - self.start_subprocess(echo.app) - - def tearDown(self): - self.stop_subprocess() - - def test_pipelining(self): - s = ("GET / HTTP/1.0\r\n" - "Connection: %s\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s") - to_send = b'' - count = 25 - for n in range(count): - body = "Response #%d\r\n" % (n + 1) - if n + 1 < count: - conn = 'keep-alive' - else: - conn = 'close' - to_send += tobytes(s % (conn, len(body), body)) - - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - for n in range(count): - expect_body = tobytes("Response #%d\r\n" % (n + 1)) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get('content-length')) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, expect_body) - -class ExpectContinueTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import echo - self.start_subprocess(echo.app) - - def tearDown(self): - self.stop_subprocess() - - def test_expect_continue(self): - # specifying Connection: close explicitly - data = "I have expectations" - to_send = tobytes( - "GET / HTTP/1.1\n" - "Connection: close\n" - "Content-Length: %d\n" - "Expect: 100-continue\n" - "\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line = fp.readline() # continue status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - self.assertEqual(int(status), 100) - self.assertEqual(reason, b'Continue') - self.assertEqual(version, b'HTTP/1.1') - fp.readline() # blank line - line = fp.readline() # next status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get('content-length')) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, tobytes(data)) - -class BadContentLengthTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import badcl - self.start_subprocess(badcl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get('content-length')) - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertNotEqual(content_length, len(response_body)) - self.assertEqual(len(response_body), content_length - 1) - self.assertEqual(response_body, tobytes('abcdefghi')) - # remote closed connection (despite keepalive header); not sure why - # first send succeeds - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too short - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get('content-length')) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes('abcdefgh')) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get('content-length')) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - -class NoContentLengthTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import nocl - self.start_subprocess(nocl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_http10_generator(self): - body = string.ascii_letters - to_send = ("GET / HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: %d\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('content-length'), None) - self.assertEqual(headers.get('connection'), 'close') - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # generators cannot have a content-length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http10_list(self): - body = string.ascii_letters - to_send = ("GET /list HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: %d\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers['content-length'], str(len(body))) - self.assertEqual(headers.get('connection'), 'Keep-Alive') - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - - def test_http10_listlentwo(self): - body = string.ascii_letters - to_send = ("GET /list_lentwo HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: %d\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(headers.get('content-length'), None) - self.assertEqual(headers.get('connection'), 'close') - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # lists of length > 1 cannot have their content length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_generator(self): - body = string.ascii_letters - to_send = ("GET / HTTP/1.1\n" - "Content-Length: %s\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - expected = b'' - for chunk in chunks(body, 10): - expected += tobytes( - '%s\r\n%s\r\n' % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b'0\r\n\r\n' - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_list(self): - body = string.ascii_letters - to_send = ("GET /list HTTP/1.1\n" - "Content-Length: %d\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - self.assertEqual(headers['content-length'], str(len(body))) - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - - def test_http11_listlentwo(self): - body = string.ascii_letters - to_send = ("GET /list_lentwo HTTP/1.1\n" - "Content-Length: %s\n\n" % len(body)) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - expected = b'' - for chunk in (body[0], body[1:]): - expected += tobytes( - '%s\r\n%s\r\n' % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b'0\r\n\r\n' - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - -class WriteCallbackTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import writecb - self.start_subprocess(writecb.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, 9) - self.assertNotEqual(cl, len(response_body)) - self.assertEqual(len(response_body), cl - 1) - self.assertEqual(response_body, tobytes('abcdefgh')) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too long - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get('content-length')) or None - self.assertEqual(content_length, 9) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes('abcdefghi')) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - - def test_equal_body(self): - # check server doesnt close connection when body is equal to - # cl header - to_send = tobytes( - "GET /equal_body HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get('content-length')) or None - self.assertEqual(content_length, 9) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes('abcdefghi')) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - - def test_no_content_length(self): - # wtf happens when there's no content-length - to_send = tobytes( - "GET /no_content_length HTTP/1.0\n" - "Connection: Keep-Alive\n" - "Content-Length: 0\n" - "\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line = fp.readline() # status line - line, headers, response_body = read_http(fp) - content_length = headers.get('content-length') - self.assertEqual(content_length, None) - self.assertEqual(response_body, tobytes('abcdefghi')) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - -class TooLargeTests(object): - - toobig = 1050 - - def setUp(self): - from waitress.tests.fixtureapps import toolarge - self.start_subprocess(toolarge.app, - max_request_header_size=1000, - max_request_body_size=1000) - - def tearDown(self): - self.stop_subprocess() - - def test_request_body_too_large_with_wrong_cl_http10(self): - body = 'a' * self.toobig - to_send = ("GET / HTTP/1.0\n" - "Content-Length: 5\n\n") - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # server trusts the content-length header; no pipelining, - # so request fulfilled, extra bytes are thrown away - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): - body = 'a' * self.toobig - to_send = ("GET / HTTP/1.0\n" - "Content-Length: 5\n" - "Connection: Keep-Alive\n\n") - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - self.assertline(line, '431', 'Request Header Fields Too Large', - 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10(self): - body = 'a' * self.toobig - to_send = "GET / HTTP/1.0\n\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # extra bytes are thrown away (no pipelining), connection closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10_keepalive(self): - body = 'a' * self.toobig - to_send = "GET / HTTP/1.0\nConnection: Keep-Alive\n\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed zero) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - # next response overruns because the extra data appears to be - # header data - self.assertline(line, '431', 'Request Header Fields Too Large', - 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11(self): - body = 'a' * self.toobig - to_send = ("GET / HTTP/1.1\n" - "Content-Length: 5\n\n") - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, '431', 'Request Header Fields Too Large', - 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11_connclose(self): - body = 'a' * self.toobig - to_send = "GET / HTTP/1.1\nContent-Length: 5\nConnection: close\n\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11(self): - body = 'a' * self.toobig - to_send = "GET / HTTP/1.1\n\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb') - # server trusts the content-length header (assumed 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # server assumes pipelined requests due to http/1.1, and the first - # request was assumed c-l 0 because it had no content-length header, - # so entire body looks like the header of the subsequent request - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, '431', 'Request Header Fields Too Large', - 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11_connclose(self): - body = 'a' * self.toobig - to_send = "GET / HTTP/1.1\nConnection: close\n\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed 0) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = 'This string has 32 characters.\r\n' - to_send = "GET / HTTP/1.1\nTransfer-Encoding: chunked\n\n" - repeat = control_line + s - to_send += repeat * ((self.toobig // len(repeat)) + 1) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - # body bytes counter caught a max_request_body_size overrun - self.assertline(line, '413', 'Request Entity Too Large', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertEqual(headers['content-type'], 'text/plain') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - -class InternalServerErrorTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import error - self.start_subprocess(error.app, expose_tracebacks=True) - - def tearDown(self): - self.stop_subprocess() - - def test_before_start_response_http_10(self): - to_send = "GET /before_start_response HTTP/1.0\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(headers['connection'], 'close') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11(self): - to_send = "GET /before_start_response HTTP/1.1\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(sorted(headers.keys()), - ['content-length', 'content-type', 'date', 'server']) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11_close(self): - to_send = tobytes( - "GET /before_start_response HTTP/1.1\n" - "Connection: close\n\n") - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(sorted(headers.keys()), - ['connection', 'content-length', 'content-type', 'date', - 'server']) - self.assertEqual(headers['connection'], 'close') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http10(self): - to_send = "GET /after_start_response HTTP/1.0\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(sorted(headers.keys()), - ['connection', 'content-length', 'content-type', 'date', - 'server']) - self.assertEqual(headers['connection'], 'close') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11(self): - to_send = "GET /after_start_response HTTP/1.1\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(sorted(headers.keys()), - ['content-length', 'content-type', 'date', 'server']) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11_close(self): - to_send = tobytes( - "GET /after_start_response HTTP/1.1\n" - "Connection: close\n\n") - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '500', 'Internal Server Error', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b'Internal Server Error')) - self.assertEqual(sorted(headers.keys()), - ['connection', 'content-length', 'content-type', 'date', - 'server']) - self.assertEqual(headers['connection'], 'close') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_write_cb(self): - to_send = "GET /after_write_cb HTTP/1.1\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - self.assertEqual(response_body, b'') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_in_generator(self): - to_send = "GET /in_generator HTTP/1.1\n\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - self.assertEqual(response_body, b'') - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - -class FileWrapperTests(object): - - def setUp(self): - from waitress.tests.fixtureapps import filewrapper - self.start_subprocess(filewrapper.app) - - def tearDown(self): - self.stop_subprocess() - - def test_filelike_http11(self): - to_send = "GET /filelike HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has not been closed - - def test_filelike_nocl_http11(self): - to_send = "GET /filelike_nocl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has not been closed - - def test_filelike_shortcl_http11(self): - to_send = "GET /filelike_shortcl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377' in response_body) - # connection has not been closed - - def test_filelike_longcl_http11(self): - to_send = "GET /filelike_longcl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has not been closed - - def test_notfilelike_http11(self): - to_send = "GET /notfilelike HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has not been closed - - def test_notfilelike_nocl_http11(self): - to_send = "GET /notfilelike_nocl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_shortcl_http11(self): - to_send = "GET /notfilelike_shortcl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377' in response_body) - # connection has not been closed - - def test_notfilelike_longcl_http11(self): - to_send = "GET /notfilelike_longcl HTTP/1.1\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.1') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body) + 10) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_http10(self): - to_send = "GET /filelike HTTP/1.0\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_nocl_http10(self): - to_send = "GET /filelike_nocl HTTP/1.0\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_http10(self): - to_send = "GET /notfilelike HTTP/1.0\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - cl = int(headers['content-length']) - self.assertEqual(cl, len(response_body)) - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_nocl_http10(self): - to_send = "GET /notfilelike_nocl HTTP/1.0\n\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile('rb', 0) - line, headers, response_body = read_http(fp) - self.assertline(line, '200', 'OK', 'HTTP/1.0') - ct = headers['content-type'] - self.assertEqual(ct, 'image/jpeg') - self.assertTrue(b'\377\330\377' in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - -class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): - pass - -class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): - pass - -class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): - pass - -class TcpBadContentLengthTests( - BadContentLengthTests, TcpTests, unittest.TestCase): - pass - -class TcpNoContentLengthTests( - NoContentLengthTests, TcpTests, unittest.TestCase): - pass - -class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): - pass - -class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): - pass - -class TcpInternalServerErrorTests( - InternalServerErrorTests, TcpTests, unittest.TestCase): - pass - -class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): - pass - -if hasattr(socket, 'AF_UNIX'): - - class FixtureUnixWSGIServer(server.UnixWSGIServer): - """A version of UnixWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_UNIX # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - # To permit parallel testing, use a PID-dependent socket. - kw['unix_socket'] = '/tmp/waitress.test-%d.sock' % os.getpid() - super(FixtureUnixWSGIServer, self).__init__(application, **kw) - queue.put(self.socket.getsockname()) - - class UnixTests(SubprocessTests): - - server = FixtureUnixWSGIServer - - def make_http_connection(self): - return UnixHTTPConnection(self.bound_to) - - def stop_subprocess(self): - super(UnixTests, self).stop_subprocess() - cleanup_unix_socket(self.bound_to) - - def send_check_error(self, to_send): - # Unlike inet domain sockets, Unix domain sockets can trigger a - # 'Broken pipe' error when the socket it closed. - try: - self.sock.send(to_send) - except socket.error as exc: - self.assertEqual(get_errno(exc), errno.EPIPE) - - class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): - pass - - class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): - pass - - class UnixExpectContinueTests( - ExpectContinueTests, UnixTests, unittest.TestCase): - pass - - class UnixBadContentLengthTests( - BadContentLengthTests, UnixTests, unittest.TestCase): - pass - - class UnixNoContentLengthTests( - NoContentLengthTests, UnixTests, unittest.TestCase): - pass - - class UnixWriteCallbackTests( - WriteCallbackTests, UnixTests, unittest.TestCase): - pass - - class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): - pass - - class UnixInternalServerErrorTests( - InternalServerErrorTests, UnixTests, unittest.TestCase): - pass - - class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): - pass - -def parse_headers(fp): - """Parses only RFC2822 headers from a file pointer. - """ - headers = {} - while True: - line = fp.readline() - if line in (b'\r\n', b'\n', b''): - break - line = line.decode('iso-8859-1') - name, value = line.strip().split(':', 1) - headers[name.lower().strip()] = value.lower().strip() - return headers - -class UnixHTTPConnection(httplib.HTTPConnection): - """Patched version of HTTPConnection that uses Unix domain sockets. - """ - - def __init__(self, path): - httplib.HTTPConnection.__init__(self, 'localhost') - self.path = path - - def connect(self): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(self.path) - self.sock = sock - -class ConnectionClosed(Exception): - pass - -# stolen from gevent -def read_http(fp): # pragma: no cover - try: - response_line = fp.readline() - except socket.error as exc: - fp.close() - # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET - if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): - raise ConnectionClosed - raise - if not response_line: - raise ConnectionClosed - - header_lines = [] - while True: - line = fp.readline() - if line in (b'\r\n', b'\n', b''): - break - else: - header_lines.append(line) - headers = dict() - for x in header_lines: - x = x.strip() - if not x: - continue - key, value = x.split(b': ', 1) - key = key.decode('iso-8859-1').lower() - value = value.decode('iso-8859-1') - assert key not in headers, "%s header duplicated" % key - headers[key] = value - - if 'content-length' in headers: - num = int(headers['content-length']) - body = b'' - left = num - while left > 0: - data = fp.read(left) - if not data: - break - body += data - left -= len(data) - else: - # read until EOF - body = fp.read() - - return response_line, headers, body - -# stolen from gevent -def get_errno(exc): # pragma: no cover - """ Get the error code out of socket.error objects. - socket.error in <2.5 does not have errno attribute - socket.error in 3.x does not allow indexing access - e.args[0] works for all. - There are cases when args[0] is not errno. - i.e. http://bugs.python.org/issue6471 - Maybe there are cases when errno is set, but it is not the first argument? - """ - try: - if exc.errno is not None: - return exc.errno - except AttributeError: - pass - try: - return exc.args[0] - except IndexError: - return None - -def chunks(l, n): - """ Yield successive n-sized chunks from l. - """ - for i in range(0, len(l), n): - yield l[i:i + n] diff --git a/waitress/tests/test_parser.py b/waitress/tests/test_parser.py deleted file mode 100644 index ecb66060..00000000 --- a/waitress/tests/test_parser.py +++ /dev/null @@ -1,452 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser tests -""" -import unittest - -from waitress.compat import ( - text_, - tobytes, -) - -class TestHTTPRequestParser(unittest.TestCase): - - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def test_get_body_stream_None(self): - self.parser.body_recv = None - result = self.parser.get_body_stream() - self.assertEqual(result.getvalue(), b'') - - def test_get_body_stream_nonNone(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - result = self.parser.get_body_stream() - self.assertEqual(result, body_rcv) - - def test_received_nonsense_with_double_cr(self): - data = b"""\ -HTTP/1.0 GET /foobar - - -""" - result = self.parser.received(data) - self.assertEqual(result, 22) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_bad_host_header(self): - from waitress.utilities import BadRequest - data = b"""\ -HTTP/1.0 GET /foobar - Host: foo - - -""" - result = self.parser.received(data) - self.assertEqual(result, 33) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.error.__class__, BadRequest) - - def test_received_nonsense_nothing(self): - data = b"""\ - - -""" - result = self.parser.received(data) - self.assertEqual(result, 2) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_no_doublecr(self): - data = b"""\ -GET /foobar HTTP/8.4 -""" - result = self.parser.received(data) - self.assertEqual(result, 21) - self.assertFalse(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_already_completed(self): - self.parser.completed = True - result = self.parser.received(b'a') - self.assertEqual(result, 0) - - def test_received_cl_too_large(self): - from waitress.utilities import RequestEntityTooLarge - self.parser.adj.max_request_body_size = 2 - data = b"""\ -GET /foobar HTTP/8.4 -Content-Length: 10 - -""" - result = self.parser.received(data) - self.assertEqual(result, 41) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) - - def test_received_headers_too_large(self): - from waitress.utilities import RequestHeaderFieldsTooLarge - self.parser.adj.max_request_header_size = 2 - data = b"""\ -GET /foobar HTTP/8.4 -X-Foo: 1 -""" - result = self.parser.received(data) - self.assertEqual(result, 30) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, - RequestHeaderFieldsTooLarge)) - - def test_received_body_too_large(self): - from waitress.utilities import RequestEntityTooLarge - self.parser.adj.max_request_body_size = 2 - data = b"""\ -GET /foobar HTTP/1.1 -Transfer-Encoding: chunked -X-Foo: 1 - -20;\r\n -This string has 32 characters\r\n -0\r\n\r\n""" - result = self.parser.received(data) - self.assertEqual(result, 58) - self.parser.received(data[result:]) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, - RequestEntityTooLarge)) - - def test_received_error_from_parser(self): - from waitress.utilities import BadRequest - data = b"""\ -GET /foobar HTTP/1.1 -Transfer-Encoding: chunked -X-Foo: 1 - -garbage -""" - # header - result = self.parser.received(data) - # body - result = self.parser.received(data[result:]) - self.assertEqual(result, 8) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, - BadRequest)) - - def test_received_chunked_completed_sets_content_length(self): - data = b"""\ -GET /foobar HTTP/1.1 -Transfer-Encoding: chunked -X-Foo: 1 - -20;\r\n -This string has 32 characters\r\n -0\r\n\r\n""" - result = self.parser.received(data) - self.assertEqual(result, 58) - data = data[result:] - result = self.parser.received(data) - self.assertTrue(self.parser.completed) - self.assertTrue(self.parser.error is None) - self.assertEqual(self.parser.headers['CONTENT_LENGTH'], '32') - - def test_parse_header_gardenpath(self): - data = b"""\ -GET /foobar HTTP/8.4 -foo: bar""" - self.parser.parse_header(data) - self.assertEqual(self.parser.first_line, b'GET /foobar HTTP/8.4') - self.assertEqual(self.parser.headers['FOO'], 'bar') - - def test_parse_header_no_cr_in_headerplus(self): - data = b"GET /foobar HTTP/8.4" - self.parser.parse_header(data) - self.assertEqual(self.parser.first_line, data) - - def test_parse_header_bad_content_length(self): - data = b"GET /foobar HTTP/8.4\ncontent-length: abc" - self.parser.parse_header(data) - self.assertEqual(self.parser.body_rcv, None) - - def test_parse_header_11_te_chunked(self): - # NB: test that capitalization of header value is unimportant - data = b"GET /foobar HTTP/1.1\ntransfer-encoding: ChUnKed" - self.parser.parse_header(data) - self.assertEqual(self.parser.body_rcv.__class__.__name__, - 'ChunkedReceiver') - - def test_parse_header_11_expect_continue(self): - data = b"GET /foobar HTTP/1.1\nexpect: 100-continue" - self.parser.parse_header(data) - self.assertEqual(self.parser.expect_continue, True) - - def test_parse_header_connection_close(self): - data = b"GET /foobar HTTP/1.1\nConnection: close\n\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.connection_close, True) - - def test_close_with_body_rcv(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - self.parser.close() - self.assertTrue(body_rcv.closed) - - def test_close_with_no_body_rcv(self): - self.parser.body_rcv = None - self.parser.close() # doesn't raise - -class Test_split_uri(unittest.TestCase): - - def _callFUT(self, uri): - from waitress.parser import split_uri - (self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, self.fragment) = split_uri(uri) - - def test_split_uri_unquoting_unneeded(self): - self._callFUT(b'http://localhost:8080/abc def') - self.assertEqual(self.path, '/abc def') - - def test_split_uri_unquoting_needed(self): - self._callFUT(b'http://localhost:8080/abc%20def') - self.assertEqual(self.path, '/abc def') - - def test_split_url_with_query(self): - self._callFUT(b'http://localhost:8080/abc?a=1&b=2') - self.assertEqual(self.path, '/abc') - self.assertEqual(self.query, 'a=1&b=2') - - def test_split_url_with_query_empty(self): - self._callFUT(b'http://localhost:8080/abc?') - self.assertEqual(self.path, '/abc') - self.assertEqual(self.query, '') - - def test_split_url_with_fragment(self): - self._callFUT(b'http://localhost:8080/#foo') - self.assertEqual(self.path, '/') - self.assertEqual(self.fragment, 'foo') - - def test_split_url_https(self): - self._callFUT(b'https://localhost:8080/') - self.assertEqual(self.path, '/') - self.assertEqual(self.proxy_scheme, 'https') - self.assertEqual(self.proxy_netloc, 'localhost:8080') - - def test_split_uri_unicode_error_raises_parsing_error(self): - # See https://github.com/Pylons/waitress/issues/64 - from waitress.parser import ParsingError - # Either pass or throw a ParsingError, just don't throw another type of - # exception as that will cause the connection to close badly: - try: - self._callFUT(b'/\xd0') - except ParsingError: - pass - -class Test_get_header_lines(unittest.TestCase): - - def _callFUT(self, data): - from waitress.parser import get_header_lines - return get_header_lines(data) - - def test_get_header_lines(self): - result = self._callFUT(b'slam\nslim') - self.assertEqual(result, [b'slam', b'slim']) - - def test_get_header_lines_folded(self): - # From RFC2616: - # HTTP/1.1 header field values can be folded onto multiple lines if the - # continuation line begins with a space or horizontal tab. All linear - # white space, including folding, has the same semantics as SP. A - # recipient MAY replace any linear white space with a single SP before - # interpreting the field value or forwarding the message downstream. - - # We are just preserving the whitespace that indicates folding. - result = self._callFUT(b'slim\n slam') - self.assertEqual(result, [b'slim slam']) - - def test_get_header_lines_tabbed(self): - result = self._callFUT(b'slam\n\tslim') - self.assertEqual(result, [b'slam\tslim']) - - def test_get_header_lines_malformed(self): - # http://corte.si/posts/code/pathod/pythonservers/index.html - from waitress.parser import ParsingError - self.assertRaises(ParsingError, - self._callFUT, b' Host: localhost\r\n\r\n') - -class Test_crack_first_line(unittest.TestCase): - - def _callFUT(self, line): - from waitress.parser import crack_first_line - return crack_first_line(line) - - def test_crack_first_line_matchok(self): - result = self._callFUT(b'GET / HTTP/1.0') - self.assertEqual(result, (b'GET', b'/', b'1.0')) - - def test_crack_first_line_lowercase_method(self): - from waitress.parser import ParsingError - self.assertRaises(ParsingError, self._callFUT, b'get / HTTP/1.0') - - def test_crack_first_line_nomatch(self): - result = self._callFUT(b'GET / bleh') - self.assertEqual(result, (b'', b'', b'')) - - def test_crack_first_line_missing_version(self): - result = self._callFUT(b'GET /') - self.assertEqual(result, (b'GET', b'/', None)) - -class TestHTTPRequestParserIntegration(unittest.TestCase): - - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def feed(self, data): - parser = self.parser - for n in range(100): # make sure we never loop forever - consumed = parser.received(data) - data = data[consumed:] - if parser.completed: - return - raise ValueError('Looping') # pragma: no cover - - def testSimpleGET(self): - data = b"""\ -GET /foobar HTTP/8.4 -FirstName: mickey -lastname: Mouse -content-length: 7 - -Hello. -""" - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, '8.4') - self.assertFalse(parser.empty) - self.assertEqual(parser.headers, - {'FIRSTNAME': 'mickey', - 'LASTNAME': 'Mouse', - 'CONTENT_LENGTH': '7', - }) - self.assertEqual(parser.path, '/foobar') - self.assertEqual(parser.command, 'GET') - self.assertEqual(parser.query, '') - self.assertEqual(parser.proxy_scheme, '') - self.assertEqual(parser.proxy_netloc, '') - self.assertEqual(parser.get_body_stream().getvalue(), b'Hello.\n') - - def testComplexGET(self): - data = b"""\ -GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4 -FirstName: mickey -lastname: Mouse -content-length: 10 - -Hello mickey. -""" - parser = self.parser - self.feed(data) - self.assertEqual(parser.command, 'GET') - self.assertEqual(parser.version, '8.4') - self.assertFalse(parser.empty) - self.assertEqual(parser.headers, - {'FIRSTNAME': 'mickey', - 'LASTNAME': 'Mouse', - 'CONTENT_LENGTH': '10', - }) - # path should be utf-8 encoded - self.assertEqual(tobytes(parser.path).decode('utf-8'), - text_(b'/foo/a++/\xc3\xa4=&a:int', 'utf-8')) - self.assertEqual(parser.query, - 'd=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6') - self.assertEqual(parser.get_body_stream().getvalue(), b'Hello mick') - - def testProxyGET(self): - data = b"""\ -GET https://example.com:8080/foobar HTTP/8.4 -content-length: 7 - -Hello. -""" - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, '8.4') - self.assertFalse(parser.empty) - self.assertEqual(parser.headers, - {'CONTENT_LENGTH': '7', - }) - self.assertEqual(parser.path, '/foobar') - self.assertEqual(parser.command, 'GET') - self.assertEqual(parser.proxy_scheme, 'https') - self.assertEqual(parser.proxy_netloc, 'example.com:8080') - self.assertEqual(parser.command, 'GET') - self.assertEqual(parser.query, '') - self.assertEqual(parser.get_body_stream().getvalue(), b'Hello.\n') - - def testDuplicateHeaders(self): - # Ensure that headers with the same key get concatenated as per - # RFC2616. - data = b"""\ -GET /foobar HTTP/8.4 -x-forwarded-for: 10.11.12.13 -x-forwarded-for: unknown,127.0.0.1 -X-Forwarded_for: 255.255.255.255 -content-length: 7 - -Hello. -""" - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, { - 'CONTENT_LENGTH': '7', - 'X_FORWARDED_FOR': - '10.11.12.13, unknown,127.0.0.1', - }) - - def testSpoofedHeadersDropped(self): - data = b"""\ -GET /foobar HTTP/8.4 -x-auth_user: bob -content-length: 7 - -Hello. -""" - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, { - 'CONTENT_LENGTH': '7', - }) - - -class DummyBodyStream(object): - - def getfile(self): - return self - - def getbuf(self): - return self - - def close(self): - self.closed = True diff --git a/waitress/tests/test_receiver.py b/waitress/tests/test_receiver.py deleted file mode 100644 index 707f3284..00000000 --- a/waitress/tests/test_receiver.py +++ /dev/null @@ -1,169 +0,0 @@ -import unittest - -class TestFixedStreamReceiver(unittest.TestCase): - - def _makeOne(self, cl, buf): - from waitress.receiver import FixedStreamReceiver - return FixedStreamReceiver(cl, buf) - - def test_received_remain_lt_1(self): - buf = DummyBuffer() - inst = self._makeOne(0, buf) - result = inst.received('a') - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_lte_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(1, buf) - result = inst.received('aa') - self.assertEqual(result, 1) - self.assertEqual(inst.completed, True) - self.assertEqual(inst.completed, 1) - self.assertEqual(inst.remain, 0) - self.assertEqual(buf.data, ['a']) - - def test_received_remain_gt_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - result = inst.received('aa') - self.assertEqual(result, 2) - self.assertEqual(inst.completed, False) - self.assertEqual(inst.remain, 8) - self.assertEqual(buf.data, ['aa']) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(['1', '2']) - inst = self._makeOne(10, buf) - self.assertEqual(inst.__len__(), 2) - -class TestChunkedReceiver(unittest.TestCase): - - def _makeOne(self, buf): - from waitress.receiver import ChunkedReceiver - return ChunkedReceiver(buf) - - def test_alreadycompleted(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.completed = True - result = inst.received(b'a') - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_gt_zero(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.chunk_remainder = 100 - result = inst.received(b'a') - self.assertEqual(inst.chunk_remainder, 99) - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_notfinished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b'a') - self.assertEqual(inst.control_line, b'a') - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_garbage_in_input(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b'garbage\n') - self.assertEqual(result, 8) - self.assertTrue(inst.error) - - def test_received_control_line_finished_all_chunks_not_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b'a;discard\n') - self.assertEqual(inst.control_line, b'') - self.assertEqual(inst.chunk_remainder, 10) - self.assertEqual(inst.all_chunks_received, False) - self.assertEqual(result, 10) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_all_chunks_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b'0;discard\n') - self.assertEqual(inst.control_line, b'') - self.assertEqual(inst.all_chunks_received, True) - self.assertEqual(result, 10) - self.assertEqual(inst.completed, False) - - def test_received_trailer_startswith_crlf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b'\r\n') - self.assertEqual(result, 2) - self.assertEqual(inst.completed, True) - - def test_received_trailer_startswith_lf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b'\n') - self.assertEqual(result, 1) - self.assertEqual(inst.completed, True) - - def test_received_trailer_not_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b'a') - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_trailer_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b'abc\r\n\r\n') - self.assertEqual(inst.trailer, b'abc\r\n\r\n') - self.assertEqual(result, 7) - self.assertEqual(inst.completed, True) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(['1', '2']) - inst = self._makeOne(buf) - self.assertEqual(inst.__len__(), 2) - -class DummyBuffer(object): - - def __init__(self, data=None): - if data is None: - data = [] - self.data = data - - def append(self, s): - self.data.append(s) - - def getfile(self): - return self - - def __len__(self): - return len(self.data) diff --git a/waitress/utilities.py b/waitress/utilities.py deleted file mode 100644 index 943c92fd..00000000 --- a/waitress/utilities.py +++ /dev/null @@ -1,216 +0,0 @@ -############################################################################## -# -# Copyright (c) 2004 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Utility functions -""" - -import asyncore -import errno -import logging -import os -import re -import stat -import time -import calendar - -logger = logging.getLogger('waitress') - -def find_double_newline(s): - """Returns the position just after a double newline in the given string.""" - pos1 = s.find(b'\n\r\n') # One kind of double newline - if pos1 >= 0: - pos1 += 3 - pos2 = s.find(b'\n\n') # Another kind of double newline - if pos2 >= 0: - pos2 += 2 - - if pos1 >= 0: - if pos2 >= 0: - return min(pos1, pos2) - else: - return pos1 - else: - return pos2 - -def concat(*args): - return ''.join(args) - -def join(seq, field=' '): - return field.join(seq) - -def group(s): - return '(' + s + ')' - -short_days = ['sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat'] -long_days = ['sunday', 'monday', 'tuesday', 'wednesday', - 'thursday', 'friday', 'saturday'] - -short_day_reg = group(join(short_days, '|')) -long_day_reg = group(join(long_days, '|')) - -daymap = {} -for i in range(7): - daymap[short_days[i]] = i - daymap[long_days[i]] = i - -hms_reg = join(3 * [group('[0-9][0-9]')], ':') - -months = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', - 'aug', 'sep', 'oct', 'nov', 'dec'] - -monmap = {} -for i in range(12): - monmap[months[i]] = i + 1 - -months_reg = group(join(months, '|')) - -# From draft-ietf-http-v11-spec-07.txt/3.3.1 -# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 -# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 -# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format - -# rfc822 format -rfc822_date = join( - [concat(short_day_reg, ','), # day - group('[0-9][0-9]?'), # date - months_reg, # month - group('[0-9]+'), # year - hms_reg, # hour minute second - 'gmt' - ], - ' ' -) - -rfc822_reg = re.compile(rfc822_date) - -def unpack_rfc822(m): - g = m.group - return ( - int(g(4)), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0, - ) - -# rfc850 format -rfc850_date = join( - [concat(long_day_reg, ','), - join( - [group('[0-9][0-9]?'), - months_reg, - group('[0-9]+') - ], - '-' - ), - hms_reg, - 'gmt' - ], - ' ' -) - -rfc850_reg = re.compile(rfc850_date) -# they actually unpack the same way -def unpack_rfc850(m): - g = m.group - yr = g(4) - if len(yr) == 2: - yr = '19' + yr - return ( - int(yr), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0 - ) - -# parsdate.parsedate - ~700/sec. -# parse_http_date - ~1333/sec. - -weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] -monthname = [None, 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - -def build_http_date(when): - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) - return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - weekdayname[wd], - day, monthname[month], year, - hh, mm, ss) - -def parse_http_date(d): - d = d.lower() - m = rfc850_reg.match(d) - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc850(m))) - else: - m = rfc822_reg.match(d) - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc822(m))) - else: - return 0 - return retval - -class logging_dispatcher(asyncore.dispatcher): - logger = logger - - def log_info(self, message, type='info'): - severity = { - 'info': logging.INFO, - 'warning': logging.WARN, - 'error': logging.ERROR, - } - self.logger.log(severity.get(type, logging.INFO), message) - -def cleanup_unix_socket(path): - try: - st = os.stat(path) - except OSError as exc: - if exc.errno != errno.ENOENT: - raise # pragma: no cover - else: - if stat.S_ISSOCK(st.st_mode): - try: - os.remove(path) - except OSError: # pragma: no cover - # avoid race condition error during tests - pass - -class Error(object): - - def __init__(self, body): - self.body = body - -class BadRequest(Error): - code = 400 - reason = 'Bad Request' - -class RequestHeaderFieldsTooLarge(BadRequest): - code = 431 - reason = 'Request Header Fields Too Large' - -class RequestEntityTooLarge(BadRequest): - code = 413 - reason = 'Request Entity Too Large' - -class InternalServerError(Error): - code = 500 - reason = 'Internal Server Error'