Skip to content

Commit ebe1ab2

Browse files
committed
fix: make _mark_seen safe under free-threaded contention
Address review feedback on #1717. The previous eviction body — ``del seen[next(iter(seen))]`` — relied on the GIL to serialize the compound iter/next/del. Under the free-threaded build (3.14t) and under multi-instance sync use where multiple ``Zeroconf`` instances share the module-level ``_seen_logs``, callers can race: - Two threads pop the same victim — ``del`` raises ``KeyError`` on the loser - One thread mutates the dict between another's ``iter()`` and ``next()`` — the iterator's mutation-count check raises ``RuntimeError`` ("dictionary changed size during iteration") Switch to ``seen.pop(next(iter(seen), None), None)`` so the pop is idempotent (no ``KeyError``) and the iter start handles the empty- dict edge case (no ``StopIteration``), and wrap the iter/next in a ``try/except RuntimeError`` so concurrent mutation during eviction is absorbed. Worst case under contention is a transient overshoot of the cap by one entry per racing thread, which clears as soon as the contention does. Add ``test_mark_seen_absorbs_runtime_error_during_eviction`` which substitutes a dict subclass whose ``__iter__`` always raises, proving the new ``except`` branch lets the insert still complete. Also tighten ``tests/test_protocol.py::test_seen_logs_is_bounded`` per Kōan's review comment: previously asserted ``<= _MAX_SEEN_LOGS``, which would still pass if a future refactor collapsed the per-port keys to a single dedup string. Now asserts ``== _MAX_SEEN_LOGS`` and verifies port 0's exception text is gone while the highest port's is still present — pins both the bound and that the parser exception path actually enters the dict with per-source-unique keys.
1 parent 9a64353 commit ebe1ab2

3 files changed

Lines changed: 50 additions & 5 deletions

File tree

src/zeroconf/_logger.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,27 @@ def _mark_seen(seen: dict[str, None], key: str) -> bool:
5050
addresses, packet offsets) cannot grow it without bound. Evicts
5151
the oldest entry per overflow (dict preserves insertion order on
5252
Python 3.7+), so ``_MAX_SEEN_LOGS`` is a recency window.
53+
54+
The dict is shared across all ``Zeroconf`` instances in the
55+
process; on the free-threaded build (3.14t) and under multi-
56+
instance sync use, callers can race. Individual dict operations
57+
(``in``, ``__setitem__``, ``pop``, ``len``) are atomic in CPython
58+
3.13+ FT and don't crash, but the compound ``iter`` → ``next``
59+
used to find the FIFO victim can raise ``RuntimeError`` if
60+
another thread mutates the dict between the two ops. Catch and
61+
skip — the other thread is already shrinking the set, so missing
62+
one eviction here just lets the cap drift up by one entry per
63+
racing thread until contention clears.
5364
"""
5465
if key in seen:
5566
return False
5667
if len(seen) >= _MAX_SEEN_LOGS:
57-
del seen[next(iter(seen))]
68+
try:
69+
oldest = next(iter(seen), None)
70+
except RuntimeError:
71+
oldest = None
72+
if oldest is not None:
73+
seen.pop(oldest, None)
5874
seen[key] = None
5975
return True
6076

tests/test_logger.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import call, patch
77

88
from zeroconf import _logger
9-
from zeroconf._logger import _MAX_SEEN_LOGS, QuietLogger, set_logger_level_if_unset
9+
from zeroconf._logger import _MAX_SEEN_LOGS, QuietLogger, _mark_seen, set_logger_level_if_unset
1010

1111

1212
def test_loading_logger():
@@ -85,6 +85,26 @@ def test_llog_exception_debug():
8585
assert mock_log_debug.mock_calls == [call("the exception", exc_info=False)]
8686

8787

88+
def test_mark_seen_absorbs_runtime_error_during_eviction() -> None:
89+
"""Concurrent mutation can make ``iter(seen)`` raise ``RuntimeError``.
90+
91+
Free-threaded (3.14t) and multi-instance sync callers share
92+
``_seen_logs``; if another thread mutates it between ``iter()``
93+
and ``next()`` the iterator raises ``RuntimeError``.
94+
``_mark_seen`` must absorb that and still insert the new key.
95+
"""
96+
97+
class RacyDict(dict[str, None]):
98+
def __iter__(self): # type: ignore[override]
99+
raise RuntimeError("dictionary changed size during iteration")
100+
101+
seen: dict[str, None] = RacyDict()
102+
for i in range(_MAX_SEEN_LOGS):
103+
seen[f"k-{i}"] = None
104+
assert _mark_seen(seen, "new-key") is True
105+
assert "new-key" in seen
106+
107+
88108
def test_seen_logs_is_bounded() -> None:
89109
"""``_seen_logs`` stays at the cap and evicts oldest-first (FIFO)."""
90110
_logger._seen_logs.clear()

tests/test_protocol.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -965,16 +965,25 @@ def test_dns_compression_generic_failure(caplog):
965965

966966

967967
def test_seen_logs_is_bounded():
968-
"""Corrupt packets from varying peers must not grow _seen_logs without bound."""
968+
"""Corrupt packets from varying peers fill ``_seen_logs`` exactly to the cap."""
969969
packet = (
970970
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01"
971971
b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00"
972972
b"\x00\x01\x00\x04\xc0\xa8\xd0\x06"
973973
)
974+
overflow = 5
974975
_incoming_module._seen_logs.clear()
975-
for port in range(_MAX_SEEN_LOGS + 5):
976+
for port in range(_MAX_SEEN_LOGS + overflow):
976977
r.DNSIncoming(packet, ("1.2.3.4", port))
977-
assert len(_incoming_module._seen_logs) <= _MAX_SEEN_LOGS
978+
# Bound is hit exactly — confirms the parser exception path actually
979+
# entered the dict with a per-port-unique key; a future change that
980+
# dropped self.source from the exception text would collapse to a
981+
# single dedup key and fail this assertion.
982+
assert len(_incoming_module._seen_logs) == _MAX_SEEN_LOGS
983+
# FIFO eviction: the earliest port's exception string is gone, the
984+
# latest port's is still present.
985+
assert not any("'1.2.3.4', 0)" in k for k in _incoming_module._seen_logs)
986+
assert any(f"'1.2.3.4', {_MAX_SEEN_LOGS + overflow - 1})" in k for k in _incoming_module._seen_logs)
978987

979988

980989
def test_label_length_attack():

0 commit comments

Comments
 (0)