Skip to content
Merged
89 changes: 56 additions & 33 deletions src/zeroconf/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import logging
import sys
from typing import Any, ClassVar, cast
from typing import Any

log = logging.getLogger(__name__.split(".", maxsplit=1)[0])
log.addHandler(logging.NullHandler())
Expand All @@ -39,50 +39,73 @@ def set_logger_level_if_unset() -> None:
set_logger_level_if_unset()


class QuietLogger:
_seen_logs: ClassVar[dict[str, int | tuple]] = {}
_MAX_SEEN_LOGS = 512
_seen_logs: dict[str, None] = {}


def _evict_oldest(seen: dict[str, None]) -> bool:
"""Pop the oldest entry from ``seen``; return False if it raced.

Individual dict ops (``pop`` with a default, ``next``) are atomic
on the free-threaded build, but the compound ``iter`` → ``next``
used to pick the FIFO victim can raise ``RuntimeError`` if
another thread mutates the dict between the two ops. The caller
breaks its drain loop on False so concurrent mutation can't make
it spin.
"""
try:
seen.pop(next(iter(seen)), None)
except (RuntimeError, StopIteration):
return False
return True


def _mark_seen(seen: dict[str, None], key: str) -> bool:
"""Record ``key`` in ``seen`` and return True if it was newly added.

Bounds the dict so callers passing attacker-influenced keys (peer
addresses, packet offsets) cannot grow it without bound. Evicts
the oldest entries on overflow (dict preserves insertion order on
Python 3.7+), so ``_MAX_SEEN_LOGS`` is a recency window.

The dict is shared across all ``Zeroconf`` instances in the
process; on the free-threaded build (3.14t) and under multi-
instance sync use, callers can race the ``len < cap`` check and
both insert, leaving the dict transiently above the cap. The
drain loop runs on every call (steady-state-at-cap hits are a
single ``len`` + compare past the membership check because the
helper short-circuits) so a contention burst is corrected by the
next caller regardless of whether it's a hit or a miss.
"""
inserting = key not in seen
# Hit (``inserting`` is False): drain only if drifted above cap.
# Miss (``inserting`` is True): drain to ``cap - 1`` to make room
# for the new key. Bool subtracts as 0/1 to pick the right limit.
while len(seen) > _MAX_SEEN_LOGS - inserting and _evict_oldest(seen):
pass
if inserting:
seen[key] = None
return inserting


class QuietLogger:
@classmethod
def log_exception_warning(cls, *logger_data: Any) -> None:
exc_info = sys.exc_info()
exc_str = str(exc_info[1])
if exc_str not in cls._seen_logs:
# log at warning level the first time this is seen
cls._seen_logs[exc_str] = exc_info
logger = log.warning
else:
logger = log.debug
first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
logger = log.warning if first_time else log.debug
logger(*(logger_data or ["Exception occurred"]), exc_info=True)

@classmethod
def log_exception_debug(cls, *logger_data: Any) -> None:
log_exc_info = False
exc_info = sys.exc_info()
exc_str = str(exc_info[1])
if exc_str not in cls._seen_logs:
# log the trace only on the first time
cls._seen_logs[exc_str] = exc_info
log_exc_info = True
log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info)
first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
log.debug(*(logger_data or ["Exception occurred"]), exc_info=first_time)

@classmethod
def log_warning_once(cls, *args: Any) -> None:
msg_str = args[0]
if msg_str not in cls._seen_logs:
cls._seen_logs[msg_str] = 0
logger = log.warning
else:
logger = log.debug
cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug
logger(*args)

@classmethod
def log_exception_once(cls, exc: Exception, *args: Any) -> None:
msg_str = args[0]
if msg_str not in cls._seen_logs:
cls._seen_logs[msg_str] = 0
logger = log.warning
else:
logger = log.debug
cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug
logger(*args, exc_info=exc)
12 changes: 3 additions & 9 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DNSText,
)
from .._exceptions import IncomingDecodeError
from .._logger import log
from .._logger import _mark_seen, log
from .._utils.time import current_time_millis
from ..const import (
_FLAGS_QR_MASK,
Expand All @@ -63,7 +63,7 @@
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError)


_seen_logs: dict[str, int | tuple] = {}
_seen_logs: dict[str, None] = {}
_str = str
_int = int

Expand Down Expand Up @@ -182,13 +182,7 @@ def _initial_parse(self) -> None:

