Skip to content
Merged
11 changes: 10 additions & 1 deletion src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cdef object _UNIQUE_RECORD_TYPES
cdef unsigned int _TYPE_PTR
cdef cython.uint _ONE_SECOND
cdef unsigned int _MIN_SCHEDULED_RECORD_EXPIRATION
cdef unsigned int _MAX_CACHE_RECORDS


@cython.locals(record_cache=dict)
Expand All @@ -31,6 +32,7 @@ cdef class DNSCache:
cdef public cython.dict service_cache
cdef public list _expire_heap
cdef public dict _expirations
cdef public unsigned int _total_records

cpdef bint async_add_records(self, object entries)

Expand Down Expand Up @@ -60,10 +62,17 @@ cdef class DNSCache:
service_store=cython.dict,
service_record=DNSService,
when=object,
new=bint
new=bint,
is_new=bint
)
cdef bint _async_add(self, DNSRecord record)

@cython.locals(record=DNSRecord, when_record=tuple)
cdef void _async_evict_oldest(self)

@cython.locals(expire_heap_len="unsigned int")
cdef void _maybe_rebuild_heap(self)

@cython.locals(service_record=DNSService)
cdef void _async_remove(self, DNSRecord record)

Expand Down
79 changes: 51 additions & 28 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DNSText,
)
from ._utils.time import current_time_millis
from .const import _ONE_SECOND, _TYPE_PTR
from .const import _MAX_CACHE_RECORDS, _ONE_SECOND, _TYPE_PTR

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
_UniqueRecordsType = DNSAddress | DNSHinfo | DNSPointer | DNSText | DNSService
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(self) -> None:
self._expire_heap: list[tuple[float, DNSRecord]] = []
self._expirations: dict[DNSRecord, float] = {}
self.service_cache: _DNSRecordCacheType = {}
self._total_records: int = 0

# Functions prefixed with async_ are NOT threadsafe and must
# be run in the event loop.
Expand All @@ -89,15 +90,34 @@ def _async_add(self, record: _DNSRecord) -> bool:
# replaces any existing records that are __eq__ to each other which
# removes the risk that accessing the cache from the wrong
# direction would return the old incorrect entry.
if (store := self.cache.get(record.key)) is None:
store = self.cache.get(record.key)
is_new = store is None or record not in store
# Bound total cache size; evict closest-to-expiration entry to
# make room before inserting a new record. Prevents a LAN-local
# flood of unique-name records from growing the cache without
# bound (RFC 6762 §10 advisory caching, defense-in-depth).
if is_new and self._total_records >= _MAX_CACHE_RECORDS:
self._async_evict_oldest()
Comment thread
bdraco marked this conversation as resolved.
# The victim may have been the last record under
# ``record.key``, in which case ``_remove_key`` deleted
# the bucket. Re-fetch before creating below.
store = self.cache.get(record.key)
if store is None:
store = self.cache[record.key] = {}
new = record not in store and not isinstance(record, DNSNsec)
new = is_new and not isinstance(record, DNSNsec)
if is_new:
self._total_records += 1
store[record] = record
Comment thread
bdraco marked this conversation as resolved.
when = record.created + (record.ttl * 1000)
if self._expirations.get(record) != when:
# Avoid adding duplicates to the heap
heappush(self._expire_heap, (when, record))
self._expirations[record] = when
# Re-adds of an existing record with a new TTL push a fresh
# entry but leave the prior tuple behind as stale, so a peer
# that just replays cached records can grow ``_expire_heap``
# without ever tripping the cap. Rebuild when stale entries
# dominate.
self._maybe_rebuild_heap()

if isinstance(record, DNSService):
service_record = record
Expand All @@ -106,6 +126,28 @@ def _async_add(self, record: _DNSRecord) -> bool:
service_store[service_record] = service_record
return new

def _async_evict_oldest(self) -> None:
"""Drop the closest-to-expiration record to make room for a new one."""
while self._expire_heap:
when_record = heappop(self._expire_heap)
record = when_record[1]
if self._expirations.get(record) != when_record[0]:
continue
self._async_remove(record)
return

