Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Updated the wsgiref library + test
  • Loading branch information
terryluan12 committed Jan 3, 2026
commit 4c011e6374e385050e8f9fcd16c9f35d4c5626ca
67 changes: 22 additions & 45 deletions Lib/test/test_wsgiref.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import mock
from test import support
from test.support import warnings_helper
from test.support import socket_helper
from test.test_httpservers import NoLogRequestHandler
from unittest import TestCase
from wsgiref.util import setup_testing_defaults
Expand Down Expand Up @@ -80,41 +80,26 @@ def run_amock(app=hello_app, data=b"GET / HTTP/1.0\n\n"):

return out.getvalue(), err.getvalue()

def compare_generic_iter(make_it,match):
"""Utility to compare a generic 2.1/2.2+ iterator with an iterable

If running under Python 2.2+, this tests the iterator using iter()/next(),
as well as __getitem__. 'make_it' must be a function returning a fresh
def compare_generic_iter(make_it, match):
"""Utility to compare a generic iterator with an iterable

This tests the iterator using iter()/next().
'make_it' must be a function returning a fresh
iterator to be tested (since this may test the iterator twice)."""

it = make_it()
n = 0
if not iter(it) is it:
raise AssertionError
for item in match:
if not it[n]==item: raise AssertionError
n+=1
try:
it[n]
except IndexError:
pass
else:
raise AssertionError("Too many items from __getitem__",it)

if not next(it) == item:
raise AssertionError
try:
iter, StopIteration
except NameError:
next(it)
except StopIteration:
pass
else:
# Only test iter mode under 2.2+
it = make_it()
if not iter(it) is it: raise AssertionError
for item in match:
if not next(it) == item: raise AssertionError
try:
next(it)
except StopIteration:
pass
else:
raise AssertionError("Too many items from .__next__()", it)
raise AssertionError("Too many items from .__next__()", it)


class IntegrationTests(TestCase):
Expand Down Expand Up @@ -152,7 +137,7 @@ def test_environ(self):
def test_request_length(self):
out, err = run_amock(data=b"GET " + (b"x" * 65537) + b" HTTP/1.0\n\n")
self.assertEqual(out.splitlines()[0],
b"HTTP/1.0 414 Request-URI Too Long")
b"HTTP/1.0 414 URI Too Long")

def test_validated_hello(self):
out, err = run_amock(validator(hello_app))
Expand Down Expand Up @@ -264,7 +249,7 @@ def app(environ, start_response):
class WsgiHandler(NoLogRequestHandler, WSGIRequestHandler):
pass

server = make_server(support.HOST, 0, app, handler_class=WsgiHandler)
server = make_server(socket_helper.HOST, 0, app, handler_class=WsgiHandler)
self.addCleanup(server.server_close)
interrupted = threading.Event()

Expand Down Expand Up @@ -339,7 +324,6 @@ def checkReqURI(self,uri,query=1,**kw):
util.setup_testing_defaults(kw)
self.assertEqual(util.request_uri(kw,query),uri)

@warnings_helper.ignore_warnings(category=DeprecationWarning)
def checkFW(self,text,size,match):

def make_it(text=text,size=size):
Expand All @@ -358,15 +342,6 @@ def make_it(text=text,size=size):
it.close()
self.assertTrue(it.filelike.closed)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_filewrapper_getitem_deprecation(self):
wrapper = util.FileWrapper(StringIO('foobar'), 3)
with self.assertWarnsRegex(DeprecationWarning,
r'Use iterator protocol instead'):
# This should have returned 'bar'.
self.assertEqual(wrapper[1], 'foo')

def testSimpleShifts(self):
self.checkShift('','/', '', '/', '')
self.checkShift('','/x', 'x', '/x', '')
Expand Down Expand Up @@ -473,6 +448,10 @@ def testHopByHop(self):
for alt in hop, hop.title(), hop.upper(), hop.lower():
self.assertFalse(util.is_hop_by_hop(alt))

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_filewrapper_getitem_deprecation(self):
return super().test_filewrapper_getitem_deprecation()

class HeaderTests(TestCase):

def testMappingInterface(self):
Expand Down Expand Up @@ -581,7 +560,7 @@ def testEnviron(self):
# Test handler.environ as a dict
expected = {}
setup_testing_defaults(expected)
# Handler inherits os_environ variables which are not overriden
# Handler inherits os_environ variables which are not overridden
# by SimpleHandler.add_cgi_vars() (SimpleHandler.base_env)
for key, value in os_environ.items():
if key not in expected:
Expand Down Expand Up @@ -821,8 +800,7 @@ def flush(self):
b"Hello, world!",
written)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def testClientConnectionTerminations(self):
environ = {"SERVER_PROTOCOL": "HTTP/1.0"}
for exception in (
Expand All @@ -841,8 +819,7 @@ def write(self, b):

self.assertFalse(stderr.getvalue())

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def testDontResetInternalStateOnException(self):
class CustomException(ValueError):
pass
Expand Down
2 changes: 2 additions & 0 deletions Lib/wsgiref/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* validate -- validation wrapper that sits between an app and a server
to detect errors in either

* types -- collection of WSGI-related types for static type checking

To-Do:

* cgi_gateway -- Run WSGI apps under CGI (pending a deployment standard)
Expand Down
38 changes: 27 additions & 11 deletions Lib/wsgiref/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def run(self, application):
self.setup_environ()
self.result = application(self.environ, self.start_response)
self.finish_response()
except (ConnectionAbortedError, BrokenPipeError, ConnectionResetError):
# We expect the client to close the connection abruptly from time
# to time.
return
except:
try:
self.handle_error()
Expand Down Expand Up @@ -179,7 +183,16 @@ def finish_response(self):
for data in self.result:
self.write(data)
self.finish_content()
finally:
except:
# Call close() on the iterable returned by the WSGI application
# in case of an exception.
if hasattr(self.result, 'close'):
self.result.close()
raise
else:
# We only call close() when no exception is raised, because it
# will set status, result, headers, and environ fields to None.
# See bpo-29183 for more details.
self.close()


Expand Down Expand Up @@ -215,8 +228,7 @@ def start_response(self, status, headers,exc_info=None):
if exc_info:
try:
if self.headers_sent:
# Re-raise original exception if headers sent
raise exc_info[0](exc_info[1]).with_traceback(exc_info[2])
raise
finally:
exc_info = None # avoid dangling circular ref
elif self.headers is not None:
Expand All @@ -225,18 +237,25 @@ def start_response(self, status, headers,exc_info=None):
self.status = status
self.headers = self.headers_class(headers)
status = self._convert_string_type(status, "Status")
assert len(status)>=4,"Status must be at least 4 characters"
assert status[:3].isdigit(), "Status message must begin w/3-digit code"
assert status[3]==" ", "Status message must have a space after code"
self._validate_status(status)

if __debug__:
for name, val in headers:
name = self._convert_string_type(name, "Header name")
val = self._convert_string_type(val, "Header value")
assert not is_hop_by_hop(name),"Hop-by-hop headers not allowed"
assert not is_hop_by_hop(name),\
f"Hop-by-hop header, '{name}: {val}', not allowed"

return self.write

def _validate_status(self, status):
if len(status) < 4:
raise AssertionError("Status must be at least 4 characters")
if not status[:3].isdigit():
raise AssertionError("Status message must begin w/3-digit code")
if status[3] != " ":
raise AssertionError("Status message must have a space after code")

def _convert_string_type(self, value, title):
"""Convert/check value type."""
if type(value) is str:
Expand Down Expand Up @@ -456,10 +475,7 @@ def _write(self,data):
from warnings import warn
warn("SimpleHandler.stdout.write() should not do partial writes",
DeprecationWarning)
while True:
data = data[result:]
if not data:
break
while data := data[result:]:
result = self.stdout.write(data)

def _flush(self):
Expand Down
7 changes: 2 additions & 5 deletions Lib/wsgiref/simple_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def get_environ(self):

env['PATH_INFO'] = urllib.parse.unquote(path, 'iso-8859-1')
env['QUERY_STRING'] = query

host = self.address_string()
if host != self.client_address[0]:
env['REMOTE_HOST'] = host
env['REMOTE_ADDR'] = self.client_address[0]

if self.headers.get('content-type') is None:
Expand Down Expand Up @@ -127,7 +123,8 @@ def handle(self):
return

handler = ServerHandler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
self.rfile, self.wfile, self.get_stderr(), self.get_environ(),
multithread=False,
)
handler.request_handler = self # backpointer for logging
handler.run(self.server.get_app())
Expand Down
54 changes: 54 additions & 0 deletions Lib/wsgiref/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""WSGI-related types for static type checking"""