@classmethod
def _log_exception_debug(cls, *logger_data: Any) -> None:
log_exc_info = False
exc_info = sys.exc_info()
exc_str = str(exc_info[1])
if exc_str not in _seen_logs:
# log the trace only on the first time
_seen_logs[exc_str] = exc_info
log_exc_info = True
log_exc_info = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info)

def answers(self) -> list[DNSRecord]:
Expand Down
39 changes: 39 additions & 0 deletions tests/benchmarks/test_mark_seen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Benchmark for _logger._mark_seen."""

from __future__ import annotations

from pytest_codspeed import BenchmarkFixture

from zeroconf._logger import _MAX_SEEN_LOGS, _mark_seen


def test_mark_seen_hit(benchmark: BenchmarkFixture) -> None:
"""Benchmark the cache-hit path (same key repeated)."""
seen: dict[str, None] = {"warm": None}

@benchmark
def _hit() -> None:
for _ in range(1000):
_mark_seen(seen, "warm")


def test_mark_seen_fill(benchmark: BenchmarkFixture) -> None:
"""Benchmark filling from empty up to the cap (no evictions)."""
keys = [f"key-{i}" for i in range(_MAX_SEEN_LOGS)]

@benchmark
def _fill() -> None:
seen: dict[str, None] = {}
for k in keys:
_mark_seen(seen, k)


def test_mark_seen_churn(benchmark: BenchmarkFixture) -> None:
"""Benchmark sustained eviction (every call past the cap drops oldest)."""
keys = [f"churn-{i}" for i in range(_MAX_SEEN_LOGS * 4)]

@benchmark
def _churn() -> None:
seen: dict[str, None] = {}
for k in keys:
_mark_seen(seen, k)
87 changes: 82 additions & 5 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import logging
from unittest.mock import call, patch

from zeroconf._logger import QuietLogger, set_logger_level_if_unset
from zeroconf import _logger
from zeroconf._logger import _MAX_SEEN_LOGS, QuietLogger, _mark_seen, set_logger_level_if_unset


def test_loading_logger():
Expand All @@ -25,7 +26,7 @@ def test_loading_logger():

def test_log_warning_once():
"""Test we only log with warning level once."""
QuietLogger._seen_logs = {}
_logger._seen_logs.clear()
quiet_logger = QuietLogger()
with (
patch("zeroconf._logger.log.warning") as mock_log_warning,
Expand All @@ -48,7 +49,7 @@ def test_log_warning_once():

def test_log_exception_warning():
"""Test we only log with warning level once."""
QuietLogger._seen_logs = {}
_logger._seen_logs.clear()
quiet_logger = QuietLogger()
with (
patch("zeroconf._logger.log.warning") as mock_log_warning,
Expand All @@ -71,7 +72,7 @@ def test_log_exception_warning():

def test_llog_exception_debug():
"""Test we only log with a trace once."""
QuietLogger._seen_logs = {}
_logger._seen_logs.clear()
quiet_logger = QuietLogger()
with patch("zeroconf._logger.log.debug") as mock_log_debug:
quiet_logger.log_exception_debug("the exception")
Expand All @@ -84,9 +85,85 @@ def test_llog_exception_debug():
assert mock_log_debug.mock_calls == [call("the exception", exc_info=False)]


def test_mark_seen_absorbs_runtime_error_during_eviction() -> None:
"""Concurrent mutation can make ``iter(seen)`` raise ``RuntimeError``.

Free-threaded (3.14t) and multi-instance sync callers share
``_seen_logs``; if another thread mutates it between ``iter()``
and ``next()`` the iterator raises ``RuntimeError``.
``_mark_seen`` must absorb that and still insert the new key.
"""

class RacyDict(dict[str, None]):
def __iter__(self): # type: ignore[override]
raise RuntimeError("dictionary changed size during iteration")

seen: dict[str, None] = RacyDict()
for i in range(_MAX_SEEN_LOGS):
seen[f"k-{i}"] = None
assert _mark_seen(seen, "new-key") is True
assert "new-key" in seen


def test_mark_seen_drains_drift_above_cap() -> None:
"""``_mark_seen`` drains a drifted-over-cap dict back to the cap.

