Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions Lib/test/ssltests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Convenience test module to run all of the OpenSSL-related tests in the
# standard library.

import ssl
import sys
import subprocess

TESTS = [
'test_asyncio', 'test_ensurepip.py', 'test_ftplib', 'test_hashlib',
'test_hmac', 'test_httplib', 'test_imaplib',
'test_poplib', 'test_ssl', 'test_smtplib', 'test_smtpnet',
'test_urllib2_localnet', 'test_venv', 'test_xmlrpc'
]

def run_regrtests(*extra_args):
print(ssl.OPENSSL_VERSION)
args = [
sys.executable,
'-Werror', '-bb', # turn warnings into exceptions
'-m', 'test',
]
if not extra_args:
args.extend([
'-r', # randomize
'-w', # re-run failed tests with -v
'-u', 'network', # use network
'-u', 'urlfetch', # download test vectors
'-j', '0' # use multiple CPUs
])
else:
args.extend(extra_args)
args.extend(TESTS)
result = subprocess.call(args)
sys.exit(result)

if __name__ == '__main__':
run_regrtests(*sys.argv[1:])
254 changes: 232 additions & 22 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Test the support for SSL and sockets

import contextlib
import sys
import unittest
import unittest.mock
Expand Down Expand Up @@ -47,9 +48,20 @@

PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = socket_helper.HOST
IS_AWS_LC = "AWS-LC" in ssl.OPENSSL_VERSION
IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')

HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
CAN_SET_KEYLOG = HAS_KEYLOG and os.name != "nt"
requires_keylog_setter = unittest.skipUnless(
CAN_SET_KEYLOG,
"cannot set 'keylog_filename' on Windows"
)


PROTOCOL_TO_TLS_VERSION = {}
for proto, ver in (
("PROTOCOL_SSLv3", "SSLv3"),
Expand Down Expand Up @@ -258,26 +270,67 @@ def utc_offset(): #NOTE: ignore issues like #1647654
)


def test_wrap_socket(sock, *,
cert_reqs=ssl.CERT_NONE, ca_certs=None,
ciphers=None, certfile=None, keyfile=None,
**kwargs):
if not kwargs.get("server_side"):
kwargs["server_hostname"] = SIGNED_CERTFILE_HOSTNAME
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
else:
def make_test_context(
*,
server_side=False,
check_hostname=None,
cert_reqs=ssl.CERT_NONE,
ca_certs=None, certfile=None, keyfile=None,
ciphers=None,
min_version=None, max_version=None,
):
if server_side:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if cert_reqs is not None:
else:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

if check_hostname is None:
if cert_reqs == ssl.CERT_NONE:
context.check_hostname = False
else:
context.check_hostname = check_hostname

if cert_reqs is not None:
context.verify_mode = cert_reqs

if ca_certs is not None:
context.load_verify_locations(ca_certs)
if certfile is not None or keyfile is not None:
context.load_cert_chain(certfile, keyfile)

if ciphers is not None:
context.set_ciphers(ciphers)
return context.wrap_socket(sock, **kwargs)

if min_version is not None:
context.minimum_version = min_version
if max_version is not None:
context.maximum_version = max_version

return context


def test_wrap_socket(
sock,
*,
server_side=False,
check_hostname=None,
cert_reqs=ssl.CERT_NONE,
ca_certs=None, certfile=None, keyfile=None,
ciphers=None,
min_version=None, max_version=None,
**kwargs,
):
context = make_test_context(
server_side=server_side,
check_hostname=check_hostname,
cert_reqs=cert_reqs,
ca_certs=ca_certs, certfile=certfile, keyfile=keyfile,
ciphers=ciphers,
min_version=min_version, max_version=max_version,
)
if not server_side:
kwargs.setdefault("server_hostname", SIGNED_CERTFILE_HOSTNAME)
return context.wrap_socket(sock, server_side=server_side, **kwargs)


USE_SAME_TEST_CONTEXT = False
Expand Down Expand Up @@ -317,6 +370,20 @@ def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
return client_context, server_context, hostname


def do_ssl_object_handshake(sslobject, outgoing, max_retry=25):
"""Call do_handshake() on the sslobject and return the sent data.

If do_handshake() fails more than *max_retry* times, return None.
"""
data, attempt = None, 0
while not data and attempt < max_retry:
with contextlib.suppress(ssl.SSLWantReadError):
sslobject.do_handshake()
data = outgoing.read()
attempt += 1
return data


