Skip to content

Commit 95561e2

Browse files
authored
fix: bound _seen_logs and stop retaining exc_info (#1717)
1 parent 65b22cb commit 95561e2

5 files changed

Lines changed: 214 additions & 47 deletions

File tree

src/zeroconf/_logger.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import logging
2727
import sys
28-
from typing import Any, ClassVar, cast
28+
from typing import Any
2929

3030
log = logging.getLogger(__name__.split(".", maxsplit=1)[0])
3131
log.addHandler(logging.NullHandler())
@@ -39,50 +39,73 @@ def set_logger_level_if_unset() -> None:
3939
set_logger_level_if_unset()
4040

4141

42-
class QuietLogger:
43-
_seen_logs: ClassVar[dict[str, int | tuple]] = {}
42+
_MAX_SEEN_LOGS = 512
43+
_seen_logs: dict[str, None] = {}
44+
45+
46+
def _evict_oldest(seen: dict[str, None]) -> bool:
47+
"""Pop the oldest entry from ``seen``; return False if it raced.
48+
49+
Individual dict ops (``pop`` with a default, ``next``) are atomic
50+
on the free-threaded build, but the compound ``iter`` → ``next``
51+
used to pick the FIFO victim can raise ``RuntimeError`` if
52+
another thread mutates the dict between the two ops. The caller
53+
breaks its drain loop on False so concurrent mutation can't make
54+
it spin.
55+
"""
56+
try:
57+
seen.pop(next(iter(seen)), None)
58+
except (RuntimeError, StopIteration):
59+
return False
60+
return True
61+
62+
63+
def _mark_seen(seen: dict[str, None], key: str) -> bool:
64+
"""Record ``key`` in ``seen`` and return True if it was newly added.
65+
66+
Bounds the dict so callers passing attacker-influenced keys (peer
67+
addresses, packet offsets) cannot grow it without bound. Evicts
68+
the oldest entries on overflow (dict preserves insertion order on
69+
Python 3.7+), so ``_MAX_SEEN_LOGS`` is a recency window.
70+
71+
The dict is shared across all ``Zeroconf`` instances in the
72+
process; on the free-threaded build (3.14t) and under multi-
73+
instance sync use, callers can race the ``len < cap`` check and
74+
both insert, leaving the dict transiently above the cap. The
75+
drain loop runs on every call (steady-state-at-cap hits are a
76+
single ``len`` + compare past the membership check because the
77+
helper short-circuits) so a contention burst is corrected by the
78+
next caller regardless of whether it's a hit or a miss.
79+
"""
80+
inserting = key not in seen
81+
# Hit (``inserting`` is False): drain only if drifted above cap.
82+
# Miss (``inserting`` is True): drain to ``cap - 1`` to make room
83+
# for the new key. Bool subtracts as 0/1 to pick the right limit.
84+
while len(seen) > _MAX_SEEN_LOGS - inserting and _evict_oldest(seen):
85+
pass
86+
if inserting:
87+
seen[key] = None
88+
return inserting
89+
4490

91+
class QuietLogger:
4592
@classmethod
4693
def log_exception_warning(cls, *logger_data: Any) -> None:
47-
exc_info = sys.exc_info()
48-
exc_str = str(exc_info[1])
49-
if exc_str not in cls._seen_logs:
50-
# log at warning level the first time this is seen
51-
cls._seen_logs[exc_str] = exc_info
52-
logger = log.warning
53-
else:
54-
logger = log.debug
94+
first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
95+
logger = log.warning if first_time else log.debug
5596
logger(*(logger_data or ["Exception occurred"]), exc_info=True)
5697

5798
@classmethod
5899
def log_exception_debug(cls, *logger_data: Any) -> None:
59-
log_exc_info = False
60-
exc_info = sys.exc_info()
61-
exc_str = str(exc_info[1])
62-
if exc_str not in cls._seen_logs:
63-
# log the trace only on the first time
64-
cls._seen_logs[exc_str] = exc_info
65-
log_exc_info = True
66-
log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info)
100+
first_time = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
101+
log.debug(*(logger_data or ["Exception occurred"]), exc_info=first_time)
67102