Concurrent inserts on the free-threaded build can leave the dict
transiently above ``_MAX_SEEN_LOGS`` (e.g. two threads both passed
the ``len < cap`` check and both inserted). The next non-racing
call must drain the accumulated overshoot, not just evict one
entry — otherwise the cap silently inflates with thread count.
"""
seen: dict[str, None] = {}
drift = 10
for i in range(_MAX_SEEN_LOGS + drift):
seen[f"k-{i}"] = None
assert len(seen) == _MAX_SEEN_LOGS + drift
assert _mark_seen(seen, "new-key") is True
assert len(seen) == _MAX_SEEN_LOGS
assert "new-key" in seen
for i in range(drift + 1):
assert f"k-{i}" not in seen


def test_mark_seen_drains_drift_on_hit_path() -> None:
"""``_mark_seen`` drains drift even when ``key`` is already cached.

A hit-heavy workload after a contention burst (e.g. the same
exception text deduplicated repeatedly) must still correct the
overshoot — otherwise the dict can sit permanently above the cap
until a miss happens to come along.
"""
seen: dict[str, None] = {}
drift = 10
for i in range(_MAX_SEEN_LOGS + drift):
seen[f"k-{i}"] = None
# Hit on a non-oldest key — survives the drift drain.
hit_key = f"k-{_MAX_SEEN_LOGS}"
assert _mark_seen(seen, hit_key) is False
assert len(seen) == _MAX_SEEN_LOGS
assert hit_key in seen
for i in range(drift):
assert f"k-{i}" not in seen


def test_seen_logs_is_bounded() -> None:
"""``_seen_logs`` stays at the cap and evicts oldest-first (FIFO)."""
_logger._seen_logs.clear()
overflow = 5
with patch("zeroconf._logger.log.warning"), patch("zeroconf._logger.log.debug"):
for i in range(_MAX_SEEN_LOGS + overflow):
QuietLogger.log_warning_once(f"warning-{i}")
assert len(_logger._seen_logs) == _MAX_SEEN_LOGS
for i in range(overflow):
assert f"warning-{i}" not in _logger._seen_logs
for i in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow):
assert f"warning-{i}" in _logger._seen_logs


def test_log_exception_once():
"""Test we only log with warning level once."""
QuietLogger._seen_logs = {}
_logger._seen_logs.clear()
quiet_logger = QuietLogger()
exc = Exception()
with (
Expand Down
34 changes: 34 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import zeroconf as r
from zeroconf import DNSHinfo, DNSIncoming, DNSText, const, current_time_millis
from zeroconf._logger import _MAX_SEEN_LOGS
from zeroconf._protocol import incoming as _incoming_module

from . import has_working_ipv6

Expand Down Expand Up @@ -962,6 +964,38 @@ def test_dns_compression_generic_failure(caplog):
assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text


def test_seen_logs_is_bounded():
"""Corrupt packets from varying peers fill ``_seen_logs`` exactly to the cap."""
packet = (
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01"
b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00"
b"\x00\x01\x00\x04\xc0\xa8\xd0\x06"
)
overflow = 5
_incoming_module._seen_logs.clear()
# Snapshot the actual key the parser inserted per port. This is whatever
# ``str(exc_info()[1])`` produces today — the test stays agnostic to the
# exception text format so a future normalization of the message (see
# the discussion on #1714) doesn't break the assertions, while still
# pinning that the parser exception path actually entered the dict.
Comment thread
bdraco marked this conversation as resolved.
keys_per_port: list[str] = []
for port in range(_MAX_SEEN_LOGS + overflow):
r.DNSIncoming(packet, ("1.2.3.4", port))
keys_per_port.append(next(reversed(_incoming_module._seen_logs)))
# Bound is hit exactly.
assert len(_incoming_module._seen_logs) == _MAX_SEEN_LOGS
# Each port produced a distinct dedup key — a regression that dropped
# the per-packet-varying component (e.g. self.source) from the exception
# text would collapse all 517 calls to one key and fail this.
assert len(set(keys_per_port)) == _MAX_SEEN_LOGS + overflow
# FIFO eviction by key identity (no substring matching on the message
# format): the earliest ports' keys are gone, the latest ports' remain.
for port in range(overflow):
assert keys_per_port[port] not in _incoming_module._seen_logs
for port in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow):
assert keys_per_port[port] in _incoming_module._seen_logs


def test_label_length_attack():
"""Test our wire parser does not loop forever when the name exceeds 253 chars."""
packet = (
Expand Down
Loading