from collections.abc import Callable, Iterable, Iterator
from types import TracebackType
from typing import Any, Protocol, TypeAlias

__all__ = [
"StartResponse",
"WSGIEnvironment",
"WSGIApplication",
"InputStream",
"ErrorStream",
"FileWrapper",
]

_ExcInfo: TypeAlias = tuple[type[BaseException], BaseException, TracebackType]
_OptExcInfo: TypeAlias = _ExcInfo | tuple[None, None, None]

class StartResponse(Protocol):
"""start_response() callable as defined in PEP 3333"""
def __call__(
self,
status: str,
headers: list[tuple[str, str]],
exc_info: _OptExcInfo | None = ...,
/,
) -> Callable[[bytes], object]: ...

WSGIEnvironment: TypeAlias = dict[str, Any]
WSGIApplication: TypeAlias = Callable[[WSGIEnvironment, StartResponse],
Iterable[bytes]]

class InputStream(Protocol):
"""WSGI input stream as defined in PEP 3333"""
def read(self, size: int = ..., /) -> bytes: ...
def readline(self, size: int = ..., /) -> bytes: ...
def readlines(self, hint: int = ..., /) -> list[bytes]: ...
def __iter__(self) -> Iterator[bytes]: ...

class ErrorStream(Protocol):
"""WSGI error stream as defined in PEP 3333"""
def flush(self) -> object: ...
def write(self, s: str, /) -> object: ...
def writelines(self, seq: list[str], /) -> object: ...