class BasicSocketTests(unittest.TestCase):

def test_constants(self):
Expand Down Expand Up @@ -698,6 +765,7 @@ def test_dealloc_warn(self):
support.gc_collect()
self.assertIn(r, str(cm.warning.args[0]))

@unittest.expectedFailureIf(sys.platform == "android", "TODO: RUSTPYTHON; TypeError: path should be string, bytes, os.PathLike or integer, not NoneType")
def test_get_default_verify_paths(self):
paths = ssl.get_default_verify_paths()
self.assertEqual(len(paths), 6)
Expand Down Expand Up @@ -1035,7 +1103,6 @@ def test_hostname_checks_common_name(self):
with self.assertRaises(AttributeError):
ctx.hostname_checks_common_name = True

@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: <TLSVersion.TLSv1_2: 771> not found in {<TLSVersion.TLSv1: 769>, <TLSVersion.TLSv1_1: 770>, <TLSVersion.SSLv3: 768>}
@ignore_deprecation
def test_min_max_version(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
Expand Down Expand Up @@ -1089,7 +1156,12 @@ def test_min_max_version(self):
ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
self.assertIn(
ctx.maximum_version,
{ssl.TLSVersion.TLSv1, ssl.TLSVersion.TLSv1_1, ssl.TLSVersion.SSLv3}
{
ssl.TLSVersion.TLSv1,
ssl.TLSVersion.TLSv1_1,
ssl.TLSVersion.TLSv1_2,
ssl.TLSVersion.SSLv3,
}
)

ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
Expand Down Expand Up @@ -1410,6 +1482,50 @@ def dummycallback(sock, servername, ctx):
ctx.set_servername_callback(None)
ctx.set_servername_callback(dummycallback)

@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: Expected 'mock' to not have been called. Called 1 times.
def test_sni_callback_on_dead_references(self):
# See https://github.com/python/cpython/issues/146080.
c_ctx = make_test_context()
c_inc, c_out = ssl.MemoryBIO(), ssl.MemoryBIO()
client = c_ctx.wrap_bio(c_inc, c_out, server_hostname=SIGNED_CERTFILE_HOSTNAME)

def sni_callback(sock, servername, ctx): pass
sni_callback = unittest.mock.Mock(wraps=sni_callback)
s_ctx = make_test_context(server_side=True, certfile=SIGNED_CERTFILE)
s_ctx.set_servername_callback(sni_callback)

s_inc, s_out = ssl.MemoryBIO(), ssl.MemoryBIO()
server = s_ctx.wrap_bio(s_inc, s_out, server_side=True)
server_impl = server._sslobj

# Perform the handshake on the client side first.
data = do_ssl_object_handshake(client, c_out)
sni_callback.assert_not_called()
if data is None:
self.skipTest("cannot establish a handshake from the client")
s_inc.write(data)
sni_callback.assert_not_called()
# Delete the server object before it starts doing its handshake
# and ensure that we did not call the SNI callback yet.
del server
gc.collect()
# Try to continue the server's handshake by directly using
# the internal SSL object. The latter is a weak reference
# stored in the server context and has now a dead owner.
with self.assertRaises(ssl.SSLError) as cm:
server_impl.do_handshake()
# The SNI C callback raised an exception before calling our callback.
sni_callback.assert_not_called()

# In AWS-LC, any handshake failures reports SSL_R_PARSE_TLSEXT,
# while OpenSSL uses SSL_R_CALLBACK_FAILED on SNI callback failures.
if IS_AWS_LC:
libssl_error_reason = "PARSE_TLSEXT"
else:
libssl_error_reason = "callback failed"
self.assertIn(libssl_error_reason, str(cm.exception))
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_SSL)

@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: <SSLContext(protocol=17)> is not None
def test_sni_callback_refcycle(self):
# Reference cycles through the servername callback are detected
Expand All @@ -1423,6 +1539,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(support.Py_GIL_DISABLED,
"test is only useful if the GIL is disabled")
@threading_helper.requires_working_threading()
def test_sni_callback_race(self):
# Replacing sni_callback while handshakes are in-flight must not
# crash (use-after-free on the callback in free-threaded builds).
client_ctx, server_ctx, hostname = testing_context()

server_ctx.sni_callback = lambda *a: None
done = threading.Event()