def _maybe_rebuild_heap(self) -> None:
"""Rebuild ``_expire_heap`` when stale entries dominate live ones."""
expire_heap_len = len(self._expire_heap)
if (
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
self._expire_heap = [
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
]
heapify(self._expire_heap)

def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
"""Add multiple records.

Expand All @@ -129,6 +171,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
_remove_key(self.service_cache, service_record.server_key, service_record)
_remove_key(self.cache, record.key, record)
self._expirations.pop(record, None)
self._total_records -= 1

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
"""Remove multiple records.
Expand All @@ -145,43 +188,23 @@ def async_expire(self, now: _float) -> list[DNSRecord]:

:param now: The current time in milliseconds.
"""
if not (expire_heap_len := len(self._expire_heap)):
if not self._expire_heap:
return []

expired: list[DNSRecord] = []
# Find any expired records and add them to the to-delete list
while self._expire_heap:
when_record = self._expire_heap[0]
when = when_record[0]
if when > now:
break
heappop(self._expire_heap)
# Check if the record hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
# Skip entries left behind by a TTL re-add; the live tuple is
# later in the heap and will be removed when it reaches the top.
record = when_record[1]
if self._expirations.get(record) == when:
expired.append(record)

# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the record has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
]
heapify(self._expire_heap)

self._maybe_rebuild_heap()
self.async_remove_records(expired)
return expired

Expand Down
6 changes: 6 additions & 0 deletions src/zeroconf/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
# level of rate limit and safe guards so we use 1/4 of the recommended value
_DNS_PTR_MIN_TTL = 1125

# Upper bound on the number of records the DNSCache will hold before it
# starts evicting the closest-to-expiration entry to make room for new
# arrivals. Bounds the memory a malicious LAN peer can force the cache
# to retain by multicasting many unique-name records.
_MAX_CACHE_RECORDS = 10000

_DNS_PACKET_HEADER_LEN = 12

_MAX_MSG_TYPICAL = 1460 # unused
Expand Down
68 changes: 68 additions & 0 deletions tests/benchmarks/test_cache_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Benchmark for the DNSCache record-count bound + overflow eviction."""

from __future__ import annotations

from collections.abc import Iterator
from itertools import count

from pytest_codspeed import BenchmarkFixture

from zeroconf import DNSAddress, DNSCache, current_time_millis
from zeroconf.const import _CLASS_IN, _MAX_CACHE_RECORDS, _TYPE_A


def _make_records(count_: int, now: float, prefix: str = "bench") -> list[DNSAddress]:
return [
DNSAddress(
f"{prefix}-{i}.local.",
_TYPE_A,
_CLASS_IN,
120,
bytes(((i >> 24) & 0xFF, (i >> 16) & 0xFF, (i >> 8) & 0xFF, i & 0xFF)),
created=now + i,
)
for i in range(count_)
]


def _unbounded_records(now: float, prefix: str = "evict") -> Iterator[DNSAddress]:
"""Unbounded generator of unique-name DNSAddress records."""
for i in count():
yield DNSAddress(
f"{prefix}-{i}.local.",
_TYPE_A,
_CLASS_IN,
120,
bytes(((i >> 24) & 0xFF, (i >> 16) & 0xFF, (i >> 8) & 0xFF, i & 0xFF)),
created=now + i,
)


def test_cache_add_below_cap(benchmark: BenchmarkFixture) -> None:
"""Adding records while the cache is well below the cap (no eviction)."""
now = current_time_millis()
records = _make_records(1000, now)

@benchmark
def _add() -> None:
cache = DNSCache()
cache.async_add_records(records)


def test_cache_add_at_cap_evicts(benchmark: BenchmarkFixture) -> None:
"""Steady-state add at the cap: every measured insert forces one eviction.

Pre-fills the cache to ``_MAX_CACHE_RECORDS`` outside the timed body so
only the eviction-path adds are measured. Each benchmark iteration
pulls one fresh unique record from an unbounded generator, keeping the
cache permanently at the cap. The generator avoids the iteration-count
cap that a pre-built pool would impose for very fast operations.
"""
now = current_time_millis()
cache = DNSCache()
cache.async_add_records(_make_records(_MAX_CACHE_RECORDS, now, prefix="fill"))
pool = _unbounded_records(now + _MAX_CACHE_RECORDS)

@benchmark
def _evict_one() -> None:
cache.async_add_records([next(pool)])
Loading
Loading