class _Readable(Protocol):
def read(self, size: int = ..., /) -> bytes: ...
# Optional: def close(self) -> object: ...

class FileWrapper(Protocol):
"""WSGI file wrapper as defined in PEP 3333"""
def __call__(
self, file: _Readable, block_size: int = ..., /,
) -> Iterable[bytes]: ...
14 changes: 4 additions & 10 deletions Lib/wsgiref/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__all__ = [
'FileWrapper', 'guess_scheme', 'application_uri', 'request_uri',
'shift_path_info', 'setup_testing_defaults',
'shift_path_info', 'setup_testing_defaults', 'is_hop_by_hop',
]


Expand All @@ -17,12 +17,6 @@ def __init__(self, filelike, blksize=8192):
if hasattr(filelike,'close'):
self.close = filelike.close

def __getitem__(self,key):
data = self.filelike.read(self.blksize)
if data:
return data
raise IndexError

def __iter__(self):
return self

Expand Down Expand Up @@ -155,9 +149,9 @@ def setup_testing_defaults(environ):


_hoppish = {
'connection':1, 'keep-alive':1, 'proxy-authenticate':1,
'proxy-authorization':1, 'te':1, 'trailers':1, 'transfer-encoding':1,
'upgrade':1
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'transfer-encoding',
'upgrade'
}.__contains__

def is_hop_by_hop(header_name):
Expand Down
15 changes: 5 additions & 10 deletions Lib/wsgiref/validate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
# Also licenced under the Apache License, 2.0: http://opensource.org/licenses/apache2.0.php
# Licensed under the MIT license: https://opensource.org/licenses/mit-license.php
# Also licenced under the Apache License, 2.0: https://opensource.org/licenses/apache2.0.php
# Licensed to PSF under a Contributor Agreement
"""
Middleware to check for obedience to the WSGI specification.
Expand Down Expand Up @@ -77,7 +77,7 @@

* That wsgi.input is used properly:

- .read() is called with zero or one argument
- .read() is called with exactly one argument

- That it returns a string

Expand Down Expand Up @@ -137,7 +137,7 @@ def validator(application):

"""
When applied between a WSGI server and a WSGI application, this
middleware will check for WSGI compliancy on a number of levels.
middleware will check for WSGI compliance on a number of levels.
This middleware does not modify the request or response in any
way, but will raise an AssertionError if anything seems off
(except for a failure to close the application iterator, which
Expand Down Expand Up @@ -214,10 +214,7 @@ def readlines(self, *args):
return lines

def __iter__(self):
while 1:
line = self.readline()
if not line:
return
while line := self.readline():
yield line

def close(self):
Expand Down Expand Up @@ -390,7 +387,6 @@ def check_headers(headers):
assert_(type(headers) is list,
"Headers (%r) must be of type list: %r"
% (headers, type(headers)))
header_names = {}
for item in headers:
assert_(type(item) is tuple,
"Individual headers (%r) must be of type tuple: %r"
Expand All @@ -403,7 +399,6 @@ def check_headers(headers):
"The Status header cannot be used; it conflicts with CGI "
"script, and HTTP status is not given through headers "
"(value: %r)." % value)
header_names[name.lower()] = None
assert_('\n' not in name and ':' not in name,
"Header names may not contain ':' or '\\n': %r" % name)
assert_(header_re.search(name), "Bad header name: %r" % name)
Expand Down