68103
@classmethod
69104
def log_warning_once(cls, *args: Any) -> None:
70-
msg_str = args[0]
71-
if msg_str not in cls._seen_logs:
72-
cls._seen_logs[msg_str] = 0
73-
logger = log.warning
74-
else:
75-
logger = log.debug
76-
cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
105+
logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug
77106
logger(*args)
78107

79108
@classmethod
80109
def log_exception_once(cls, exc: Exception, *args: Any) -> None:
81-
msg_str = args[0]
82-
if msg_str not in cls._seen_logs:
83-
cls._seen_logs[msg_str] = 0
84-
logger = log.warning
85-
else:
86-
logger = log.debug
87-
cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
110+
logger = log.warning if _mark_seen(_seen_logs, args[0]) else log.debug
88111
logger(*args, exc_info=exc)

src/zeroconf/_protocol/incoming.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
DNSText,
3838
)
3939
from .._exceptions import IncomingDecodeError
40-
from .._logger import log
40+
from .._logger import _mark_seen, log
4141
from .._utils.time import current_time_millis
4242
from ..const import (
4343
_FLAGS_QR_MASK,
@@ -63,7 +63,7 @@
6363
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError)
6464

6565

66-
_seen_logs: dict[str, int | tuple] = {}
66+
_seen_logs: dict[str, None] = {}
6767
_str = str
6868
_int = int
6969

@@ -182,13 +182,7 @@ def _initial_parse(self) -> None:
182182

183183
@classmethod
184184
def _log_exception_debug(cls, *logger_data: Any) -> None:
185-
log_exc_info = False
186-
exc_info = sys.exc_info()
187-
exc_str = str(exc_info[1])
188-
if exc_str not in _seen_logs:
189-
# log the trace only on the first time
190-
_seen_logs[exc_str] = exc_info
191-
log_exc_info = True
185+
log_exc_info = _mark_seen(_seen_logs, str(sys.exc_info()[1]))
192186
log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info)
193187

194188
def answers(self) -> list[DNSRecord]:

