diff --git a/Lib/http/__init__.py b/Lib/http/__init__.py index d4334cc88f9..bf8d7d68868 100644 --- a/Lib/http/__init__.py +++ b/Lib/http/__init__.py @@ -2,6 +2,7 @@ __all__ = ['HTTPStatus'] + class HTTPStatus(IntEnum): """HTTP status codes and reason phrases @@ -15,6 +16,11 @@ class HTTPStatus(IntEnum): * RFC 7238: Permanent Redirect * RFC 2295: Transparent Content Negotiation in HTTP * RFC 2774: An HTTP Extension Framework + * RFC 7725: An HTTP Status Code to Report Legal Obstacles + * RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2) + * RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0) + * RFC 8297: An HTTP Status Code for Indicating Hints + * RFC 8470: Using Early Data in HTTP """ def __new__(cls, value, phrase, description=''): obj = int.__new__(cls, value) @@ -29,6 +35,7 @@ def __new__(cls, value, phrase, description=''): SWITCHING_PROTOCOLS = (101, 'Switching Protocols', 'Switching to new protocol; obey Upgrade header') PROCESSING = 102, 'Processing' + EARLY_HINTS = 103, 'Early Hints' # success OK = 200, 'OK', 'Request fulfilled, document follows' @@ -58,7 +65,7 @@ def __new__(cls, value, phrase, description=''): TEMPORARY_REDIRECT = (307, 'Temporary Redirect', 'Object moved temporarily -- see URI list') PERMANENT_REDIRECT = (308, 'Permanent Redirect', - 'Object moved temporarily -- see URI list') + 'Object moved permanently -- see URI list') # client error BAD_REQUEST = (400, 'Bad Request', @@ -98,9 +105,14 @@ def __new__(cls, value, phrase, description=''): 'Cannot satisfy request range') EXPECTATION_FAILED = (417, 'Expectation Failed', 'Expect condition could not be satisfied') + IM_A_TEAPOT = (418, 'I\'m a Teapot', + 'Server refuses to brew coffee because it is a teapot.') + MISDIRECTED_REQUEST = (421, 'Misdirected Request', + 'Server is not able to produce a response') UNPROCESSABLE_ENTITY = 422, 'Unprocessable Entity' LOCKED = 423, 'Locked' FAILED_DEPENDENCY = 424, 'Failed Dependency' + TOO_EARLY = 425, 'Too Early' UPGRADE_REQUIRED = 426, 'Upgrade Required' PRECONDITION_REQUIRED = (428, 'Precondition Required', 'The origin server requires the request to be conditional') @@ -111,6 +123,10 @@ def __new__(cls, value, phrase, description=''): 'Request Header Fields Too Large', 'The server is unwilling to process the request because its header ' 'fields are too large') + UNAVAILABLE_FOR_LEGAL_REASONS = (451, + 'Unavailable For Legal Reasons', + 'The server is denying access to the ' + 'resource as a consequence of a legal demand') # server errors INTERNAL_SERVER_ERROR = (500, 'Internal Server Error', diff --git a/Lib/http/client.py b/Lib/http/client.py index a8e59b95616..a6ab135b2c3 100644 --- a/Lib/http/client.py +++ b/Lib/http/client.py @@ -70,12 +70,13 @@ import email.parser import email.message +import errno import http import io -import os import re import socket -import collections +import sys +import collections.abc from urllib.parse import urlsplit # HTTPMessage, parse_headers(), and the HTTP status code constants are @@ -106,9 +107,6 @@ # Mapping status codes to official W3C names responses = {v: v.phrase for v in http.HTTPStatus.__members__.values()} -# maximal amount of data to read at one time in _safe_read -MAXAMOUNT = 1048576 - # maximal line length when calling readline(). _MAXLINE = 65536 _MAXHEADERS = 100 @@ -141,6 +139,20 @@ _is_legal_header_name = re.compile(rb'[^:\s][^:\r\n]*').fullmatch _is_illegal_header_value = re.compile(rb'\n(?![ \t])|\r(?![ \t\n])').search +# These characters are not allowed within HTTP URL paths. +# See https://tools.ietf.org/html/rfc3986#section-3.3 and the +# https://tools.ietf.org/html/rfc3986#appendix-A pchar definition. +# Prevents CVE-2019-9740. Includes control characters such as \r\n. +# We don't restrict chars above \x7f as putrequest() limits us to ASCII. +_contains_disallowed_url_pchar_re = re.compile('[\x00-\x20\x7f]') +# Arguably only these _should_ allowed: +# _is_allowed_url_pchars_re = re.compile(r"^[/!$&'()*+,;=:@%a-zA-Z0-9._~-]+$") +# We are more lenient for assumed real world compatibility purposes. + +# These characters are not allowed within HTTP method names +# to prevent http header injection. +_contains_disallowed_method_pchar_re = re.compile('[\x00-\x1f]') + # We always set the Content-Length header for these methods because some # servers will otherwise respond with a 411 _METHODS_EXPECTING_BODY = {'PATCH', 'POST', 'PUT'} @@ -191,15 +203,11 @@ def getallmatchingheaders(self, name): lst.append(line) return lst -def parse_headers(fp, _class=HTTPMessage): - """Parses only RFC2822 headers from a file pointer. - - email Parser wants to see strings rather than bytes. - But a TextIOWrapper around self.rfile would buffer too many bytes - from the stream, bytes which we later need to read as bytes. - So we read the correct bytes here, as bytes, for email Parser - to parse. +def _read_headers(fp): + """Reads potential header lines into a list from a file pointer. + Length of line is limited by _MAXLINE, and number of + headers is limited by _MAXHEADERS. """ headers = [] while True: @@ -211,6 +219,19 @@ def parse_headers(fp, _class=HTTPMessage): raise HTTPException("got more than %d headers" % _MAXHEADERS) if line in (b'\r\n', b'\n', b''): break + return headers + +def parse_headers(fp, _class=HTTPMessage): + """Parses only RFC2822 headers from a file pointer. + + email Parser wants to see strings rather than bytes. + But a TextIOWrapper around self.rfile would buffer too many bytes + from the stream, bytes which we later need to read as bytes. + So we read the correct bytes here, as bytes, for email Parser + to parse. + + """ + headers = _read_headers(fp) hstring = b''.join(headers).decode('iso-8859-1') return email.parser.Parser(_class=_class).parsestr(hstring) @@ -298,15 +319,10 @@ def begin(self): if status != CONTINUE: break # skip the header from the 100 response - while True: - skip = self.fp.readline(_MAXLINE + 1) - if len(skip) > _MAXLINE: - raise LineTooLong("header line") - skip = skip.strip() - if not skip: - break - if self.debuglevel > 0: - print("header:", skip) + skipped_headers = _read_headers(self.fp) + if self.debuglevel > 0: + print("headers:", skipped_headers) + del skipped_headers self.code = self.status = status self.reason = reason.strip() @@ -321,8 +337,8 @@ def begin(self): self.headers = self.msg = parse_headers(self.fp) if self.debuglevel > 0: - for hdr in self.headers: - print("header:", hdr, end=" ") + for hdr, val in self.headers.items(): + print("header:", hdr + ":", val) # are we using the chunked-style of transfer encoding? tr_enc = self.headers.get("transfer-encoding") @@ -339,9 +355,6 @@ def begin(self): # NOTE: RFC 2616, S4.4, #3 says we ignore this if tr_enc is "chunked" self.length = None length = self.headers.get("content-length") - - # are we using the chunked-style of transfer encoding? - tr_enc = self.headers.get("transfer-encoding") if length and not self.chunked: try: self.length = int(length) @@ -372,7 +385,6 @@ def _check_close(self): if self.version == 11: # An HTTP/1.1 proxy is assumed to stay open unless # explicitly closed. - conn = self.headers.get("connection") if conn and "close" in conn.lower(): return True return False @@ -443,18 +455,25 @@ def read(self, amt=None): self._close_conn() return b"" + if self.chunked: + return self._read_chunked(amt) + if amt is not None: - # Amount is given, implement using readinto - b = bytearray(amt) - n = self.readinto(b) - return memoryview(b)[:n].tobytes() + if self.length is not None and amt > self.length: + # clip the read to the "end of response" + amt = self.length + s = self.fp.read(amt) + if not s and amt: + # Ideally, we would raise IncompleteRead if the content-length + # wasn't satisfied, but it might break compatibility. + self._close_conn() + elif self.length is not None: + self.length -= len(s) + if not self.length: + self._close_conn() + return s else: # Amount is not given (unbounded read) so we must check self.length - # and self.chunked - - if self.chunked: - return self._readall_chunked() - if self.length is None: s = self.fp.read() else: @@ -540,7 +559,7 @@ def _get_chunk_left(self): chunk_left = self.chunk_left if not chunk_left: # Can be 0 or None if chunk_left is not None: - # We are at the end of chunk. dicard chunk end + # We are at the end of chunk, discard chunk end self._safe_read(2) # toss the CRLF at the end of the chunk try: chunk_left = self._read_next_chunk_size() @@ -555,7 +574,7 @@ def _get_chunk_left(self): self.chunk_left = chunk_left return chunk_left - def _readall_chunked(self): + def _read_chunked(self, amt=None): assert self.chunked != _UNKNOWN value = [] try: @@ -563,7 +582,15 @@ def _readall_chunked(self): chunk_left = self._get_chunk_left() if chunk_left is None: break + + if amt is not None and amt <= chunk_left: + value.append(self._safe_read(amt)) + self.chunk_left = chunk_left - amt + break + value.append(self._safe_read(chunk_left)) + if amt is not None: + amt -= chunk_left self.chunk_left = 0 return b''.join(value) except IncompleteRead: @@ -594,43 +621,24 @@ def _readinto_chunked(self, b): raise IncompleteRead(bytes(b[0:total_bytes])) def _safe_read(self, amt): - """Read the number of bytes requested, compensating for partial reads. - - Normally, we have a blocking socket, but a read() can be interrupted - by a signal (resulting in a partial read). - - Note that we cannot distinguish between EOF and an interrupt when zero - bytes have been read. IncompleteRead() will be raised in this - situation. + """Read the number of bytes requested. This function should be used when bytes "should" be present for reading. If the bytes are truly not available (due to EOF), then the IncompleteRead exception can be used to detect the problem. """ - s = [] - while amt > 0: - chunk = self.fp.read(min(amt, MAXAMOUNT)) - if not chunk: - raise IncompleteRead(b''.join(s), amt) - s.append(chunk) - amt -= len(chunk) - return b"".join(s) + data = self.fp.read(amt) + if len(data) < amt: + raise IncompleteRead(data, amt-len(data)) + return data def _safe_readinto(self, b): """Same as _safe_read, but for reading into a buffer.""" - total_bytes = 0 - mvb = memoryview(b) - while total_bytes < len(b): - if MAXAMOUNT < len(mvb): - temp_mvb = mvb[0:MAXAMOUNT] - n = self.fp.readinto(temp_mvb) - else: - n = self.fp.readinto(mvb) - if not n: - raise IncompleteRead(bytes(mvb[0:total_bytes]), len(b)) - mvb = mvb[n:] - total_bytes += n - return total_bytes + amt = len(b) + n = self.fp.readinto(b) + if n < amt: + raise IncompleteRead(bytes(b[:n]), amt-n) + return n def read1(self, n=-1): """Read with at most one underlying system call. If at least one @@ -642,14 +650,7 @@ def read1(self, n=-1): return self._read1_chunked(n) if self.length is not None and (n < 0 or n > self.length): n = self.length - try: - result = self.fp.read1(n) - except ValueError: - if n >= 0: - raise - # some implementations, like BufferedReader, don't support -1 - # Read an arbitrarily selected largeish chunk. - result = self.fp.read1(16*1024) + result = self.fp.read1(n) if not result and n: self._close_conn() elif self.length is not None: @@ -834,9 +835,10 @@ def _get_content_length(body, method): return None def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None): + source_address=None, blocksize=8192): self.timeout = timeout self.source_address = source_address + self.blocksize = blocksize self.sock = None self._buffer = [] self.__response = None @@ -848,6 +850,8 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, (self.host, self.port) = self._get_hostport(host, port) + self._validate_host(self.host) + # This is stored as an instance variable to allow unit # tests to replace it with a suitable mockup self._create_connection = socket.create_connection @@ -860,7 +864,7 @@ def set_tunnel(self, host, port=None, headers=None): the endpoint passed to `set_tunnel`. This done by sending an HTTP CONNECT request to the proxy server when the connection is established. - This method must be called before the HTML connection has been + This method must be called before the HTTP connection has been established. The headers argument should be a mapping of extra HTTP headers to send @@ -900,23 +904,24 @@ def set_debuglevel(self, level): self.debuglevel = level def _tunnel(self): - connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host, - self._tunnel_port) - connect_bytes = connect_str.encode("ascii") - self.send(connect_bytes) + connect = b"CONNECT %s:%d HTTP/1.0\r\n" % ( + self._tunnel_host.encode("ascii"), self._tunnel_port) + headers = [connect] for header, value in self._tunnel_headers.items(): - header_str = "%s: %s\r\n" % (header, value) - header_bytes = header_str.encode("latin-1") - self.send(header_bytes) - self.send(b'\r\n') + headers.append(f"{header}: {value}\r\n".encode("latin-1")) + headers.append(b"\r\n") + # Making a single send() call instead of one per line encourages + # the host OS to use a more optimal packet size instead of + # potentially emitting a series of small packets. + self.send(b"".join(headers)) + del headers response = self.response_class(self.sock, method=self._method) (version, code, message) = response._read_status() if code != http.HTTPStatus.OK: self.close() - raise OSError("Tunnel connection failed: %d %s" % (code, - message.strip())) + raise OSError(f"Tunnel connection failed: {code} {message.strip()}") while True: line = response.fp.readline(_MAXLINE + 1) if len(line) > _MAXLINE: @@ -932,9 +937,15 @@ def _tunnel(self): def connect(self): """Connect to the host and port specified in __init__.""" + sys.audit("http.client.connect", self, self.host, self.port) self.sock = self._create_connection( (self.host,self.port), self.timeout, self.source_address) - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # Might fail in OSs that don't implement TCP_NODELAY + try: + self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + except OSError as e: + if e.errno != errno.ENOPROTOOPT: + raise if self._tunnel_host: self._tunnel() @@ -967,7 +978,6 @@ def send(self, data): if self.debuglevel > 0: print("send:", repr(data)) - blocksize = 8192 if hasattr(data, "read") : if self.debuglevel > 0: print("sendIng a read()able") @@ -975,17 +985,19 @@ def send(self, data): if encode and self.debuglevel > 0: print("encoding file using iso-8859-1") while 1: - datablock = data.read(blocksize) + datablock = data.read(self.blocksize) if not datablock: break if encode: datablock = datablock.encode("iso-8859-1") + sys.audit("http.client.send", self, datablock) self.sock.sendall(datablock) return + sys.audit("http.client.send", self, data) try: self.sock.sendall(data) except TypeError: - if isinstance(data, collections.Iterable): + if isinstance(data, collections.abc.Iterable): for d in data: self.sock.sendall(d) else: @@ -1000,14 +1012,13 @@ def _output(self, s): self._buffer.append(s) def _read_readable(self, readable): - blocksize = 8192 if self.debuglevel > 0: print("sendIng a read()able") encode = self._is_textIO(readable) if encode and self.debuglevel > 0: print("encoding file using iso-8859-1") while True: - datablock = readable.read(blocksize) + datablock = readable.read(self.blocksize) if not datablock: break if encode: @@ -1107,14 +1118,17 @@ def putrequest(self, method, url, skip_host=False, else: raise CannotSendRequest(self.__state) - # Save the method we use, we need it later in the response phase + self._validate_method(method) + + # Save the method for use later in the response phase self._method = method - if not url: - url = '/' + + url = url or '/' + self._validate_path(url) + request = '%s %s %s' % (method, url, self._http_vsn_str) - # Non-ASCII characters should have been eliminated earlier - self._output(request.encode('ascii')) + self._output(self._encode_request(request)) if self._http_vsn == 11: # Issue some standard headers for better HTTP/1.1 compliance @@ -1192,6 +1206,35 @@ def putrequest(self, method, url, skip_host=False, # For HTTP/1.0, the server will assume "not chunked" pass + def _encode_request(self, request): + # ASCII also helps prevent CVE-2019-9740. + return request.encode('ascii') + + def _validate_method(self, method): + """Validate a method name for putrequest.""" + # prevent http header injection + match = _contains_disallowed_method_pchar_re.search(method) + if match: + raise ValueError( + f"method can't contain control characters. {method!r} " + f"(found at least {match.group()!r})") + + def _validate_path(self, url): + """Validate a url for putrequest.""" + # Prevent CVE-2019-9740. + match = _contains_disallowed_url_pchar_re.search(url) + if match: + raise InvalidURL(f"URL can't contain control characters. {url!r} " + f"(found at least {match.group()!r})") + + def _validate_host(self, host): + """Validate a host so it doesn't contain control characters.""" + # Prevent CVE-2019-18348. + match = _contains_disallowed_url_pchar_re.search(host) + if match: + raise InvalidURL(f"URL can't contain control characters. {host!r} " + f"(found at least {match.group()!r})") + def putheader(self, header, *values): """Send a request header line to the server. @@ -1362,9 +1405,10 @@ class HTTPSConnection(HTTPConnection): def __init__(self, host, port=None, key_file=None, cert_file=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None, *, context=None, - check_hostname=None): + check_hostname=None, blocksize=8192): super(HTTPSConnection, self).__init__(host, port, timeout, - source_address) + source_address, + blocksize=blocksize) if (key_file is not None or cert_file is not None or check_hostname is not None): import warnings @@ -1375,6 +1419,12 @@ def __init__(self, host, port=None, key_file=None, cert_file=None, self.cert_file = cert_file if context is None: context = ssl._create_default_https_context() + # send ALPN extension to indicate HTTP/1.1 protocol + if self._http_vsn == 11: + context.set_alpn_protocols(['http/1.1']) + # enable PHA for TLS 1.3 connections if available + if context.post_handshake_auth is not None: + context.post_handshake_auth = True will_verify = context.verify_mode != ssl.CERT_NONE if check_hostname is None: check_hostname = context.check_hostname @@ -1383,8 +1433,13 @@ def __init__(self, host, port=None, key_file=None, cert_file=None, "either CERT_OPTIONAL or CERT_REQUIRED") if key_file or cert_file: context.load_cert_chain(cert_file, key_file) + # cert and key file means the user wants to authenticate. + # enable TLS 1.3 PHA implicitly even for custom contexts. + if context.post_handshake_auth is not None: + context.post_handshake_auth = True self._context = context - self._check_hostname = check_hostname + if check_hostname is not None: + self._context.check_hostname = check_hostname def connect(self): "Connect to a host on a given (SSL) port." @@ -1398,13 +1453,6 @@ def connect(self): self.sock = self._context.wrap_socket(self.sock, server_hostname=server_hostname) - if not self._context.check_hostname and self._check_hostname: - try: - ssl.match_hostname(self.sock.getpeercert(), server_hostname) - except Exception: - self.sock.shutdown(socket.SHUT_RDWR) - self.sock.close() - raise __all__.append("HTTPSConnection") @@ -1442,8 +1490,7 @@ def __repr__(self): e = '' return '%s(%i bytes read%s)' % (self.__class__.__name__, len(self.partial), e) - def __str__(self): - return repr(self) + __str__ = object.__str__ class ImproperConnectionState(HTTPException): pass diff --git a/Lib/http/cookiejar.py b/Lib/http/cookiejar.py index adf956d66a0..685f6a0b976 100644 --- a/Lib/http/cookiejar.py +++ b/Lib/http/cookiejar.py @@ -28,6 +28,7 @@ __all__ = ['Cookie', 'CookieJar', 'CookiePolicy', 'DefaultCookiePolicy', 'FileCookieJar', 'LWPCookieJar', 'LoadError', 'MozillaCookieJar'] +import os import copy import datetime import re @@ -52,10 +53,18 @@ def _debug(*args): logger = logging.getLogger("http.cookiejar") return logger.debug(*args) - +HTTPONLY_ATTR = "HTTPOnly" +HTTPONLY_PREFIX = "#HttpOnly_" DEFAULT_HTTP_PORT = str(http.client.HTTP_PORT) +NETSCAPE_MAGIC_RGX = re.compile("#( Netscape)? HTTP Cookie File") MISSING_FILENAME_TEXT = ("a filename was not supplied (nor was the CookieJar " "instance initialised with one)") +NETSCAPE_HEADER_TEXT = """\ +# Netscape HTTP Cookie File +# http://curl.haxx.se/rfc/cookie_spec.html +# This is a generated file! Do not edit. + +""" def _warn_unhandled_exception(): # There are a few catch-all except: statements in this module, for @@ -216,10 +225,14 @@ def _str2time(day, mon, yr, hr, min, sec, tz): (?::(\d\d))? # optional seconds )? # optional clock \s* - ([-+]?\d{2,4}|(?![APap][Mm]\b)[A-Za-z]+)? # timezone + (?: + ([-+]?\d{2,4}|(?![APap][Mm]\b)[A-Za-z]+) # timezone + \s* + )? + (?: + \(\w+\) # ASCII representation of timezone in parens. \s* - (?:\(\w+\))? # ASCII representation of timezone in parens. - \s*$""", re.X | re.ASCII) + )?$""", re.X | re.ASCII) def http2time(text): """Returns time in seconds since epoch of time represented by a string. @@ -289,9 +302,11 @@ def http2time(text): (?::?(\d\d(?:\.\d*)?))? # optional seconds (and fractional) )? # optional clock \s* - ([-+]?\d\d?:?(:?\d\d)? - |Z|z)? # timezone (Z is "zero meridian", i.e. GMT) - \s*$""", re.X | re. ASCII) + (?: + ([-+]?\d\d?:?(:?\d\d)? + |Z|z) # timezone (Z is "zero meridian", i.e. GMT) + \s* + )?$""", re.X | re. ASCII) def iso2time(text): """ As for http2time, but parses the ISO 8601 formats: @@ -881,6 +896,7 @@ def __init__(self, strict_ns_domain=DomainLiberal, strict_ns_set_initial_dollar=False, strict_ns_set_path=False, + secure_protocols=("https", "wss") ): """Constructor arguments should be passed as keyword arguments only.""" self.netscape = netscape @@ -893,6 +909,7 @@ def __init__(self, self.strict_ns_domain = strict_ns_domain self.strict_ns_set_initial_dollar = strict_ns_set_initial_dollar self.strict_ns_set_path = strict_ns_set_path + self.secure_protocols = secure_protocols if blocked_domains is not None: self._blocked_domains = tuple(blocked_domains) @@ -993,7 +1010,7 @@ def set_ok_path(self, cookie, request): req_path = request_path(request) if ((cookie.version > 0 or (cookie.version == 0 and self.strict_ns_set_path)) and - not req_path.startswith(cookie.path)): + not self.path_return_ok(cookie.path, request)): _debug(" path attribute %s is not a prefix of request " "path %s", cookie.path, req_path) return False @@ -1119,7 +1136,7 @@ def return_ok_verifiability(self, cookie, request): return True def return_ok_secure(self, cookie, request): - if cookie.secure and request.type != "https": + if cookie.secure and request.type not in self.secure_protocols: _debug(" secure cookie with non-secure request") return False return True @@ -1148,6 +1165,11 @@ def return_ok_domain(self, cookie, request): req_host, erhn = eff_request_host(request) domain = cookie.domain + if domain and not domain.startswith("."): + dotdomain = "." + domain + else: + dotdomain = domain + # strict check of non-domain cookies: Mozilla does this, MSIE5 doesn't if (cookie.version == 0 and (self.strict_ns_domain & self.DomainStrictNonDomain) and @@ -1160,7 +1182,7 @@ def return_ok_domain(self, cookie, request): _debug(" effective request-host name %s does not domain-match " "RFC 2965 cookie domain %s", erhn, domain) return False - if cookie.version == 0 and not ("."+erhn).endswith(domain): + if cookie.version == 0 and not ("."+erhn).endswith(dotdomain): _debug(" request-host %s does not match Netscape cookie domain " "%s", req_host, domain) return False @@ -1174,7 +1196,11 @@ def domain_return_ok(self, domain, request): req_host = "."+req_host if not erhn.startswith("."): erhn = "."+erhn - if not (req_host.endswith(domain) or erhn.endswith(domain)): + if domain and not domain.startswith("."): + dotdomain = "." + domain + else: + dotdomain = domain + if not (req_host.endswith(dotdomain) or erhn.endswith(dotdomain)): #_debug(" request domain %s does not match cookie domain %s", # req_host, domain) return False @@ -1191,11 +1217,15 @@ def domain_return_ok(self, domain, request): def path_return_ok(self, path, request): _debug("- checking cookie path=%s", path) req_path = request_path(request) - if not req_path.startswith(path): - _debug(" %s does not path-match %s", req_path, path) - return False - return True + pathlen = len(path) + if req_path == path: + return True + elif (req_path.startswith(path) and + (path.endswith("/") or req_path[pathlen:pathlen+1] == "/")): + return True + _debug(" %s does not path-match %s", req_path, path) + return False def vals_sorted_by_key(adict): keys = sorted(adict.keys()) @@ -1580,6 +1610,7 @@ def make_cookies(self, response, request): headers = response.info() rfc2965_hdrs = headers.get_all("Set-Cookie2", []) ns_hdrs = headers.get_all("Set-Cookie", []) + self._policy._now = self._now = int(time.time()) rfc2965 = self._policy.rfc2965 netscape = self._policy.netscape @@ -1659,8 +1690,6 @@ def extract_cookies(self, response, request): _debug("extract_cookies: %s", response.info()) self._cookies_lock.acquire() try: - self._policy._now = self._now = int(time.time()) - for cookie in self.make_cookies(response, request): if self._policy.set_ok(cookie, request): _debug(" setting cookie: %s", cookie) @@ -1763,10 +1792,7 @@ def __init__(self, filename=None, delayload=False, policy=None): """ CookieJar.__init__(self, policy) if filename is not None: - try: - filename+"" - except: - raise ValueError("filename must be string-like") + filename = os.fspath(filename) self.filename = filename self.delayload = bool(delayload) @@ -1989,19 +2015,11 @@ class MozillaCookieJar(FileCookieJar): header by default (Mozilla can cope with that). """ - magic_re = re.compile("#( Netscape)? HTTP Cookie File") - header = """\ -# Netscape HTTP Cookie File -# http://curl.haxx.se/rfc/cookie_spec.html -# This is a generated file! Do not edit. - -""" def _really_load(self, f, filename, ignore_discard, ignore_expires): now = time.time() - magic = f.readline() - if not self.magic_re.search(magic): + if not NETSCAPE_MAGIC_RGX.match(f.readline()): raise LoadError( "%r does not look like a Netscape format cookies file" % filename) @@ -2009,8 +2027,17 @@ def _really_load(self, f, filename, ignore_discard, ignore_expires): try: while 1: line = f.readline() + rest = {} + if line == "": break + # httponly is a cookie flag as defined in rfc6265 + # when encoded in a netscape cookie file, + # the line is prepended with "#HttpOnly_" + if line.startswith(HTTPONLY_PREFIX): + rest[HTTPONLY_ATTR] = "" + line = line[len(HTTPONLY_PREFIX):] + # last field may be absent, so keep any trailing tab if line.endswith("\n"): line = line[:-1] @@ -2048,7 +2075,7 @@ def _really_load(self, f, filename, ignore_discard, ignore_expires): discard, None, None, - {}) + rest) if not ignore_discard and c.discard: continue if not ignore_expires and c.is_expired(now): @@ -2068,16 +2095,17 @@ def save(self, filename=None, ignore_discard=False, ignore_expires=False): else: raise ValueError(MISSING_FILENAME_TEXT) with open(filename, "w") as f: - f.write(self.header) + f.write(NETSCAPE_HEADER_TEXT) now = time.time() for cookie in self: + domain = cookie.domain if not ignore_discard and cookie.discard: continue if not ignore_expires and cookie.is_expired(now): continue if cookie.secure: secure = "TRUE" else: secure = "FALSE" - if cookie.domain.startswith("."): initial_dot = "TRUE" + if domain.startswith("."): initial_dot = "TRUE" else: initial_dot = "FALSE" if cookie.expires is not None: expires = str(cookie.expires) @@ -2092,7 +2120,9 @@ def save(self, filename=None, ignore_discard=False, ignore_expires=False): else: name = cookie.name value = cookie.value + if cookie.has_nonstandard_attr(HTTPONLY_ATTR): + domain = HTTPONLY_PREFIX + domain f.write( - "\t".join([cookie.domain, initial_dot, cookie.path, + "\t".join([domain, initial_dot, cookie.path, secure, expires, name, value])+ "\n") diff --git a/Lib/http/cookies.py b/Lib/http/cookies.py index be3b080aa3d..35ac2dc6ae2 100644 --- a/Lib/http/cookies.py +++ b/Lib/http/cookies.py @@ -131,6 +131,7 @@ # import re import string +import types __all__ = ["CookieError", "BaseCookie", "SimpleCookie"] @@ -138,12 +139,6 @@ _semispacejoin = '; '.join _spacejoin = ' '.join -def _warn_deprecated_setter(setter): - import warnings - msg = ('The .%s setter is deprecated. The attribute will be read-only in ' - 'future releases. Please use the set() method instead.' % setter) - warnings.warn(msg, DeprecationWarning, stacklevel=3) - # # Define an exception visible to External modules # @@ -262,8 +257,7 @@ class Morsel(dict): In a cookie, each such pair may have several attributes, so this class is used to keep the attributes associated with the appropriate key,value pair. This class also includes a coded_value attribute, which is used to hold - the network representation of the value. This is most useful when Python - objects are pickled for network transit. + the network representation of the value. """ # RFC 2109 lists these attributes as reserved: # path comment domain @@ -287,6 +281,7 @@ class Morsel(dict): "secure" : "Secure", "httponly" : "HttpOnly", "version" : "Version", + "samesite" : "SameSite", } _flags = {'secure', 'httponly'} @@ -303,29 +298,14 @@ def __init__(self): def key(self): return self._key - @key.setter - def key(self, key): - _warn_deprecated_setter('key') - self._key = key - @property def value(self): return self._value - @value.setter - def value(self, value): - _warn_deprecated_setter('value') - self._value = value - @property def coded_value(self): return self._coded_value - @coded_value.setter - def coded_value(self, coded_value): - _warn_deprecated_setter('coded_value') - self._coded_value = coded_value - def __setitem__(self, K, V): K = K.lower() if not K in self._reserved: @@ -366,14 +346,7 @@ def update(self, values): def isReservedKey(self, K): return K.lower() in self._reserved - def set(self, key, val, coded_val, LegalChars=_LegalChars): - if LegalChars != _LegalChars: - import warnings - warnings.warn( - 'LegalChars parameter is deprecated, ignored and will ' - 'be removed in future versions.', DeprecationWarning, - stacklevel=2) - + def set(self, key, val, coded_val): if key.lower() in self._reserved: raise CookieError('Attempt to set a reserved key %r' % (key,)) if not _is_legal_key(key): @@ -436,6 +409,8 @@ def OutputString(self, attrs=None): append("%s=%s" % (self._reserved[key], _getdate(value))) elif key == "max-age" and isinstance(value, int): append("%s=%d" % (self._reserved[key], value)) + elif key == "comment" and isinstance(value, str): + append("%s=%s" % (self._reserved[key], _quote(value))) elif key in self._flags: if value: append(str(self._reserved[key])) @@ -445,6 +420,8 @@ def OutputString(self, attrs=None): # Return the result return _semispacejoin(result) + __class_getitem__ = classmethod(types.GenericAlias) + # # Pattern for finding cookie diff --git a/Lib/http/server.py b/Lib/http/server.py index e12e45bfc38..58abadf7377 100644 --- a/Lib/http/server.py +++ b/Lib/http/server.py @@ -83,10 +83,12 @@ __version__ = "0.6" __all__ = [ - "HTTPServer", "BaseHTTPRequestHandler", + "HTTPServer", "ThreadingHTTPServer", "BaseHTTPRequestHandler", "SimpleHTTPRequestHandler", "CGIHTTPRequestHandler", ] +import copy +import datetime import email.utils import html import http.client @@ -101,8 +103,6 @@ import sys import time import urllib.parse -import copy -import argparse from http import HTTPStatus @@ -139,6 +139,10 @@ def server_bind(self): self.server_port = port +class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer): + daemon_threads = True + + class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): """HTTP request handler base class. @@ -267,8 +271,8 @@ def parse_request(self): are in self.command, self.path, self.request_version and self.headers. - Return True for success, False for failure; on failure, an - error is sent back. + Return True for success, False for failure; on failure, any relevant + error response has already been sent back. """ self.command = None # set in case of error on the first line @@ -278,10 +282,13 @@ def parse_request(self): requestline = requestline.rstrip('\r\n') self.requestline = requestline words = requestline.split() - if len(words) == 3: - command, path, version = words + if len(words) == 0: + return False + + if len(words) >= 3: # Enough to determine protocol version + version = words[-1] try: - if version[:5] != 'HTTP/': + if not version.startswith('HTTP/'): raise ValueError base_version_number = version.split('/', 1)[1] version_number = base_version_number.split(".") @@ -306,22 +313,22 @@ def parse_request(self): HTTPStatus.HTTP_VERSION_NOT_SUPPORTED, "Invalid HTTP version (%s)" % base_version_number) return False - elif len(words) == 2: - command, path = words + self.request_version = version + + if not 2 <= len(words) <= 3: + self.send_error( + HTTPStatus.BAD_REQUEST, + "Bad request syntax (%r)" % requestline) + return False + command, path = words[:2] + if len(words) == 2: self.close_connection = True if command != 'GET': self.send_error( HTTPStatus.BAD_REQUEST, "Bad HTTP/0.9 request type (%r)" % command) return False - elif not words: - return False - else: - self.send_error( - HTTPStatus.BAD_REQUEST, - "Bad request syntax (%r)" % requestline) - return False - self.command, self.path, self.request_version = command, path, version + self.command, self.path = command, path # Examine the headers and look for a Connection directive. try: @@ -405,7 +412,7 @@ def handle_one_request(self): method = getattr(self, mname) method() self.wfile.flush() #actually send the response if not already done. - except socket.timeout as e: + except TimeoutError as e: #a read or a write timed out. Discard this connection self.log_error("Request timed out: %r", e) self.close_connection = True @@ -466,7 +473,7 @@ def send_error(self, code, message=None, explain=None): }) body = content.encode('UTF-8', 'replace') self.send_header("Content-Type", self.error_content_type) - self.send_header('Content-Length', int(len(body))) + self.send_header('Content-Length', str(len(body))) self.end_headers() if self.command != 'HEAD' and body: @@ -630,6 +637,18 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): """ server_version = "SimpleHTTP/" + __version__ + extensions_map = _encodings_map_default = { + '.gz': 'application/gzip', + '.Z': 'application/octet-stream', + '.bz2': 'application/x-bzip2', + '.xz': 'application/x-xz', + } + + def __init__(self, *args, directory=None, **kwargs): + if directory is None: + directory = os.getcwd() + self.directory = os.fspath(directory) + super().__init__(*args, **kwargs) def do_GET(self): """Serve a GET request.""" @@ -668,6 +687,7 @@ def send_head(self): parts[3], parts[4]) new_url = urllib.parse.urlunsplit(new_parts) self.send_header("Location", new_url) + self.send_header("Content-Length", "0") self.end_headers() return None for index in "index.html", "index.htm": @@ -678,17 +698,55 @@ def send_head(self): else: return self.list_directory(path) ctype = self.guess_type(path) + # check for trailing "/" which should return 404. See Issue17324 + # The test for this was added in test_httpserver.py + # However, some OS platforms accept a trailingSlash as a filename + # See discussion on python-dev and Issue34711 regarding + # parseing and rejection of filenames with a trailing slash + if path.endswith("/"): + self.send_error(HTTPStatus.NOT_FOUND, "File not found") + return None try: f = open(path, 'rb') except OSError: self.send_error(HTTPStatus.NOT_FOUND, "File not found") return None + try: + fs = os.fstat(f.fileno()) + # Use browser cache if possible + if ("If-Modified-Since" in self.headers + and "If-None-Match" not in self.headers): + # compare If-Modified-Since and time of last file modification + try: + ims = email.utils.parsedate_to_datetime( + self.headers["If-Modified-Since"]) + except (TypeError, IndexError, OverflowError, ValueError): + # ignore ill-formed values + pass + else: + if ims.tzinfo is None: + # obsolete format with no timezone, cf. + # https://tools.ietf.org/html/rfc7231#section-7.1.1.1 + ims = ims.replace(tzinfo=datetime.timezone.utc) + if ims.tzinfo is datetime.timezone.utc: + # compare to UTC datetime of last modification + last_modif = datetime.datetime.fromtimestamp( + fs.st_mtime, datetime.timezone.utc) + # remove microseconds, like in If-Modified-Since + last_modif = last_modif.replace(microsecond=0) + + if last_modif <= ims: + self.send_response(HTTPStatus.NOT_MODIFIED) + self.end_headers() + f.close() + return None + self.send_response(HTTPStatus.OK) self.send_header("Content-type", ctype) - fs = os.fstat(f.fileno()) self.send_header("Content-Length", str(fs[6])) - self.send_header("Last-Modified", self.date_time_string(fs.st_mtime)) + self.send_header("Last-Modified", + self.date_time_string(fs.st_mtime)) self.end_headers() return f except: @@ -773,7 +831,7 @@ def translate_path(self, path): path = posixpath.normpath(path) words = path.split('/') words = filter(None, words) - path = os.getcwd() + path = self.directory for word in words: if os.path.dirname(word) or word in (os.curdir, os.pardir): # Ignore components that are not a simple file/directory name @@ -813,25 +871,16 @@ def guess_type(self, path): slow) to look inside the data to make a better guess. """ - base, ext = posixpath.splitext(path) if ext in self.extensions_map: return self.extensions_map[ext] ext = ext.lower() if ext in self.extensions_map: return self.extensions_map[ext] - else: - return self.extensions_map[''] - - if not mimetypes.inited: - mimetypes.init() # try to read system mime.types - extensions_map = mimetypes.types_map.copy() - extensions_map.update({ - '': 'application/octet-stream', # Default - '.py': 'text/plain', - '.c': 'text/plain', - '.h': 'text/plain', - }) + guess, _ = mimetypes.guess_type(path) + if guess: + return guess + return 'application/octet-stream' # Utilities for CGIHTTPRequestHandler @@ -962,8 +1011,10 @@ def is_cgi(self): """ collapsed_path = _url_collapse_path(self.path) dir_sep = collapsed_path.find('/', 1) - head, tail = collapsed_path[:dir_sep], collapsed_path[dir_sep+1:] - if head in self.cgi_directories: + while dir_sep > 0 and not collapsed_path[:dir_sep] in self.cgi_directories: + dir_sep = collapsed_path.find('/', dir_sep+1) + if dir_sep > 0: + head, tail = collapsed_path[:dir_sep], collapsed_path[dir_sep+1:] self.cgi_info = head, tail return True return False @@ -1040,8 +1091,7 @@ def run_cgi(self): env['PATH_INFO'] = uqrest env['PATH_TRANSLATED'] = self.translate_path(uqrest) env['SCRIPT_NAME'] = scriptname - if query: - env['QUERY_STRING'] = query + env['QUERY_STRING'] = query env['REMOTE_ADDR'] = self.client_address[0] authorization = self.headers.get("authorization") if authorization: @@ -1071,12 +1121,7 @@ def run_cgi(self): referer = self.headers.get('referer') if referer: env['HTTP_REFERER'] = referer - accept = [] - for line in self.headers.getallmatchingheaders('accept'): - if line[:1] in "\t\n\r ": - accept.append(line.strip()) - else: - accept = accept + line[7:].split(',') + accept = self.headers.get_all('accept', ()) env['HTTP_ACCEPT'] = ','.join(accept) ua = self.headers.get('user-agent') if ua: @@ -1112,8 +1157,9 @@ def run_cgi(self): while select.select([self.rfile], [], [], 0)[0]: if not self.rfile.read(1): break - if sts: - self.log_error("CGI script exit status %#x", sts) + exitcode = os.waitstatus_to_exitcode(sts) + if exitcode: + self.log_error(f"CGI script exit code {exitcode}") return # Child try: @@ -1172,20 +1218,33 @@ def run_cgi(self): self.log_message("CGI script exited OK") +def _get_best_family(*address): + infos = socket.getaddrinfo( + *address, + type=socket.SOCK_STREAM, + flags=socket.AI_PASSIVE, + ) + family, type, proto, canonname, sockaddr = next(iter(infos)) + return family, sockaddr + + def test(HandlerClass=BaseHTTPRequestHandler, - ServerClass=HTTPServer, protocol="HTTP/1.0", port=8000, bind=""): + ServerClass=ThreadingHTTPServer, + protocol="HTTP/1.0", port=8000, bind=None): """Test the HTTP request handler class. This runs an HTTP server on port 8000 (or the port argument). """ - server_address = (bind, port) - + ServerClass.address_family, addr = _get_best_family(bind, port) HandlerClass.protocol_version = protocol - with ServerClass(server_address, HandlerClass) as httpd: - sa = httpd.socket.getsockname() - serve_message = "Serving HTTP on {host} port {port} (http://{host}:{port}/) ..." - print(serve_message.format(host=sa[0], port=sa[1])) + with ServerClass(addr, HandlerClass) as httpd: + host, port = httpd.socket.getsockname()[:2] + url_host = f'[{host}]' if ':' in host else host + print( + f"Serving HTTP on {host} port {port} " + f"(http://{url_host}:{port}/) ..." + ) try: httpd.serve_forever() except KeyboardInterrupt: @@ -1193,19 +1252,44 @@ def test(HandlerClass=BaseHTTPRequestHandler, sys.exit(0) if __name__ == '__main__': + import argparse + import contextlib + parser = argparse.ArgumentParser() parser.add_argument('--cgi', action='store_true', - help='Run as CGI Server') - parser.add_argument('--bind', '-b', default='', metavar='ADDRESS', - help='Specify alternate bind address ' - '[default: all interfaces]') - parser.add_argument('port', action='store', - default=8000, type=int, + help='run as CGI server') + parser.add_argument('--bind', '-b', metavar='ADDRESS', + help='specify alternate bind address ' + '(default: all interfaces)') + parser.add_argument('--directory', '-d', default=os.getcwd(), + help='specify alternate directory ' + '(default: current directory)') + parser.add_argument('port', action='store', default=8000, type=int, nargs='?', - help='Specify alternate port [default: 8000]') + help='specify alternate port (default: 8000)') args = parser.parse_args() if args.cgi: handler_class = CGIHTTPRequestHandler else: handler_class = SimpleHTTPRequestHandler - test(HandlerClass=handler_class, port=args.port, bind=args.bind) + + # ensure dual-stack is not disabled; ref #38907 + class DualStackServer(ThreadingHTTPServer): + + def server_bind(self): + # suppress exception when protocol is IPv4 + with contextlib.suppress(Exception): + self.socket.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + return super().server_bind() + + def finish_request(self, request, client_address): + self.RequestHandlerClass(request, client_address, self, + directory=args.directory) + + test( + HandlerClass=handler_class, + ServerClass=DualStackServer, + port=args.port, + bind=args.bind, + ) diff --git a/Lib/test/test_http_cookiejar.py b/Lib/test/test_http_cookiejar.py index 1a7c3e0e975..ba594079cd8 100644 --- a/Lib/test/test_http_cookiejar.py +++ b/Lib/test/test_http_cookiejar.py @@ -4,6 +4,7 @@ import re import test.support from test.support import os_helper +from test.support import warnings_helper import time import unittest import urllib.request @@ -335,8 +336,6 @@ def test_constructor_with_str(self): c = LWPCookieJar(filename) self.assertEqual(c.filename, filename) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_constructor_with_path_like(self): filename = pathlib.Path(os_helper.TESTFN) c = LWPCookieJar(filename) @@ -346,8 +345,6 @@ def test_constructor_with_none(self): c = LWPCookieJar(None) self.assertIsNone(c.filename) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_constructor_with_other_types(self): class A: pass @@ -446,8 +443,6 @@ class CookieTests(unittest.TestCase): ## just the 7 special TLD's listed in their spec. And folks rely on ## that... - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_domain_return_ok(self): # test optimization: .domain_return_ok() should filter out most # domains in the CookieJar before we try to access them (because that @@ -603,14 +598,12 @@ def test_ns_parser_special_names(self): self.assertIn('expires', cookies) self.assertIn('version', cookies) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_expires(self): # if expires is in future, keep cookie... c = CookieJar() future = time2netscape(time.time()+3600) - with test.warnings_helper.check_no_warnings(self): + with warnings_helper.check_no_warnings(self): headers = [f"Set-Cookie: FOO=BAR; path=/; expires={future}"] req = urllib.request.Request("http://www.coyote.com/") res = FakeResponse(headers, "http://www.coyote.com/") @@ -753,8 +746,6 @@ def test_request_path(self): req = urllib.request.Request("http://www.example.com") self.assertEqual(request_path(req), "/") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_path_prefix_match(self): pol = DefaultCookiePolicy() strict_ns_path_pol = DefaultCookiePolicy(strict_ns_set_path=True) @@ -1006,8 +997,6 @@ def test_domain_allow(self): c.add_cookie_header(req) self.assertFalse(req.has_header("Cookie")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_domain_block(self): pol = DefaultCookiePolicy( rfc2965=True, blocked_domains=[".acme.com"]) @@ -1098,8 +1087,6 @@ def test_secure(self): c._cookies["www.acme.com"]["/"]["foo2"].secure, "secure cookie registered non-secure") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_secure_block(self): pol = DefaultCookiePolicy() c = CookieJar(policy=pol) @@ -1128,8 +1115,6 @@ def test_secure_block(self): c.add_cookie_header(req) self.assertFalse(req.has_header("Cookie")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_custom_secure_protocols(self): pol = DefaultCookiePolicy(secure_protocols=["foos"]) c = CookieJar(policy=pol) @@ -1790,6 +1775,10 @@ def test_mozilla(self): interact_netscape(c, "http://www.foo.com/", "fooc=bar; Domain=www.foo.com; %s" % expires) + for cookie in c: + if cookie.name == "foo1": + cookie.set_nonstandard_attr("HTTPOnly", "") + def save_and_restore(cj, ignore_discard): try: cj.save(ignore_discard=ignore_discard) @@ -1804,6 +1793,7 @@ def save_and_restore(cj, ignore_discard): new_c = save_and_restore(c, True) self.assertEqual(len(new_c), 6) # none discarded self.assertIn("name='foo1', value='bar'", repr(new_c)) + self.assertIn("rest={'HTTPOnly': ''}", repr(new_c)) new_c = save_and_restore(c, False) self.assertEqual(len(new_c), 4) # 2 of them discarded on save @@ -1932,14 +1922,5 @@ def test_session_cookies(self): self.assertNotEqual(counter["session_before"], 0) -def test_main(verbose=None): - test.support.run_unittest( - DateTimeTests, - HeaderTests, - CookieTests, - FileCookieJarTests, - LWPCookieTests, - ) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_http_cookies.py b/Lib/test/test_http_cookies.py new file mode 100644 index 00000000000..6072c7e15e9 --- /dev/null +++ b/Lib/test/test_http_cookies.py @@ -0,0 +1,487 @@ +# Simple test suite for http/cookies.py + +import copy +from test.support import run_unittest, run_doctest +import unittest +from http import cookies +import pickle + + +class CookieTests(unittest.TestCase): + + def test_basic(self): + cases = [ + {'data': 'chips=ahoy; vienna=finger', + 'dict': {'chips':'ahoy', 'vienna':'finger'}, + 'repr': "", + 'output': 'Set-Cookie: chips=ahoy\nSet-Cookie: vienna=finger'}, + + {'data': 'keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"', + 'dict': {'keebler' : 'E=mc2; L="Loves"; fudge=\012;'}, + 'repr': '''''', + 'output': 'Set-Cookie: keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'}, + + # Check illegal cookies that have an '=' char in an unquoted value + {'data': 'keebler=E=mc2', + 'dict': {'keebler' : 'E=mc2'}, + 'repr': "", + 'output': 'Set-Cookie: keebler=E=mc2'}, + + # Cookies with ':' character in their name. Though not mentioned in + # RFC, servers / browsers allow it. + + {'data': 'key:term=value:term', + 'dict': {'key:term' : 'value:term'}, + 'repr': "", + 'output': 'Set-Cookie: key:term=value:term'}, + + # issue22931 - Adding '[' and ']' as valid characters in cookie + # values as defined in RFC 6265 + { + 'data': 'a=b; c=[; d=r; f=h', + 'dict': {'a':'b', 'c':'[', 'd':'r', 'f':'h'}, + 'repr': "", + 'output': '\n'.join(( + 'Set-Cookie: a=b', + 'Set-Cookie: c=[', + 'Set-Cookie: d=r', + 'Set-Cookie: f=h' + )) + } + ] + + for case in cases: + C = cookies.SimpleCookie() + C.load(case['data']) + self.assertEqual(repr(C), case['repr']) + self.assertEqual(C.output(sep='\n'), case['output']) + for k, v in sorted(case['dict'].items()): + self.assertEqual(C[k].value, v) + + def test_load(self): + C = cookies.SimpleCookie() + C.load('Customer="WILE_E_COYOTE"; Version=1; Path=/acme') + + self.assertEqual(C['Customer'].value, 'WILE_E_COYOTE') + self.assertEqual(C['Customer']['version'], '1') + self.assertEqual(C['Customer']['path'], '/acme') + + self.assertEqual(C.output(['path']), + 'Set-Cookie: Customer="WILE_E_COYOTE"; Path=/acme') + self.assertEqual(C.js_output(), r""" + + """) + self.assertEqual(C.js_output(['path']), r""" + + """) + + def test_extended_encode(self): + # Issue 9824: some browsers don't follow the standard; we now + # encode , and ; to keep them from tripping up. + C = cookies.SimpleCookie() + C['val'] = "some,funky;stuff" + self.assertEqual(C.output(['val']), + 'Set-Cookie: val="some\\054funky\\073stuff"') + + def test_special_attrs(self): + # 'expires' + C = cookies.SimpleCookie('Customer="WILE_E_COYOTE"') + C['Customer']['expires'] = 0 + # can't test exact output, it always depends on current date/time + self.assertTrue(C.output().endswith('GMT')) + + # loading 'expires' + C = cookies.SimpleCookie() + C.load('Customer="W"; expires=Wed, 01 Jan 2010 00:00:00 GMT') + self.assertEqual(C['Customer']['expires'], + 'Wed, 01 Jan 2010 00:00:00 GMT') + C = cookies.SimpleCookie() + C.load('Customer="W"; expires=Wed, 01 Jan 98 00:00:00 GMT') + self.assertEqual(C['Customer']['expires'], + 'Wed, 01 Jan 98 00:00:00 GMT') + + # 'max-age' + C = cookies.SimpleCookie('Customer="WILE_E_COYOTE"') + C['Customer']['max-age'] = 10 + self.assertEqual(C.output(), + 'Set-Cookie: Customer="WILE_E_COYOTE"; Max-Age=10') + + def test_set_secure_httponly_attrs(self): + C = cookies.SimpleCookie('Customer="WILE_E_COYOTE"') + C['Customer']['secure'] = True + C['Customer']['httponly'] = True + self.assertEqual(C.output(), + 'Set-Cookie: Customer="WILE_E_COYOTE"; HttpOnly; Secure') + + def test_samesite_attrs(self): + samesite_values = ['Strict', 'Lax', 'strict', 'lax'] + for val in samesite_values: + with self.subTest(val=val): + C = cookies.SimpleCookie('Customer="WILE_E_COYOTE"') + C['Customer']['samesite'] = val + self.assertEqual(C.output(), + 'Set-Cookie: Customer="WILE_E_COYOTE"; SameSite=%s' % val) + + C = cookies.SimpleCookie() + C.load('Customer="WILL_E_COYOTE"; SameSite=%s' % val) + self.assertEqual(C['Customer']['samesite'], val) + + def test_secure_httponly_false_if_not_present(self): + C = cookies.SimpleCookie() + C.load('eggs=scrambled; Path=/bacon') + self.assertFalse(C['eggs']['httponly']) + self.assertFalse(C['eggs']['secure']) + + def test_secure_httponly_true_if_present(self): + # Issue 16611 + C = cookies.SimpleCookie() + C.load('eggs=scrambled; httponly; secure; Path=/bacon') + self.assertTrue(C['eggs']['httponly']) + self.assertTrue(C['eggs']['secure']) + + def test_secure_httponly_true_if_have_value(self): + # This isn't really valid, but demonstrates what the current code + # is expected to do in this case. + C = cookies.SimpleCookie() + C.load('eggs=scrambled; httponly=foo; secure=bar; Path=/bacon') + self.assertTrue(C['eggs']['httponly']) + self.assertTrue(C['eggs']['secure']) + # Here is what it actually does; don't depend on this behavior. These + # checks are testing backward compatibility for issue 16611. + self.assertEqual(C['eggs']['httponly'], 'foo') + self.assertEqual(C['eggs']['secure'], 'bar') + + def test_extra_spaces(self): + C = cookies.SimpleCookie() + C.load('eggs = scrambled ; secure ; path = bar ; foo=foo ') + self.assertEqual(C.output(), + 'Set-Cookie: eggs=scrambled; Path=bar; Secure\r\nSet-Cookie: foo=foo') + + def test_quoted_meta(self): + # Try cookie with quoted meta-data + C = cookies.SimpleCookie() + C.load('Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"') + self.assertEqual(C['Customer'].value, 'WILE_E_COYOTE') + self.assertEqual(C['Customer']['version'], '1') + self.assertEqual(C['Customer']['path'], '/acme') + + self.assertEqual(C.output(['path']), + 'Set-Cookie: Customer="WILE_E_COYOTE"; Path=/acme') + self.assertEqual(C.js_output(), r""" + + """) + self.assertEqual(C.js_output(['path']), r""" + + """) + + def test_invalid_cookies(self): + # Accepting these could be a security issue + C = cookies.SimpleCookie() + for s in (']foo=x', '[foo=x', 'blah]foo=x', 'blah[foo=x', + 'Set-Cookie: foo=bar', 'Set-Cookie: foo', + 'foo=bar; baz', 'baz; foo=bar', + 'secure;foo=bar', 'Version=1;foo=bar'): + C.load(s) + self.assertEqual(dict(C), {}) + self.assertEqual(C.output(), '') + + def test_pickle(self): + rawdata = 'Customer="WILE_E_COYOTE"; Path=/acme; Version=1' + expected_output = 'Set-Cookie: %s' % rawdata + + C = cookies.SimpleCookie() + C.load(rawdata) + self.assertEqual(C.output(), expected_output) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + C1 = pickle.loads(pickle.dumps(C, protocol=proto)) + self.assertEqual(C1.output(), expected_output) + + def test_illegal_chars(self): + rawdata = "a=b; c,d=e" + C = cookies.SimpleCookie() + with self.assertRaises(cookies.CookieError): + C.load(rawdata) + + def test_comment_quoting(self): + c = cookies.SimpleCookie() + c['foo'] = '\N{COPYRIGHT SIGN}' + self.assertEqual(str(c['foo']), 'Set-Cookie: foo="\\251"') + c['foo']['comment'] = 'comment \N{COPYRIGHT SIGN}' + self.assertEqual( + str(c['foo']), + 'Set-Cookie: foo="\\251"; Comment="comment \\251"' + ) + + +class MorselTests(unittest.TestCase): + """Tests for the Morsel object.""" + + def test_defaults(self): + morsel = cookies.Morsel() + self.assertIsNone(morsel.key) + self.assertIsNone(morsel.value) + self.assertIsNone(morsel.coded_value) + self.assertEqual(morsel.keys(), cookies.Morsel._reserved.keys()) + for key, val in morsel.items(): + self.assertEqual(val, '', key) + + def test_reserved_keys(self): + M = cookies.Morsel() + # tests valid and invalid reserved keys for Morsels + for i in M._reserved: + # Test that all valid keys are reported as reserved and set them + self.assertTrue(M.isReservedKey(i)) + M[i] = '%s_value' % i + for i in M._reserved: + # Test that valid key values come out fine + self.assertEqual(M[i], '%s_value' % i) + for i in "the holy hand grenade".split(): + # Test that invalid keys raise CookieError + self.assertRaises(cookies.CookieError, + M.__setitem__, i, '%s_value' % i) + + def test_setter(self): + M = cookies.Morsel() + # tests the .set method to set keys and their values + for i in M._reserved: + # Makes sure that all reserved keys can't be set this way + self.assertRaises(cookies.CookieError, + M.set, i, '%s_value' % i, '%s_value' % i) + for i in "thou cast _the- !holy! ^hand| +*grenade~".split(): + # Try typical use case. Setting decent values. + # Check output and js_output. + M['path'] = '/foo' # Try a reserved key as well + M.set(i, "%s_val" % i, "%s_coded_val" % i) + self.assertEqual(M.key, i) + self.assertEqual(M.value, "%s_val" % i) + self.assertEqual(M.coded_value, "%s_coded_val" % i) + self.assertEqual( + M.output(), + "Set-Cookie: %s=%s; Path=/foo" % (i, "%s_coded_val" % i)) + expected_js_output = """ + + """ % (i, "%s_coded_val" % i) + self.assertEqual(M.js_output(), expected_js_output) + for i in ["foo bar", "foo@bar"]: + # Try some illegal characters + self.assertRaises(cookies.CookieError, + M.set, i, '%s_value' % i, '%s_value' % i) + + def test_set_properties(self): + morsel = cookies.Morsel() + with self.assertRaises(AttributeError): + morsel.key = '' + with self.assertRaises(AttributeError): + morsel.value = '' + with self.assertRaises(AttributeError): + morsel.coded_value = '' + + def test_eq(self): + base_case = ('key', 'value', '"value"') + attribs = { + 'path': '/', + 'comment': 'foo', + 'domain': 'example.com', + 'version': 2, + } + morsel_a = cookies.Morsel() + morsel_a.update(attribs) + morsel_a.set(*base_case) + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*base_case) + self.assertTrue(morsel_a == morsel_b) + self.assertFalse(morsel_a != morsel_b) + cases = ( + ('key', 'value', 'mismatch'), + ('key', 'mismatch', '"value"'), + ('mismatch', 'value', '"value"'), + ) + for case_b in cases: + with self.subTest(case_b): + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*case_b) + self.assertFalse(morsel_a == morsel_b) + self.assertTrue(morsel_a != morsel_b) + + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*base_case) + morsel_b['comment'] = 'bar' + self.assertFalse(morsel_a == morsel_b) + self.assertTrue(morsel_a != morsel_b) + + # test mismatched types + self.assertFalse(cookies.Morsel() == 1) + self.assertTrue(cookies.Morsel() != 1) + self.assertFalse(cookies.Morsel() == '') + self.assertTrue(cookies.Morsel() != '') + items = list(cookies.Morsel().items()) + self.assertFalse(cookies.Morsel() == items) + self.assertTrue(cookies.Morsel() != items) + + # morsel/dict + morsel = cookies.Morsel() + morsel.set(*base_case) + morsel.update(attribs) + self.assertTrue(morsel == dict(morsel)) + self.assertFalse(morsel != dict(morsel)) + + def test_copy(self): + morsel_a = cookies.Morsel() + morsel_a.set('foo', 'bar', 'baz') + morsel_a.update({ + 'version': 2, + 'comment': 'foo', + }) + morsel_b = morsel_a.copy() + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertIsNot(morsel_a, morsel_b) + self.assertEqual(morsel_a, morsel_b) + + morsel_b = copy.copy(morsel_a) + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertIsNot(morsel_a, morsel_b) + self.assertEqual(morsel_a, morsel_b) + + def test_setitem(self): + morsel = cookies.Morsel() + morsel['expires'] = 0 + self.assertEqual(morsel['expires'], 0) + morsel['Version'] = 2 + self.assertEqual(morsel['version'], 2) + morsel['DOMAIN'] = 'example.com' + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel['invalid'] = 'value' + self.assertNotIn('invalid', morsel) + + def test_setdefault(self): + morsel = cookies.Morsel() + morsel.update({ + 'domain': 'example.com', + 'version': 2, + }) + # this shouldn't override the default value + self.assertEqual(morsel.setdefault('expires', 'value'), '') + self.assertEqual(morsel['expires'], '') + self.assertEqual(morsel.setdefault('Version', 1), 2) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel.setdefault('DOMAIN', 'value'), 'example.com') + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel.setdefault('invalid', 'value') + self.assertNotIn('invalid', morsel) + + def test_update(self): + attribs = {'expires': 1, 'Version': 2, 'DOMAIN': 'example.com'} + # test dict update + morsel = cookies.Morsel() + morsel.update(attribs) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + # test iterable update + morsel = cookies.Morsel() + morsel.update(list(attribs.items())) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + # test iterator update + morsel = cookies.Morsel() + morsel.update((k, v) for k, v in attribs.items()) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel.update({'invalid': 'value'}) + self.assertNotIn('invalid', morsel) + self.assertRaises(TypeError, morsel.update) + self.assertRaises(TypeError, morsel.update, 0) + + def test_pickle(self): + morsel_a = cookies.Morsel() + morsel_a.set('foo', 'bar', 'baz') + morsel_a.update({ + 'version': 2, + 'comment': 'foo', + }) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + morsel_b = pickle.loads(pickle.dumps(morsel_a, proto)) + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertEqual(morsel_b, morsel_a) + self.assertEqual(str(morsel_b), str(morsel_a)) + + def test_repr(self): + morsel = cookies.Morsel() + self.assertEqual(repr(morsel), '') + self.assertEqual(str(morsel), 'Set-Cookie: None=None') + morsel.set('key', 'val', 'coded_val') + self.assertEqual(repr(morsel), '') + self.assertEqual(str(morsel), 'Set-Cookie: key=coded_val') + morsel.update({ + 'path': '/', + 'comment': 'foo', + 'domain': 'example.com', + 'max-age': 0, + 'secure': 0, + 'version': 1, + }) + self.assertEqual(repr(morsel), + '') + self.assertEqual(str(morsel), + 'Set-Cookie: key=coded_val; Comment=foo; Domain=example.com; ' + 'Max-Age=0; Path=/; Version=1') + morsel['secure'] = True + morsel['httponly'] = 1 + self.assertEqual(repr(morsel), + '') + self.assertEqual(str(morsel), + 'Set-Cookie: key=coded_val; Comment=foo; Domain=example.com; ' + 'HttpOnly; Max-Age=0; Path=/; Secure; Version=1') + + morsel = cookies.Morsel() + morsel.set('key', 'val', 'coded_val') + morsel['expires'] = 0 + self.assertRegex(repr(morsel), + r'') + self.assertRegex(str(morsel), + r'Set-Cookie: key=coded_val; ' + r'expires=\w+, \d+ \w+ \d+ \d+:\d+:\d+ \w+') + +def test_main(): + run_unittest(CookieTests, MorselTests) + run_doctest(cookies) + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index d5a6c54d8eb..8f095d52ac4 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -10,6 +10,7 @@ import warnings import unittest +from unittest import mock TestCase = unittest.TestCase from test import support @@ -17,6 +18,7 @@ from test.support import socket_helper from test.support import warnings_helper + here = os.path.dirname(__file__) # Self-signed cert file for 'localhost' CERT_localhost = os.path.join(here, 'keycert.pem') @@ -349,8 +351,6 @@ def test_invalid_headers(self): with self.assertRaisesRegex(ValueError, 'Invalid header'): conn.putheader(name, value) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_headers_debuglevel(self): body = ( b'HTTP/1.1 200 OK\r\n' @@ -370,8 +370,6 @@ def test_headers_debuglevel(self): class HttpMethodTests(TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_invalid_method_names(self): methods = ( 'GET\r', @@ -601,6 +599,33 @@ def test_partial_readintos(self): resp.close() self.assertTrue(resp.closed) + def test_partial_reads_past_end(self): + # if we have Content-Length, clip reads to the end + body = "HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\nText" + sock = FakeSocket(body) + resp = client.HTTPResponse(sock) + resp.begin() + self.assertEqual(resp.read(10), b'Text') + self.assertTrue(resp.isclosed()) + self.assertFalse(resp.closed) + resp.close() + self.assertTrue(resp.closed) + + def test_partial_readintos_past_end(self): + # if we have Content-Length, clip readintos to the end + body = "HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\nText" + sock = FakeSocket(body) + resp = client.HTTPResponse(sock) + resp.begin() + b = bytearray(10) + n = resp.readinto(b) + self.assertEqual(n, 4) + self.assertEqual(bytes(b)[:4], b'Text') + self.assertTrue(resp.isclosed()) + self.assertFalse(resp.closed) + resp.close() + self.assertTrue(resp.closed) + def test_partial_reads_no_content_length(self): # when no length is present, the socket should be gracefully closed when # all data was read @@ -808,8 +833,6 @@ def body(): conn.request('GET', '/foo', body(), {'Content-Length': '11'}) self.assertEqual(sock.data, expected) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_blocksize_request(self): """Check that request() respects the configured block size.""" blocksize = 8 # For easy debugging. @@ -822,8 +845,6 @@ def test_blocksize_request(self): body = sock.data.split(b"\r\n\r\n", 1)[1] self.assertEqual(body, expected) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_blocksize_send(self): """Check that send() respects the configured block size.""" blocksize = 8 # For easy debugging. @@ -1014,6 +1035,19 @@ def test_overflowing_header_line(self): resp = client.HTTPResponse(FakeSocket(body)) self.assertRaises(client.LineTooLong, resp.begin) + def test_overflowing_header_limit_after_100(self): + body = ( + 'HTTP/1.1 100 OK\r\n' + 'r\n' * 32768 + ) + resp = client.HTTPResponse(FakeSocket(body)) + with self.assertRaises(client.HTTPException) as cm: + resp.begin() + # We must assert more because other reasonable errors that we + # do not want can also be HTTPException derived. + self.assertIn('got more than ', str(cm.exception)) + self.assertIn('headers', str(cm.exception)) + def test_overflowing_chunked_line(self): body = ( 'HTTP/1.1 200 OK\r\n' @@ -1216,8 +1250,6 @@ def _validate_host(self, url): # invalid URL as the value of the "Host:" header conn.putrequest('GET', '/', skip_host=1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_putrequest_override_encoding(self): """ It should be possible to override the default encoding @@ -1417,12 +1449,12 @@ def readline(self, limit): class OfflineTest(TestCase): def test_all(self): # Documented objects defined in the module should be in __all__ - expected = {"responses"} # White-list documented dict() object + expected = {"responses"} # Allowlist documented dict() object # HTTPMessage, parse_headers(), and the HTTP status code constants are # intentionally omitted for simplicity - blacklist = {"HTTPMessage", "parse_headers"} + denylist = {"HTTPMessage", "parse_headers"} for name in dir(client): - if name.startswith("_") or name in blacklist: + if name.startswith("_") or name in denylist: continue module_object = getattr(client, name) if getattr(module_object, "__module__", None) == "http.client": @@ -1432,8 +1464,6 @@ def test_all(self): def test_responses(self): self.assertEqual(client.responses[client.NOT_FOUND], "Not Found") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_client_constants(self): # Make sure we don't break backward compatibility with 3.4 expected = [ @@ -1474,6 +1504,7 @@ def test_client_constants(self): 'UNSUPPORTED_MEDIA_TYPE', 'REQUESTED_RANGE_NOT_SATISFIABLE', 'EXPECTATION_FAILED', + 'IM_A_TEAPOT', 'MISDIRECTED_REQUEST', 'UNPROCESSABLE_ENTITY', 'LOCKED', @@ -1492,6 +1523,8 @@ def test_client_constants(self): 'INSUFFICIENT_STORAGE', 'NOT_EXTENDED', 'NETWORK_AUTHENTICATION_REQUIRED', + 'EARLY_HINTS', + 'TOO_EARLY' ] for const in expected: with self.subTest(constant=const): @@ -1647,7 +1680,6 @@ def test_100_close(self): self.assertEqual(conn.connections, 2) -@unittest.skip("TODO: RUSTPYTHON") class HTTPSTest(TestCase): def setUp(self): @@ -1663,6 +1695,8 @@ def test_attributes(self): h = client.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30) self.assertEqual(h.timeout, 30) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_networked(self): # Default settings: requires a valid cert from a trusted CA import ssl @@ -1700,6 +1734,8 @@ def test_networked_trusted_by_default_cert(self): h.close() self.assertIn('text/html', content_type) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_networked_good_cert(self): # We feed the server's cert as a validating cert import ssl @@ -1733,6 +1769,8 @@ def test_networked_good_cert(self): h.close() self.assertIn('nginx', server_string) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_networked_bad_cert(self): # We feed a "CA" cert that is unrelated to the server's cert import ssl @@ -1745,6 +1783,8 @@ def test_networked_bad_cert(self): h.request('GET', '/') self.assertEqual(exc_info.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_local_unknown_cert(self): # The custom cert isn't known to the default trust bundle import ssl @@ -1767,6 +1807,8 @@ def test_local_good_hostname(self): self.addCleanup(resp.close) self.assertEqual(resp.status, 404) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_local_bad_hostname(self): # The (valid) cert doesn't validate the HTTP hostname import ssl @@ -1776,14 +1818,14 @@ def test_local_bad_hostname(self): h = client.HTTPSConnection('localhost', server.port, context=context) with self.assertRaises(ssl.CertificateError): h.request('GET', '/') - # Samwarnings_helper.check_warningshostname=True + # Same with explicit check_hostname=True with warnings_helper.check_warnings(('', DeprecationWarning)): h = client.HTTPSConnection('localhost', server.port, context=context, check_hostname=True) with self.assertRaises(ssl.CertificateError): h.request('GET', '/') # With check_hostname=False, the mismatching is ignored - contewarnings_helper.check_warningslse + context.check_hostname = False with warnings_helper.check_warnings(('', DeprecationWarning)): h = client.HTTPSConnection('localhost', server.port, context=context, check_hostname=False) @@ -1802,7 +1844,7 @@ def test_local_bad_hostname(self): resp.close() h.close() # Passing check_hostname to HTTPSConnection should override the - # conwarnings_helper.check_warnings + # context's setting. with warnings_helper.check_warnings(('', DeprecationWarning)): h = client.HTTPSConnection('localhost', server.port, context=context, check_hostname=True) @@ -1920,9 +1962,9 @@ def test_bytes_body(self): def test_text_file_body(self): self.addCleanup(os_helper.unlink, os_helper.TESTFN) - with open(os_helper.TESTFN, "w") as f: + with open(os_helper.TESTFN, "w", encoding="utf-8") as f: f.write("body") - with open(os_helper.TESTFN) as f: + with open(os_helper.TESTFN, encoding="utf-8") as f: self.conn.request("PUT", "/url", f) message, f = self.get_headers_and_fp() self.assertEqual("text/plain", message.get_content_type()) @@ -2033,6 +2075,23 @@ def test_connect_with_tunnel(self): # This test should be removed when CONNECT gets the HTTP/1.1 blessing self.assertNotIn(b'Host: proxy.com', self.conn.sock.data) + def test_tunnel_connect_single_send_connection_setup(self): + """Regresstion test for https://bugs.python.org/issue43332.""" + with mock.patch.object(self.conn, 'send') as mock_send: + self.conn.set_tunnel('destination.com') + self.conn.connect() + self.conn.request('GET', '/') + mock_send.assert_called() + # Likely 2, but this test only cares about the first. + self.assertGreater( + len(mock_send.mock_calls), 1, + msg=f'unexpected number of send calls: {mock_send.mock_calls}') + proxy_setup_data_sent = mock_send.mock_calls[0][1][0] + self.assertIn(b'CONNECT destination.com', proxy_setup_data_sent) + self.assertTrue( + proxy_setup_data_sent.endswith(b'\r\n\r\n'), + msg=f'unexpected proxy data sent {proxy_setup_data_sent!r}') + def test_connect_put_request(self): self.conn.set_tunnel('destination.com') self.conn.request('PUT', '/', '') diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 28392ccec39..d31582c0db3 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -3,7 +3,7 @@ Written by Cody A.W. Somerville , Josip Dzolonga, and Michael Otteneder for the 2007/08 GHOP contest. """ - +from collections import OrderedDict from http.server import BaseHTTPRequestHandler, HTTPServer, \ SimpleHTTPRequestHandler, CGIHTTPRequestHandler from http import server, HTTPStatus @@ -14,11 +14,12 @@ import re import base64 import ntpath +import pathlib import shutil import email.message import email.utils import html -import http.client +import http, http.client import urllib.parse import tempfile import time @@ -427,6 +428,7 @@ def test_get(self): self.check_status_and_reason(response, HTTPStatus.OK) response = self.request(self.base_url) self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY) + self.assertEqual(response.getheader("Content-Length"), "0") response = self.request(self.base_url + '/?hi=2') self.check_status_and_reason(response, HTTPStatus.OK) response = self.request(self.base_url + '?hi=1') @@ -463,8 +465,6 @@ def test_head(self): self.assertEqual(response.getheader('content-type'), 'application/octet-stream') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_browser_cache(self): """Check that when a request to /test is sent with the request header If-Modified-Since set to date of last modification, the server returns @@ -542,7 +542,7 @@ def test_html_escape_filename(self): fullpath = os.path.join(self.tempdir, filename) try: - open(fullpath, 'w').close() + open(fullpath, 'wb').close() except OSError: raise unittest.SkipTest('Can not create file %s on current file ' 'system' % filename) @@ -589,8 +589,25 @@ def test_html_escape_filename(self): print(os.environ["%s"]) """ +cgi_file6 = """\ +#!%s +import os + +print("X-ambv: was here") +print("Content-type: text/html") +print() +print("
")
+for k, v in os.environ.items():
+    try:
+        k.encode('ascii')
+        v.encode('ascii')
+    except UnicodeEncodeError:
+        continue  # see: BPO-44647
+    print(f"{k}={v}")
+print("
") +""" + -@unittest.skipIf(sys.platform == "win32", "TODO: RUSTPYTHON, teardown errors and universal newline failures") @unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0, "This test can't be run reliably as root (issue #13308).") class CGIHTTPServerTestCase(BaseTestCase): @@ -598,6 +615,8 @@ class request_handler(NoLogRequestHandler, CGIHTTPRequestHandler): pass linesep = os.linesep.encode('ascii') + # TODO: RUSTPYTHON + linesep = b'\n' def setUp(self): BaseTestCase.setUp(self) @@ -605,18 +624,26 @@ def setUp(self): self.parent_dir = tempfile.mkdtemp() self.cgi_dir = os.path.join(self.parent_dir, 'cgi-bin') self.cgi_child_dir = os.path.join(self.cgi_dir, 'child-dir') + self.sub_dir_1 = os.path.join(self.parent_dir, 'sub') + self.sub_dir_2 = os.path.join(self.sub_dir_1, 'dir') + self.cgi_dir_in_sub_dir = os.path.join(self.sub_dir_2, 'cgi-bin') os.mkdir(self.cgi_dir) os.mkdir(self.cgi_child_dir) + os.mkdir(self.sub_dir_1) + os.mkdir(self.sub_dir_2) + os.mkdir(self.cgi_dir_in_sub_dir) self.nocgi_path = None self.file1_path = None self.file2_path = None self.file3_path = None self.file4_path = None + self.file5_path = None # The shebang line should be pure ASCII: use symlink if possible. # See issue #7668. self._pythonexe_symlink = None - if os_helper.can_symlink(): + # TODO: RUSTPYTHON; dl_nt not supported yet + if os_helper.can_symlink() and sys.platform != 'win32': self.pythonexe = os.path.join(self.parent_dir, 'python') self._pythonexe_symlink = support.PythonSymlink(self.pythonexe).__enter__() else: @@ -632,7 +659,7 @@ def setUp(self): self.skipTest("Python executable path is not encodable to utf-8") self.nocgi_path = os.path.join(self.parent_dir, 'nocgi.py') - with open(self.nocgi_path, 'w') as fp: + with open(self.nocgi_path, 'w', encoding='utf-8') as fp: fp.write(cgi_file1 % self.pythonexe) os.chmod(self.nocgi_path, 0o777) @@ -656,6 +683,16 @@ def setUp(self): file4.write(cgi_file4 % (self.pythonexe, 'QUERY_STRING')) os.chmod(self.file4_path, 0o777) + self.file5_path = os.path.join(self.cgi_dir_in_sub_dir, 'file5.py') + with open(self.file5_path, 'w', encoding='utf-8') as file5: + file5.write(cgi_file1 % self.pythonexe) + os.chmod(self.file5_path, 0o777) + + self.file6_path = os.path.join(self.cgi_dir, 'file6.py') + with open(self.file6_path, 'w', encoding='utf-8') as file6: + file6.write(cgi_file6 % self.pythonexe) + os.chmod(self.file6_path, 0o777) + os.chdir(self.parent_dir) def tearDown(self): @@ -673,8 +710,15 @@ def tearDown(self): os.remove(self.file3_path) if self.file4_path: os.remove(self.file4_path) + if self.file5_path: + os.remove(self.file5_path) + if self.file6_path: + os.remove(self.file6_path) os.rmdir(self.cgi_child_dir) os.rmdir(self.cgi_dir) + os.rmdir(self.cgi_dir_in_sub_dir) + os.rmdir(self.sub_dir_2) + os.rmdir(self.sub_dir_1) os.rmdir(self.parent_dir) finally: BaseTestCase.tearDown(self) @@ -793,12 +837,40 @@ def test_query_with_continuous_slashes(self): 'text/html', HTTPStatus.OK), (res.read(), res.getheader('Content-type'), res.status)) + def test_cgi_path_in_sub_directories(self): + try: + CGIHTTPRequestHandler.cgi_directories.append('/sub/dir/cgi-bin') + res = self.request('/sub/dir/cgi-bin/file5.py') + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) + finally: + CGIHTTPRequestHandler.cgi_directories.remove('/sub/dir/cgi-bin') + + def test_accept(self): + browser_accept = \ + 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8' + tests = ( + ((('Accept', browser_accept),), browser_accept), + ((), ''), + # Hack case to get two values for the one header + ((('Accept', 'text/html'), ('ACCEPT', 'text/plain')), + 'text/html,text/plain'), + ) + for headers, expected in tests: + headers = OrderedDict(headers) + with self.subTest(headers): + res = self.request('/cgi-bin/file6.py', 'GET', headers=headers) + self.assertEqual(http.HTTPStatus.OK, res.status) + expected = f"HTTP_ACCEPT={expected}".encode('ascii') + self.assertIn(expected, res.read()) + class SocketlessRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, *args, **kwargs): + def __init__(self, directory=None): request = mock.Mock() request.makefile.return_value = BytesIO() - super().__init__(request, None, None) + super().__init__(request, None, None, directory=directory) self.get_called = False self.protocol_version = "HTTP/1.1" @@ -896,8 +968,6 @@ def test_http_0_9(self): self.assertEqual(result[0], b'Data\r\n') self.verify_get_called() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_extra_space(self): result = self.send_typical_request( b'GET /spaced out HTTP/1.1\r\n' @@ -1075,49 +1145,99 @@ def test_date_time_string(self): class SimpleHTTPRequestHandlerTestCase(unittest.TestCase): """ Test url parsing """ def setUp(self): - self.translated = os.getcwd() - self.translated = os.path.join(self.translated, 'filename') - self.handler = SocketlessRequestHandler() + self.translated_1 = os.path.join(os.getcwd(), 'filename') + self.translated_2 = os.path.join('foo', 'filename') + self.translated_3 = os.path.join('bar', 'filename') + self.handler_1 = SocketlessRequestHandler() + self.handler_2 = SocketlessRequestHandler(directory='foo') + self.handler_3 = SocketlessRequestHandler(directory=pathlib.PurePath('bar')) def test_query_arguments(self): - path = self.handler.translate_path('/filename') - self.assertEqual(path, self.translated) - path = self.handler.translate_path('/filename?foo=bar') - self.assertEqual(path, self.translated) - path = self.handler.translate_path('/filename?a=b&spam=eggs#zot') - self.assertEqual(path, self.translated) + path = self.handler_1.translate_path('/filename') + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('/filename') + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('/filename') + self.assertEqual(path, self.translated_3) + + path = self.handler_1.translate_path('/filename?foo=bar') + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('/filename?foo=bar') + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('/filename?foo=bar') + self.assertEqual(path, self.translated_3) + + path = self.handler_1.translate_path('/filename?a=b&spam=eggs#zot') + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('/filename?a=b&spam=eggs#zot') + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('/filename?a=b&spam=eggs#zot') + self.assertEqual(path, self.translated_3) def test_start_with_double_slash(self): - path = self.handler.translate_path('//filename') - self.assertEqual(path, self.translated) - path = self.handler.translate_path('//filename?foo=bar') - self.assertEqual(path, self.translated) + path = self.handler_1.translate_path('//filename') + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('//filename') + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('//filename') + self.assertEqual(path, self.translated_3) + + path = self.handler_1.translate_path('//filename?foo=bar') + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('//filename?foo=bar') + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('//filename?foo=bar') + self.assertEqual(path, self.translated_3) def test_windows_colon(self): with support.swap_attr(server.os, 'path', ntpath): - path = self.handler.translate_path('c:c:c:foo/filename') + path = self.handler_1.translate_path('c:c:c:foo/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('c:c:c:foo/filename') path = path.replace(ntpath.sep, os.sep) - self.assertEqual(path, self.translated) + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('c:c:c:foo/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_3) - path = self.handler.translate_path('\\c:../filename') + path = self.handler_1.translate_path('\\c:../filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('\\c:../filename') path = path.replace(ntpath.sep, os.sep) - self.assertEqual(path, self.translated) + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('\\c:../filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_3) - path = self.handler.translate_path('c:\\c:..\\foo/filename') + path = self.handler_1.translate_path('c:\\c:..\\foo/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('c:\\c:..\\foo/filename') path = path.replace(ntpath.sep, os.sep) - self.assertEqual(path, self.translated) + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('c:\\c:..\\foo/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_3) - path = self.handler.translate_path('c:c:foo\\c:c:bar/filename') + path = self.handler_1.translate_path('c:c:foo\\c:c:bar/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_1) + path = self.handler_2.translate_path('c:c:foo\\c:c:bar/filename') path = path.replace(ntpath.sep, os.sep) - self.assertEqual(path, self.translated) + self.assertEqual(path, self.translated_2) + path = self.handler_3.translate_path('c:c:foo\\c:c:bar/filename') + path = path.replace(ntpath.sep, os.sep) + self.assertEqual(path, self.translated_3) class MiscTestCase(unittest.TestCase): def test_all(self): expected = [] - blacklist = {'executable', 'nobody_uid', 'test'} + denylist = {'executable', 'nobody_uid', 'test'} for name in dir(server): - if name.startswith('_') or name in blacklist: + if name.startswith('_') or name in denylist: continue module_object = getattr(server, name) if getattr(module_object, '__module__', None) == 'http.server': @@ -1140,8 +1260,6 @@ def mock_server_class(self): ), ) - # TODO: RUSTPYTHON - @unittest.expectedFailure @mock.patch('builtins.print') def test_server_test_unspec(self, _): mock_server = self.mock_server_class() @@ -1151,8 +1269,6 @@ def test_server_test_unspec(self, _): (socket.AF_INET6, socket.AF_INET), ) - # TODO: RUSTPYTHON - @unittest.expectedFailure @mock.patch('builtins.print') def test_server_test_localhost(self, _): mock_server = self.mock_server_class() @@ -1174,8 +1290,6 @@ def test_server_test_localhost(self, _): "127.0.0.1", ) - # TODO: RUSTPYTHON - @unittest.expectedFailure @mock.patch('builtins.print') def test_server_test_ipv6(self, _): for bind in self.ipv6_addrs: @@ -1183,8 +1297,6 @@ def test_server_test_ipv6(self, _): server.test(ServerClass=mock_server, bind=bind) self.assertEqual(mock_server.address_family, socket.AF_INET6) - # TODO: RUSTPYTHON - @unittest.expectedFailure @mock.patch('builtins.print') def test_server_test_ipv4(self, _): for bind in self.ipv4_addrs: @@ -1193,21 +1305,9 @@ def test_server_test_ipv4(self, _): self.assertEqual(mock_server.address_family, socket.AF_INET) -def test_main(verbose=None): - cwd = os.getcwd() - try: - support.run_unittest( - RequestHandlerLoggingTestCase, - BaseHTTPRequestHandlerTestCase, - BaseHTTPServerTestCase, - SimpleHTTPServerTestCase, - CGIHTTPServerTestCase, - SimpleHTTPRequestHandlerTestCase, - MiscTestCase, - ScriptTestCase - ) - finally: - os.chdir(cwd) +def setUpModule(): + unittest.addModuleCleanup(os.chdir, os.getcwd()) + if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index f0886af91f6..d640fe3143c 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -372,8 +372,6 @@ def test_willclose(self): finally: self.unfakehttp() - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(ssl, "ssl module required") def test_url_path_with_control_char_rejected(self): for char_no in list(range(0, 0x21)) + [0x7f]: @@ -401,8 +399,6 @@ def test_url_path_with_control_char_rejected(self): finally: self.unfakehttp() - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(ssl, "ssl module required") def test_url_path_with_newline_header_injection_rejected(self): self.fakehttp(b"HTTP/1.1 200 OK\r\n\r\nHello.") @@ -429,8 +425,6 @@ def test_url_path_with_newline_header_injection_rejected(self): finally: self.unfakehttp() - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(ssl, "ssl module required") def test_url_host_with_control_char_rejected(self): for char_no in list(range(0, 0x21)) + [0x7f]: @@ -448,8 +442,6 @@ def test_url_host_with_control_char_rejected(self): finally: self.unfakehttp() - # TODO: RUSTPYTHON - @unittest.expectedFailure @unittest.skipUnless(ssl, "ssl module required") def test_url_host_with_newline_header_injection_rejected(self): self.fakehttp(b"HTTP/1.1 200 OK\r\n\r\nHello.") diff --git a/common/src/lib.rs b/common/src/lib.rs index feca8846e81..db891ebda04 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -22,6 +22,8 @@ pub mod rc; pub mod refcount; pub mod static_cell; pub mod str; +#[cfg(windows)] +pub mod windows; pub mod vendored { pub use ascii; diff --git a/common/src/windows.rs b/common/src/windows.rs new file mode 100644 index 00000000000..e1f296c941d --- /dev/null +++ b/common/src/windows.rs @@ -0,0 +1,33 @@ +use std::{ + ffi::{OsStr, OsString}, + os::windows::ffi::{OsStrExt, OsStringExt}, +}; + +pub trait ToWideString { + fn to_wide(&self) -> Vec; + fn to_wides_with_nul(&self) -> Vec; +} +impl ToWideString for T +where + T: AsRef, +{ + fn to_wide(&self) -> Vec { + self.as_ref().encode_wide().collect() + } + fn to_wides_with_nul(&self) -> Vec { + self.as_ref().encode_wide().chain(Some(0)).collect() + } +} + +pub trait FromWideString +where + Self: Sized, +{ + fn from_wides_until_nul(wide: &[u16]) -> Self; +} +impl FromWideString for OsString { + fn from_wides_until_nul(wide: &[u16]) -> OsString { + let len = wide.iter().take_while(|&&c| c != 0).count(); + OsString::from_wide(&wide[..len]) + } +} diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 3d824b2ae95..4f10db8be94 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -113,9 +113,9 @@ schannel = "0.1.19" widestring = "0.5.1" [target.'cfg(windows)'.dependencies.windows] -version = "0.39" +version = "0.39.0" features = [ - "Win32_UI_Shell", + "Win32_UI_Shell", "Win32_System_LibraryLoader", "Win32_Foundation" ] [target.'cfg(windows)'.dependencies.winapi] diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index c28faf43bf7..e9eceb0f227 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -5,6 +5,7 @@ pub(crate) use _winapi::make_module; mod _winapi { use crate::{ builtins::PyStrRef, + common::windows::ToWideString, convert::ToPyException, function::{ArgMapping, ArgSequence, OptionalArg}, stdlib::os::errno_err, @@ -16,6 +17,11 @@ mod _winapi { fileapi, handleapi, namedpipeapi, processenv, processthreadsapi, synchapi, winbase, winnt::HANDLE, }; + use windows::{ + core::PCWSTR, + Win32::Foundation::{HINSTANCE, MAX_PATH}, + Win32::System::LibraryLoader::{GetModuleFileNameW, LoadLibraryW}, + }; #[pyattr] use winapi::{ @@ -402,4 +408,29 @@ mod _winapi { }) .map(drop) } + + // TODO: ctypes.LibraryLoader.LoadLibrary + #[allow(dead_code)] + fn LoadLibrary(path: PyStrRef, vm: &VirtualMachine) -> PyResult { + let path = path.as_str().to_wides_with_nul(); + let handle = unsafe { LoadLibraryW(PCWSTR::from_raw(path.as_ptr())).unwrap() }; + if handle.is_invalid() { + return Err(vm.new_runtime_error("LoadLibrary failed".to_owned())); + } + Ok(handle.0) + } + + #[pyfunction] + fn GetModuleFileName(handle: isize, vm: &VirtualMachine) -> PyResult { + let mut path: Vec = vec![0; MAX_PATH as usize]; + let handle = HINSTANCE(handle); + + let length = unsafe { GetModuleFileNameW(handle, &mut path) }; + if length == 0 { + return Err(vm.new_runtime_error("GetModuleFileName failed".to_owned())); + } + + let (path, _) = path.split_at(length as usize); + Ok(String::from_utf16(&path).unwrap()) + } }