Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,69 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(support.Py_GIL_DISABLED,
"test is only useful if the GIL is disabled")
@threading_helper.requires_working_threading()
def test_sni_callback_race(self):
# Replacing sni_callback while handshakes are in-flight must not
# crash (use-after-free on the callback in free-threaded builds).
client_ctx, server_ctx, hostname = testing_context()

def make_callback(n):
def sni_cb(_ssl_obj, _servername, _ctx):
if n == -1 and _servername == "":
raise AssertionError("unreachable")
return None
return sni_cb

server_ctx.sni_callback = make_callback(0)
done = threading.Event()

def do_handshakes():
while not done.is_set():
c_in = ssl.MemoryBIO()
c_out = ssl.MemoryBIO()
s_in = ssl.MemoryBIO()
s_out = ssl.MemoryBIO()
client = client_ctx.wrap_bio(
c_in, c_out, server_hostname=hostname)
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
for _ in range(50):
try:
client.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if c_out.pending:
s_in.write(c_out.read())
try:
server.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if s_out.pending:
c_in.write(s_out.read())

def toggle_callback():
i = 0
while not done.is_set():
server_ctx.sni_callback = make_callback(i)
server_ctx.sni_callback = None
server_ctx.sni_callback = make_callback(-i)
i += 1

workers = max(4, (os.cpu_count() or 4) * 2)
threads = [threading.Thread(target=do_handshakes)
for _ in range(workers)]
threads.append(threading.Thread(target=toggle_callback))

with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
done.set()
self.assertIsNone(cm.exc_value)

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
Expand Down
Loading