Skip to content

Commit 0ad3f37

Browse files
authored
fix: bound DNSCache record count to prevent unbounded LAN-driven growth (#1718)
1 parent 0ff3c6b commit 0ad3f37

5 files changed

Lines changed: 374 additions & 31 deletions

File tree

src/zeroconf/_cache.pxd

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cdef object _UNIQUE_RECORD_TYPES
1919
cdef unsigned int _TYPE_PTR
2020
cdef cython.uint _ONE_SECOND
2121
cdef unsigned int _MIN_SCHEDULED_RECORD_EXPIRATION
22+
cdef unsigned int _MAX_CACHE_RECORDS
2223

2324

2425
@cython.locals(record_cache=dict)
@@ -31,6 +32,7 @@ cdef class DNSCache:
3132
cdef public cython.dict service_cache
3233
cdef public list _expire_heap
3334
cdef public dict _expirations
35+
cdef public unsigned int _total_records
3436

3537
cpdef bint async_add_records(self, object entries)
3638

@@ -60,10 +62,17 @@ cdef class DNSCache:
6062
service_store=cython.dict,
6163
service_record=DNSService,
6264
when=object,
63-
new=bint
65+
new=bint,
66+
is_new=bint
6467
)
6568
cdef bint _async_add(self, DNSRecord record)
6669

70+
@cython.locals(record=DNSRecord, when_record=tuple)
71+
cdef void _async_evict_oldest(self)
72+
73+
@cython.locals(expire_heap_len="unsigned int")
74+
cdef void _maybe_rebuild_heap(self)
75+
6776
@cython.locals(service_record=DNSService)
6877
cdef void _async_remove(self, DNSRecord record)
6978

src/zeroconf/_cache.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
DNSText,
3838
)
3939
from ._utils.time import current_time_millis
40-
from .const import _ONE_SECOND, _TYPE_PTR
40+
from .const import _MAX_CACHE_RECORDS, _ONE_SECOND, _TYPE_PTR
4141

4242
_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
4343
_UniqueRecordsType = DNSAddress | DNSHinfo | DNSPointer | DNSText | DNSService
@@ -72,6 +72,7 @@ def __init__(self) -> None:
7272
self._expire_heap: list[tuple[float, DNSRecord]] = []
7373
self._expirations: dict[DNSRecord, float] = {}
7474
self.service_cache: _DNSRecordCacheType = {}
75+
self._total_records: int = 0
7576

7677
# Functions prefixed with async_ are NOT threadsafe and must
7778
# be run in the event loop.
@@ -89,15 +90,34 @@ def _async_add(self, record: _DNSRecord) -> bool:
8990
# replaces any existing records that are __eq__ to each other which
9091
# removes the risk that accessing the cache from the wrong
9192
# direction would return the old incorrect entry.
92-
if (store := self.cache.get(record.key)) is None:
93+
store = self.cache.get(record.key)
94+
is_new = store is None or record not in store
95+
# Bound total cache size; evict closest-to-expiration entry to
96+
# make room before inserting a new record. Prevents a LAN-local
97+
# flood of unique-name records from growing the cache without
98+
# bound (RFC 6762 §10 advisory caching, defense-in-depth).
99+
if is_new and self._total_records >= _MAX_CACHE_RECORDS:
100+
self._async_evict_oldest()
101+
# The victim may have been the last record under
102+
# ``record.key``, in which case ``_remove_key`` deleted
103+
# the bucket. Re-fetch before creating below.
104+
store = self.cache.get(record.key)
105+
if store is None:
93106
store = self.cache[record.key] = {}
94-
new = record not in store and not isinstance(record, DNSNsec)
107+
new = is_new and not isinstance(record, DNSNsec)
108+
if is_new:
109+
self._total_records += 1
95110
store[record] = record
96111
when = record.created + (record.ttl * 1000)
97112
if self._expirations.get(record) != when:
98-
# Avoid adding duplicates to the heap
99113
heappush(self._expire_heap, (when, record))
100114
self._expirations[record] = when
115+
# Re-adds of an existing record with a new TTL push a fresh
116+
# entry but leave the prior tuple behind as stale, so a peer
117+
# that just replays cached records can grow ``_expire_heap``
118+
# without ever tripping the cap. Rebuild when stale entries
119+
# dominate.
120+
self._maybe_rebuild_heap()
101121

102122
if isinstance(record, DNSService):
103123
service_record = record
@@ -106,6 +126,28 @@ def _async_add(self, record: _DNSRecord) -> bool:
106126
service_store[service_record] = service_record
107127
return new
108128