def do_handshakes():
while not done.is_set():
c_in = ssl.MemoryBIO()
c_out = ssl.MemoryBIO()
s_in = ssl.MemoryBIO()
s_out = ssl.MemoryBIO()
client = client_ctx.wrap_bio(
c_in, c_out, server_hostname=hostname)
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
for _ in range(50):
try:
client.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if c_out.pending:
s_in.write(c_out.read())
try:
server.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if s_out.pending:
c_in.write(s_out.read())

def toggle_callback():
while not done.is_set():
server_ctx.sni_callback = lambda *a: None
server_ctx.sni_callback = None

workers = max(4, (os.cpu_count() or 4) * 2)
threads = [threading.Thread(target=do_handshakes)
for _ in range(workers)]
threads.append(threading.Thread(target=toggle_callback))

with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
done.set()
self.assertIsNone(cm.exc_value)

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
Expand Down Expand Up @@ -1665,6 +1834,39 @@ def test_num_tickest(self):
with self.assertRaises(ValueError):
ctx.num_tickets = 1

@support.cpython_only
def test_refcycle_msg_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context()
def msg_callback(*args, _=ctx, **kwargs): ...
ctx._msg_callback = msg_callback

@support.cpython_only
@requires_keylog_setter
def test_refcycle_keylog_filename(self):
# See https://github.com/python/cpython/issues/142516.
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
ctx = make_test_context()
class KeylogFilename(str): ...
ctx.keylog_filename = KeylogFilename(os_helper.TESTFN)
ctx.keylog_filename._ = ctx

@support.cpython_only
@unittest.skipUnless(ssl.HAS_PSK, 'requires TLS-PSK')
def test_refcycle_psk_client_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context()
def psk_client_callback(*args, _=ctx, **kwargs): ...
ctx.set_psk_client_callback(psk_client_callback)

@support.cpython_only
@unittest.skipUnless(ssl.HAS_PSK, 'requires TLS-PSK')
def test_refcycle_psk_server_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context(server_side=True)
def psk_server_callback(*args, _=ctx, **kwargs): ...
ctx.set_psk_server_callback(psk_server_callback)


class SSLErrorTests(unittest.TestCase):

Expand Down Expand Up @@ -4922,10 +5124,6 @@ def test_internal_chain_server(self):
self.assertEqual(res, b'\x02\n')


HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')

class TestSSLDebug(unittest.TestCase):

def keylog_lines(self, fname=os_helper.TESTFN):
Expand Down Expand Up @@ -5164,15 +5362,27 @@ def non_linux_skip_if_other_okay_error(self, err):
return # Expect the full test setup to always work on Linux.
if (isinstance(err, ConnectionResetError) or
(isinstance(err, OSError) and err.errno == errno.EINVAL) or
re.search('wrong.version.number', str(getattr(err, "reason", "")), re.I)):
re.search(
# Matches the following error messages:
# '[SSL: WRONG_VERSION_NUMBER] wrong version number (_ssl.c:1123)'
# '[SSL: RECORD_LAYER_FAILURE] record layer failure (_ssl.c:1109)'
# '[SSL: HTTP_REQUEST] http request (_ssl.c:1143)'
r'wrong.version.number|record.layer.failure|http.request',
str(getattr(err, "reason", "")),
re.IGNORECASE,
)
):
# On Windows the TCP RST leads to a ConnectionResetError
# (ECONNRESET) which Linux doesn't appear to surface to userspace.
# If wrap_socket() winds up on the "if connected:" path and doing
# the actual wrapping... we get an SSLError from OpenSSL. Typically
# WRONG_VERSION_NUMBER. While appropriate, neither is the scenario
# we're specifically trying to test. The way this test is written
# is known to work on Linux. We'll skip it anywhere else that it
# does not present as doing so.
# the actual wrapping... we get an SSLError from OpenSSL. This is
# typically WRONG_VERSION_NUMBER. The same happens on iOS, but
# RECORD_LAYER_FAILURE or HTTP_REQUEST is the error.
#
# While appropriate, these scenarios aren't what we're specifically
# trying to test. The way this test is written is known to work on
# Linux. We'll skip it anywhere else that it does not present as
# doing so.
try:
self.skipTest(f"Could not recreate conditions on {sys.platform}:"
f" {err=}")
Expand Down
Loading