diff --git a/src/zeroconf/_logger.py b/src/zeroconf/_logger.py index 0d734dfd..99990cf6 100644 --- a/src/zeroconf/_logger.py +++ b/src/zeroconf/_logger.py @@ -25,7 +25,7 @@ import logging import sys -from typing import Any, ClassVar, cast +from typing import Any log = logging.getLogger(__name__.split(".", maxsplit=1)[0]) log.addHandler(logging.NullHandler()) @@ -39,50 +39,73 @@ def set_logger_level_if_unset() -> None: set_logger_level_if_unset() -class QuietLogger: - _seen_logs: ClassVar[dict[str, int | tuple]] = {} +_MAX_SEEN_LOGS = 512 +_seen_logs: dict[str, None] = {} + + +def _evict_oldest(seen: dict[str, None]) -> bool: + """Pop the oldest entry from ``seen``; return False if it raced. + + Individual dict ops (``pop`` with a default, ``next``) are atomic + on the free-threaded build, but the compound ``iter`` → ``next`` + used to pick the FIFO victim can raise ``RuntimeError`` if + another thread mutates the dict between the two ops. The caller + breaks its drain loop on False so concurrent mutation can't make + it spin. + """ + try: + seen.pop(next(iter(seen)), None) + except (RuntimeError, StopIteration): + return False + return True + + +def _mark_seen(seen: dict[str, None], key: str) -> bool: + """Record ``key`` in ``seen`` and return True if it was newly added. + + Bounds the dict so callers passing attacker-influenced keys (peer + addresses, packet offsets) cannot grow it without bound. Evicts + the oldest entries on overflow (dict preserves insertion order on + Python 3.7+), so ``_MAX_SEEN_LOGS`` is a recency window. + + The dict is shared across all ``Zeroconf`` instances in the + process; on the free-threaded build (3.14t) and under multi- + instance sync use, callers can race the ``len < cap`` check and + both insert, leaving the dict transiently above the cap. The + drain loop runs on every call (steady-state-at-cap hits are a + single ``len`` + compare past the membership check because the + helper short-circuits) so a contention burst is corrected by the + next caller regardless of whether it's a hit or a miss. + """ + inserting = key not in seen + # Hit (``inserting`` is False): drain only if drifted above cap. + # Miss (``inserting`` is True): drain to ``cap - 1`` to make room + # for the new key. Bool subtracts as 0/1 to pick the right limit. + while len(seen) > _MAX_SEEN_LOGS - inserting and _evict_oldest(seen): + pass + if inserting: + seen[key] = None + return inserting + +class QuietLogger: @classmethod def log_exception_warning(cls, *logger_data: Any) -> None: - exc_info = sys.exc_info() - exc_str = str(exc_info[1]) - if exc_str not in cls._seen_logs: - # log at warning level the first time this is seen - cls._seen_logs[exc_str] = exc_info - logger = log.warning - else: - logger = log.debug + first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1])) + logger = log.warning if first_time else log.debug logger(*(logger_data or ["Exception occurred"]), exc_info=True) @classmethod def log_exception_debug(cls, *logger_data: Any) -> None: - log_exc_info = False - exc_info = sys.exc_info() - exc_str = str(exc_info[1]) - if exc_str not in cls._seen_logs: - # log the trace only on the first time - cls._seen_logs[exc_str] = exc_info - log_exc_info = True - log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info) + first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1])) + log.debug(*(logger_data or ["Exception occurred"]), exc_info=first_time) @classmethod def log_warning_once(cls, *args: Any) -> None: - msg_str = args[0] - if msg_str not in cls._seen_logs: - cls._seen_logs[msg_str] = 0 - logger = log.warning - else: - logger = log.debug - cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug logger(*args) @classmethod def log_exception_once(cls, exc: Exception, *args: Any) -> None: - msg_str = args[0] - if msg_str not in cls._seen_logs: - cls._seen_logs[msg_str] = 0 - logger = log.warning - else: - logger = log.debug - cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug logger(*args, exc_info=exc) diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index d772f470..ffbbb59f 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -37,7 +37,7 @@ DNSText, ) from .._exceptions import IncomingDecodeError -from .._logger import log +from .._logger import _mark_seen, log from .._utils.time import current_time_millis from ..const import ( _FLAGS_QR_MASK, @@ -63,7 +63,7 @@ DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError) -_seen_logs: dict[str, int | tuple] = {} +_seen_logs: dict[str, None] = {} _str = str _int = int @@ -182,13 +182,7 @@ def _initial_parse(self) -> None: @classmethod def _log_exception_debug(cls, *logger_data: Any) -> None: - log_exc_info = False - exc_info = sys.exc_info() - exc_str = str(exc_info[1]) - if exc_str not in _seen_logs: - # log the trace only on the first time - _seen_logs[exc_str] = exc_info - log_exc_info = True + log_exc_info = _mark_seen(_seen_logs, str(sys.exc_info()[1])) log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info) def answers(self) -> list[DNSRecord]: diff --git a/tests/benchmarks/test_mark_seen.py b/tests/benchmarks/test_mark_seen.py new file mode 100644 index 00000000..4f82da8c --- /dev/null +++ b/tests/benchmarks/test_mark_seen.py @@ -0,0 +1,39 @@ +"""Benchmark for _logger._mark_seen.""" + +from __future__ import annotations + +from pytest_codspeed import BenchmarkFixture + +from zeroconf._logger import _MAX_SEEN_LOGS, _mark_seen + + +def test_mark_seen_hit(benchmark: BenchmarkFixture) -> None: + """Benchmark the cache-hit path (same key repeated).""" + seen: dict[str, None] = {"warm": None} + + @benchmark + def _hit() -> None: + for _ in range(1000): + _mark_seen(seen, "warm") + + +def test_mark_seen_fill(benchmark: BenchmarkFixture) -> None: + """Benchmark filling from empty up to the cap (no evictions).""" + keys = [f"key-{i}" for i in range(_MAX_SEEN_LOGS)] + + @benchmark + def _fill() -> None: + seen: dict[str, None] = {} + for k in keys: + _mark_seen(seen, k) + + +def test_mark_seen_churn(benchmark: BenchmarkFixture) -> None: + """Benchmark sustained eviction (every call past the cap drops oldest).""" + keys = [f"churn-{i}" for i in range(_MAX_SEEN_LOGS * 4)] + + @benchmark + def _churn() -> None: + seen: dict[str, None] = {} + for k in keys: + _mark_seen(seen, k) diff --git a/tests/test_logger.py b/tests/test_logger.py index 4e09aa3b..8042e49c 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,7 +5,8 @@ import logging from unittest.mock import call, patch -from zeroconf._logger import QuietLogger, set_logger_level_if_unset +from zeroconf import _logger +from zeroconf._logger import _MAX_SEEN_LOGS, QuietLogger, _mark_seen, set_logger_level_if_unset def test_loading_logger(): @@ -25,7 +26,7 @@ def test_loading_logger(): def test_log_warning_once(): """Test we only log with warning level once.""" - QuietLogger._seen_logs = {} + _logger._seen_logs.clear() quiet_logger = QuietLogger() with ( patch("zeroconf._logger.log.warning") as mock_log_warning, @@ -48,7 +49,7 @@ def test_log_warning_once(): def test_log_exception_warning(): """Test we only log with warning level once.""" - QuietLogger._seen_logs = {} + _logger._seen_logs.clear() quiet_logger = QuietLogger() with ( patch("zeroconf._logger.log.warning") as mock_log_warning, @@ -71,7 +72,7 @@ def test_log_exception_warning(): def test_llog_exception_debug(): """Test we only log with a trace once.""" - QuietLogger._seen_logs = {} + _logger._seen_logs.clear() quiet_logger = QuietLogger() with patch("zeroconf._logger.log.debug") as mock_log_debug: quiet_logger.log_exception_debug("the exception") @@ -84,9 +85,85 @@ def test_llog_exception_debug(): assert mock_log_debug.mock_calls == [call("the exception", exc_info=False)] +def test_mark_seen_absorbs_runtime_error_during_eviction() -> None: + """Concurrent mutation can make ``iter(seen)`` raise ``RuntimeError``. + + Free-threaded (3.14t) and multi-instance sync callers share + ``_seen_logs``; if another thread mutates it between ``iter()`` + and ``next()`` the iterator raises ``RuntimeError``. + ``_mark_seen`` must absorb that and still insert the new key. + """ + + class RacyDict(dict[str, None]): + def __iter__(self): # type: ignore[override] + raise RuntimeError("dictionary changed size during iteration") + + seen: dict[str, None] = RacyDict() + for i in range(_MAX_SEEN_LOGS): + seen[f"k-{i}"] = None + assert _mark_seen(seen, "new-key") is True + assert "new-key" in seen + + +def test_mark_seen_drains_drift_above_cap() -> None: + """``_mark_seen`` drains a drifted-over-cap dict back to the cap. + + Concurrent inserts on the free-threaded build can leave the dict + transiently above ``_MAX_SEEN_LOGS`` (e.g. two threads both passed + the ``len < cap`` check and both inserted). The next non-racing + call must drain the accumulated overshoot, not just evict one + entry — otherwise the cap silently inflates with thread count. + """ + seen: dict[str, None] = {} + drift = 10 + for i in range(_MAX_SEEN_LOGS + drift): + seen[f"k-{i}"] = None + assert len(seen) == _MAX_SEEN_LOGS + drift + assert _mark_seen(seen, "new-key") is True + assert len(seen) == _MAX_SEEN_LOGS + assert "new-key" in seen + for i in range(drift + 1): + assert f"k-{i}" not in seen + + +def test_mark_seen_drains_drift_on_hit_path() -> None: + """``_mark_seen`` drains drift even when ``key`` is already cached. + + A hit-heavy workload after a contention burst (e.g. the same + exception text deduplicated repeatedly) must still correct the + overshoot — otherwise the dict can sit permanently above the cap + until a miss happens to come along. + """ + seen: dict[str, None] = {} + drift = 10 + for i in range(_MAX_SEEN_LOGS + drift): + seen[f"k-{i}"] = None + # Hit on a non-oldest key — survives the drift drain. + hit_key = f"k-{_MAX_SEEN_LOGS}" + assert _mark_seen(seen, hit_key) is False + assert len(seen) == _MAX_SEEN_LOGS + assert hit_key in seen + for i in range(drift): + assert f"k-{i}" not in seen + + +def test_seen_logs_is_bounded() -> None: + """``_seen_logs`` stays at the cap and evicts oldest-first (FIFO).""" + _logger._seen_logs.clear() + overflow = 5 + with patch("zeroconf._logger.log.warning"), patch("zeroconf._logger.log.debug"): + for i in range(_MAX_SEEN_LOGS + overflow): + QuietLogger.log_warning_once(f"warning-{i}") + assert len(_logger._seen_logs) == _MAX_SEEN_LOGS + for i in range(overflow): + assert f"warning-{i}" not in _logger._seen_logs + for i in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow): + assert f"warning-{i}" in _logger._seen_logs + + def test_log_exception_once(): """Test we only log with warning level once.""" - QuietLogger._seen_logs = {} + _logger._seen_logs.clear() quiet_logger = QuietLogger() exc = Exception() with ( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index bac2b447..782b77aa 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -14,6 +14,8 @@ import zeroconf as r from zeroconf import DNSHinfo, DNSIncoming, DNSText, const, current_time_millis +from zeroconf._logger import _MAX_SEEN_LOGS +from zeroconf._protocol import incoming as _incoming_module from . import has_working_ipv6 @@ -962,6 +964,38 @@ def test_dns_compression_generic_failure(caplog): assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text +def test_seen_logs_is_bounded(): + """Corrupt packets from varying peers fill ``_seen_logs`` exactly to the cap.""" + packet = ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01" + b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00" + b"\x00\x01\x00\x04\xc0\xa8\xd0\x06" + ) + overflow = 5 + _incoming_module._seen_logs.clear() + # Snapshot the actual key the parser inserted per port. This is whatever + # ``str(exc_info()[1])`` produces today — the test stays agnostic to the + # exception text format so a future normalization of the message (see + # the discussion on #1714) doesn't break the assertions, while still + # pinning that the parser exception path actually entered the dict. + keys_per_port: list[str] = [] + for port in range(_MAX_SEEN_LOGS + overflow): + r.DNSIncoming(packet, ("1.2.3.4", port)) + keys_per_port.append(next(reversed(_incoming_module._seen_logs))) + # Bound is hit exactly. + assert len(_incoming_module._seen_logs) == _MAX_SEEN_LOGS + # Each port produced a distinct dedup key — a regression that dropped + # the per-packet-varying component (e.g. self.source) from the exception + # text would collapse all 517 calls to one key and fail this. + assert len(set(keys_per_port)) == _MAX_SEEN_LOGS + overflow + # FIFO eviction by key identity (no substring matching on the message + # format): the earliest ports' keys are gone, the latest ports' remain. + for port in range(overflow): + assert keys_per_port[port] not in _incoming_module._seen_logs + for port in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow): + assert keys_per_port[port] in _incoming_module._seen_logs + + def test_label_length_attack(): """Test our wire parser does not loop forever when the name exceeds 253 chars.""" packet = (