tests/benchmarks/test_mark_seen.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Benchmark for _logger._mark_seen."""
2+
3+
from __future__ import annotations
4+
5+
from pytest_codspeed import BenchmarkFixture
6+
7+
from zeroconf._logger import _MAX_SEEN_LOGS, _mark_seen
8+
9+
10+
def test_mark_seen_hit(benchmark: BenchmarkFixture) -> None:
11+
"""Benchmark the cache-hit path (same key repeated)."""
12+
seen: dict[str, None] = {"warm": None}
13+
14+
@benchmark
15+
def _hit() -> None:
16+
for _ in range(1000):
17+
_mark_seen(seen, "warm")
18+
19+
20+
def test_mark_seen_fill(benchmark: BenchmarkFixture) -> None:
21+
"""Benchmark filling from empty up to the cap (no evictions)."""
22+
keys = [f"key-{i}" for i in range(_MAX_SEEN_LOGS)]
23+
24+
@benchmark
25+
def _fill() -> None:
26+
seen: dict[str, None] = {}
27+
for k in keys:
28+
_mark_seen(seen, k)
29+
30+
31+
def test_mark_seen_churn(benchmark: BenchmarkFixture) -> None:
32+
"""Benchmark sustained eviction (every call past the cap drops oldest)."""
33+
keys = [f"churn-{i}" for i in range(_MAX_SEEN_LOGS * 4)]
34+
35+
@benchmark
36+
def _churn() -> None:
37+
seen: dict[str, None] = {}
38+
for k in keys:
39+
_mark_seen(seen, k)

tests/test_logger.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import logging
66
from unittest.mock import call, patch
77

8-
from zeroconf._logger import QuietLogger, set_logger_level_if_unset
8+
from zeroconf import _logger
9+
from zeroconf._logger import _MAX_SEEN_LOGS, QuietLogger, _mark_seen, set_logger_level_if_unset
910

1011

1112
def test_loading_logger():
@@ -25,7 +26,7 @@ def test_loading_logger():
2526

2627
def test_log_warning_once():
2728
"""Test we only log with warning level once."""
28-
QuietLogger._seen_logs = {}
29+
_logger._seen_logs.clear()
2930
quiet_logger = QuietLogger()
3031
with (
3132
patch("zeroconf._logger.log.warning") as mock_log_warning,
@@ -48,7 +49,7 @@ def test_log_warning_once():
4849

4950
def test_log_exception_warning():
5051
"""Test we only log with warning level once."""
51-
QuietLogger._seen_logs = {}
52+
_logger._seen_logs.clear()
5253
quiet_logger = QuietLogger()
5354
with (
5455
patch("zeroconf._logger.log.warning") as mock_log_warning,
@@ -71,7 +72,7 @@ def test_log_exception_warning():
7172

7273
def test_llog_exception_debug():
7374
"""Test we only log with a trace once."""
74-
QuietLogger._seen_logs = {}
75+
_logger._seen_logs.clear()
7576
quiet_logger = QuietLogger()
7677
with patch("zeroconf._logger.log.debug") as mock_log_debug:
7778
quiet_logger.log_exception_debug("the exception")
@@ -84,9 +85,85 @@ def test_llog_exception_debug():
8485
assert mock_log_debug.mock_calls == [call("the exception", exc_info=False)]
8586

8687

88+
def test_mark_seen_absorbs_runtime_error_during_eviction() -> None:
89+
"""Concurrent mutation can make ``iter(seen)`` raise ``RuntimeError``.
90+
91+
Free-threaded (3.14t) and multi-instance sync callers share
92+
``_seen_logs``; if another thread mutates it between ``iter()``
93+
and ``next()`` the iterator raises ``RuntimeError``.
94+
``_mark_seen`` must absorb that and still insert the new key.
95+
"""
96+
97+
class RacyDict(dict[str, None]):
98+
def __iter__(self): # type: ignore[override]
99+
raise RuntimeError("dictionary changed size during iteration")
100+
101+
seen: dict[str, None] = RacyDict()
102+
for i in range(_MAX_SEEN_LOGS):
103+
seen[f"k-{i}"] = None
104+
assert _mark_seen(seen, "new-key") is True
105+
assert "new-key" in seen
106+
107+
108+
def test_mark_seen_drains_drift_above_cap() -> None:
109+
"""``_mark_seen`` drains a drifted-over-cap dict back to the cap.
110+
111+
Concurrent inserts on the free-threaded build can leave the dict
112+
transiently above ``_MAX_SEEN_LOGS`` (e.g. two threads both passed
113+
the ``len < cap`` check and both inserted). The next non-racing
114+
call must drain the accumulated overshoot, not just evict one
115+
entry — otherwise the cap silently inflates with thread count.
116+
"""
117+
seen: dict[str, None] = {}
118+
drift = 10
119+
for i in range(_MAX_SEEN_LOGS + drift):
120+
seen[f"k-{i}"] = None
121+
assert len(seen) == _MAX_SEEN_LOGS + drift
122+
assert _mark_seen(seen, "new-key") is True
123+
assert len(seen) == _MAX_SEEN_LOGS
124+
assert "new-key" in seen
125+
for i in range(drift + 1):
126+
assert f"k-{i}" not in seen
127+
128+
129+
def test_mark_seen_drains_drift_on_hit_path() -> None:
130+
"""``_mark_seen`` drains drift even when ``key`` is already cached.
131+
132+
A hit-heavy workload after a contention burst (e.g. the same
133+
exception text deduplicated repeatedly) must still correct the
134+
overshoot — otherwise the dict can sit permanently above the cap
135+
until a miss happens to come along.
136+
"""
137+
seen: dict[str, None] = {}
138+
drift = 10
139+
for i in range(_MAX_SEEN_LOGS + drift):
140+
seen[f"k-{i}"] = None
141+
# Hit on a non-oldest key — survives the drift drain.
142+
hit_key = f"k-{_MAX_SEEN_LOGS}"
143+
assert _mark_seen(seen, hit_key) is False
144+
assert len(seen) == _MAX_SEEN_LOGS
145+
assert hit_key in seen
146+
for i in range(drift):
147+
assert f"k-{i}" not in seen
148+
149+
150+
def test_seen_logs_is_bounded() -> None:
151+
"""``_seen_logs`` stays at the cap and evicts oldest-first (FIFO)."""
152+
_logger._seen_logs.clear()
153+
overflow = 5
154+
with patch("zeroconf._logger.log.warning"), patch("zeroconf._logger.log.debug"):
155+
for i in range(_MAX_SEEN_LOGS + overflow):
156+
QuietLogger.log_warning_once(f"warning-{i}")
157+
assert len(_logger._seen_logs) == _MAX_SEEN_LOGS
158+
for i in range(overflow):
159+
assert f"warning-{i}" not in _logger._seen_logs
160+
for i in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow):
161+
assert f"warning-{i}" in _logger._seen_logs
162+
163+
87164
def test_log_exception_once():
88165
"""Test we only log with warning level once."""
89-
QuietLogger._seen_logs = {}
166+
_logger._seen_logs.clear()
90167
quiet_logger = QuietLogger()
91168
exc = Exception()
92169
with (

tests/test_protocol.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import zeroconf as r
1616
from zeroconf import DNSHinfo, DNSIncoming, DNSText, const, current_time_millis
17+
from zeroconf._logger import _MAX_SEEN_LOGS
18+
from zeroconf._protocol import incoming as _incoming_module
1719

1820
from . import has_working_ipv6
1921

@@ -962,6 +964,38 @@ def test_dns_compression_generic_failure(caplog):
962964
assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text
963965

964966

967+
def test_seen_logs_is_bounded():
968+
"""Corrupt packets from varying peers fill ``_seen_logs`` exactly to the cap."""
969+
packet = (
970+
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01"
971+
b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00"
972+
b"\x00\x01\x00\x04\xc0\xa8\xd0\x06"
973+
)
974+
overflow = 5
975+
_incoming_module._seen_logs.clear()
976+
# Snapshot the actual key the parser inserted per port. This is whatever
977+
# ``str(exc_info()[1])`` produces today — the test stays agnostic to the
978+
# exception text format so a future normalization of the message (see
979+
# the discussion on #1714) doesn't break the assertions, while still
980+
# pinning that the parser exception path actually entered the dict.
981+
keys_per_port: list[str] = []
982+
for port in range(_MAX_SEEN_LOGS + overflow):
983+
r.DNSIncoming(packet, ("1.2.3.4", port))
984+
keys_per_port.append(next(reversed(_incoming_module._seen_logs)))
985+
# Bound is hit exactly.
986+
assert len(_incoming_module._seen_logs) == _MAX_SEEN_LOGS
987+
# Each port produced a distinct dedup key — a regression that dropped
988+
# the per-packet-varying component (e.g. self.source) from the exception
989+
# text would collapse all 517 calls to one key and fail this.
990+
assert len(set(keys_per_port)) == _MAX_SEEN_LOGS + overflow
991+
# FIFO eviction by key identity (no substring matching on the message
992+
# format): the earliest ports' keys are gone, the latest ports' remain.
993+
for port in range(overflow):
994+
assert keys_per_port[port] not in _incoming_module._seen_logs
995+
for port in range(_MAX_SEEN_LOGS, _MAX_SEEN_LOGS + overflow):
996+
assert keys_per_port[port] in _incoming_module._seen_logs
997+
998+
965999
def test_label_length_attack():
9661000
"""Test our wire parser does not loop forever when the name exceeds 253 chars."""
9671001
packet = (

0 commit comments

Comments
 (0)