129+
def _async_evict_oldest(self) -> None:
130+
"""Drop the closest-to-expiration record to make room for a new one."""
131+
while self._expire_heap:
132+
when_record = heappop(self._expire_heap)
133+
record = when_record[1]
134+
if self._expirations.get(record) != when_record[0]:
135+
continue
136+
self._async_remove(record)
137+
return
138+
139+
def _maybe_rebuild_heap(self) -> None:
140+
"""Rebuild ``_expire_heap`` when stale entries dominate live ones."""
141+
expire_heap_len = len(self._expire_heap)
142+
if (
143+
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
144+
and expire_heap_len > len(self._expirations) * 2
145+
):
146+
self._expire_heap = [
147+
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
148+
]
149+
heapify(self._expire_heap)
150+
109151
def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
110152
"""Add multiple records.
111153
@@ -129,6 +171,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
129171
_remove_key(self.service_cache, service_record.server_key, service_record)
130172
_remove_key(self.cache, record.key, record)
131173
self._expirations.pop(record, None)
174+
self._total_records -= 1
132175

133176
def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
134177
"""Remove multiple records.
@@ -145,43 +188,23 @@ def async_expire(self, now: _float) -> list[DNSRecord]:
145188
146189
:param now: The current time in milliseconds.
147190
"""
148-
if not (expire_heap_len := len(self._expire_heap)):
191+
if not self._expire_heap:
149192
return []
150193

151194
expired: list[DNSRecord] = []
152-
# Find any expired records and add them to the to-delete list
153195
while self._expire_heap:
154196
when_record = self._expire_heap[0]
155197
when = when_record[0]
156198
if when > now:
157199
break
158200
heappop(self._expire_heap)
159-
# Check if the record hasn't been re-added to the heap
160-
# with a different expiration time as it will be removed
161-
# later when it reaches the top of the heap and its
162-
# expiration time is met.
201+
# Skip entries left behind by a TTL re-add; the live tuple is
202+
# later in the heap and will be removed when it reaches the top.
163203
record = when_record[1]
164204
if self._expirations.get(record) == when:
165205
expired.append(record)
166206

167-
# If the expiration heap grows larger than the number expirations
168-
# times two, we clean it up to avoid keeping expired entries in
169-
# the heap and consuming memory. We guard this with a minimum
170-
# threshold to avoid cleaning up the heap too often when there are
171-
# only a few scheduled expirations.
172-
if (
173-
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
174-
and expire_heap_len > len(self._expirations) * 2
175-
):
176-
# Remove any expired entries from the expiration heap
177-
# that do not match the expiration time in the expirations
178-
# as it means the record has been re-added to the heap
179-
# with a different expiration time.
180-
self._expire_heap = [
181-
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
182-
]
183-
heapify(self._expire_heap)
184-
207+
self._maybe_rebuild_heap()
185208
self.async_remove_records(expired)
186209
return expired
187210

src/zeroconf/const.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
# level of rate limit and safe guards so we use 1/4 of the recommended value
6060
_DNS_PTR_MIN_TTL = 1125
6161

62+
# Upper bound on the number of records the DNSCache will hold before it
63+
# starts evicting the closest-to-expiration entry to make room for new
64+
# arrivals. Bounds the memory a malicious LAN peer can force the cache
65+
# to retain by multicasting many unique-name records.
66+
_MAX_CACHE_RECORDS = 10000
67+
6268
_DNS_PACKET_HEADER_LEN = 12
6369

6470
_MAX_MSG_TYPICAL = 1460 # unused
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Benchmark for the DNSCache record-count bound + overflow eviction."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterator
6+
from itertools import count
7+
8+
from pytest_codspeed import BenchmarkFixture
9+
10+
from zeroconf import DNSAddress, DNSCache, current_time_millis
11+
from zeroconf.const import _CLASS_IN, _MAX_CACHE_RECORDS, _TYPE_A
12+
13+
14+
def _make_records(count_: int, now: float, prefix: str = "bench") -> list[DNSAddress]:
15+
return [
16+
DNSAddress(
17+
f"{prefix}-{i}.local.",
18+
_TYPE_A,
19+
_CLASS_IN,
20+
120,
21+
bytes(((i >> 24) & 0xFF, (i >> 16) & 0xFF, (i >> 8) & 0xFF, i & 0xFF)),
22+
created=now + i,
23+
)
24+
for i in range(count_)
25+
]
26+
27+
28+
def _unbounded_records(now: float, prefix: str = "evict") -> Iterator[DNSAddress]:
29+
"""Unbounded generator of unique-name DNSAddress records."""
30+
for i in count():
31+
yield DNSAddress(
32+
f"{prefix}-{i}.local.",
33+
_TYPE_A,
34+
_CLASS_IN,
35+
120,
36+
bytes(((i >> 24) & 0xFF, (i >> 16) & 0xFF, (i >> 8) & 0xFF, i & 0xFF)),
37+
created=now + i,
38+
)
39+
40+
41+
def test_cache_add_below_cap(benchmark: BenchmarkFixture) -> None:
42+
"""Adding records while the cache is well below the cap (no eviction)."""
43+
now = current_time_millis()
44+
records = _make_records(1000, now)
45+
46+
@benchmark
47+
def _add() -> None:
48+
cache = DNSCache()
49+
cache.async_add_records(records)
50+
51+
52+
def test_cache_add_at_cap_evicts(benchmark: BenchmarkFixture) -> None:
53+
"""Steady-state add at the cap: every measured insert forces one eviction.
54+
55+
Pre-fills the cache to ``_MAX_CACHE_RECORDS`` outside the timed body so
56+
only the eviction-path adds are measured. Each benchmark iteration
57+
pulls one fresh unique record from an unbounded generator, keeping the
58+
cache permanently at the cap. The generator avoids the iteration-count
59+
cap that a pre-built pool would impose for very fast operations.
60+
"""
61+
now = current_time_millis()
62+
cache = DNSCache()
63+
cache.async_add_records(_make_records(_MAX_CACHE_RECORDS, now, prefix="fill"))
64+
pool = _unbounded_records(now + _MAX_CACHE_RECORDS)
65+
66+
@benchmark
67+
def _evict_one() -> None:
68+
cache.async_add_records([next(pool)])

0 commit comments

Comments
 (0)