Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cdc6ada
feat: implement heapq for tracking cache expire times
bdraco Jan 8, 2025
965271a
chore: make __lt__ work
bdraco Jan 8, 2025
2bb8479
fix: fixes
bdraco Jan 8, 2025
46dde62
Update tests/services/test_browser.py
bdraco Jan 8, 2025
867e26a
Update tests/services/test_info.py
bdraco Jan 8, 2025
8ecab53
Apply suggestions from code review
bdraco Jan 8, 2025
6212ab7
Update tests/test_handlers.py
bdraco Jan 8, 2025
a401c21
fix: fixes
bdraco Jan 8, 2025
3af4e44
Merge remote-tracking branch 'origin/heapq' into heapq
bdraco Jan 8, 2025
8f9f47d
fix: safer
bdraco Jan 8, 2025
2986304
fix: safer
bdraco Jan 8, 2025
2f0b372
fix: safer
bdraco Jan 8, 2025
ead3379
fix: safer
bdraco Jan 8, 2025
04095e4
fix: fixes
bdraco Jan 8, 2025
473a02b
fix: make it clear its not public
bdraco Jan 8, 2025
1daa996
fix: revert
bdraco Jan 8, 2025
1881db9
fix: revert
bdraco Jan 8, 2025
fe9957c
fix: remove reset ttl
bdraco Jan 8, 2025
ca5a6e4
fix: reorder to pop first
bdraco Jan 8, 2025
3379b88
fix: correct record update logic
bdraco Jan 8, 2025
e59bd24
Update src/zeroconf/_handlers/record_manager.py
bdraco Jan 8, 2025
56bf5a7
chore: add coverage for cleanup
bdraco Jan 8, 2025
70d6d05
Merge remote-tracking branch 'origin/master' into heapq
bdraco Jan 8, 2025
885a439
fix: fixes
bdraco Jan 8, 2025
77a4574
Merge remote-tracking branch 'origin/master' into heapq
bdraco Jan 8, 2025
89d0ba8
Merge branch 'master' into heapq
bdraco Jan 8, 2025
f6f238b
Update src/zeroconf/_cache.py
bdraco Jan 8, 2025
4e5163e
chore: coverage
bdraco Jan 8, 2025
b979e46
Merge remote-tracking branch 'origin/heapq' into heapq
bdraco Jan 8, 2025
e3ba478
chore: add comment
bdraco Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ from ._dns cimport (
DNSText,
)

cdef object heappop
cdef object heappush
cdef object heapify

cdef object _UNIQUE_RECORD_TYPES
cdef unsigned int _TYPE_PTR
cdef cython.uint _ONE_SECOND
cdef unsigned int _MIN_SCHEDULED_RECORD_EXPIRATION

@cython.locals(
record_cache=dict,
Expand All @@ -26,6 +30,8 @@ cdef class DNSCache:

cdef public cython.dict cache
cdef public cython.dict service_cache
cdef public list _expire_heap
cdef public dict _expirations

cpdef bint async_add_records(self, object entries)

Expand Down Expand Up @@ -65,7 +71,8 @@ cdef class DNSCache:

@cython.locals(
store=cython.dict,
service_record=DNSService
service_record=DNSService,
when=object
)
cdef bint _async_add(self, DNSRecord record)

Expand Down Expand Up @@ -95,3 +102,10 @@ cdef class DNSCache:
now=double
)
cpdef current_entry_with_name_and_alias(self, str name, str alias)

cpdef void _async_set_created_ttl(
self,
DNSRecord record,
double now,
cython.float ttl
)
62 changes: 60 additions & 2 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
USA
"""

from heapq import heapify, heappop, heappush
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

from ._dns import (
Expand All @@ -43,6 +44,11 @@
_float = float
_int = int

# The minimum number of scheduled record expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_RECORD_EXPIRATION = 100


def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
"""Remove a key from a DNSRecord cache
Expand All @@ -60,6 +66,8 @@ class DNSCache:

def __init__(self) -> None:
self.cache: _DNSRecordCacheType = {}
self._expire_heap: List[Tuple[float, DNSRecord]] = []
self._expirations: Dict[DNSRecord, float] = {}
self.service_cache: _DNSRecordCacheType = {}

# Functions prefixed with async_ are NOT threadsafe and must
Expand All @@ -81,6 +89,12 @@ def _async_add(self, record: _DNSRecord) -> bool:
store = self.cache.setdefault(record.key, {})
new = record not in store and not isinstance(record, DNSNsec)
store[record] = record
when = record.created + (record.ttl * 1000)
if self._expirations.get(record) != when:
# Avoid adding duplicates to the heap
heappush(self._expire_heap, (when, record))
self._expirations[record] = when

if isinstance(record, DNSService):
Comment thread
bdraco marked this conversation as resolved.
service_record = record
self.service_cache.setdefault(record.server_key, {})[service_record] = service_record
Expand Down Expand Up @@ -108,6 +122,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
service_record = record
_remove_key(self.service_cache, service_record.server_key, service_record)
_remove_key(self.cache, record.key, record)
self._expirations.pop(record, None)

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
Comment thread
bdraco marked this conversation as resolved.
"""Remove multiple records.
Expand All @@ -121,8 +136,44 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
"""Purge expired entries from the cache.

This function must be run in from event loop.

Comment thread
bdraco marked this conversation as resolved.
:param now: The current time in milliseconds.
"""
expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]
if not (expire_heap_len := len(self._expire_heap)):
return []

expired: List[DNSRecord] = []
# Find any expired records and add them to the to-delete list
while self._expire_heap:
when, record = self._expire_heap[0]
if when > now:
break
heappop(self._expire_heap)
# Check if the record hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(record) == when:
expired.append(record)

# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
Comment thread
bdraco marked this conversation as resolved.
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
Comment thread
bdraco marked this conversation as resolved.
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the record has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
]
heapify(self._expire_heap)

self.async_remove_records(expired)
return expired

Expand Down Expand Up @@ -256,4 +307,11 @@ def async_mark_unique_records_older_than_1s_to_expire(
created_double = record.created
if (now - created_double > _ONE_SECOND) and record not in answers_rrset:
# Expire in 1s
record.set_created_ttl(now, 1)
self._async_set_created_ttl(record, now, 1)

def _async_set_created_ttl(self, record: DNSRecord, now: _float, ttl: _float) -> None:
"""Set the created time and ttl of a record."""
# It would be better if we made a copy instead of mutating the record
# in place, but records currently don't have a copy method.
record._set_created_ttl(now, ttl)
Comment thread
bdraco marked this conversation as resolved.
self._async_add(record)
4 changes: 1 addition & 3 deletions src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ cdef class DNSRecord(DNSEntry):

cpdef bint is_recent(self, double now)

cpdef reset_ttl(self, DNSRecord other)

cpdef set_created_ttl(self, double now, cython.float ttl)
cdef _set_created_ttl(self, double now, cython.float ttl)

cdef class DNSAddress(DNSRecord):

Expand Down
12 changes: 6 additions & 6 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
"""Abstract method"""
raise AbstractMethodException

def __lt__(self, other: "DNSRecord") -> bool:
return self.ttl < other.ttl

def suppressed_by(self, msg: "DNSIncoming") -> bool:
"""Returns true if any answer in a message can suffice for the
information held in this record."""
Expand Down Expand Up @@ -222,13 +225,10 @@ def is_recent(self, now: _float) -> bool:
"""Returns true if the record more than one quarter of its TTL remaining."""
return self.created + (_RECENT_TIME_MS * self.ttl) > now

def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def]
"""Sets this record's TTL and created time to that of
another record."""
self.set_created_ttl(other.created, other.ttl)

def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
def _set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
"""Set the created and ttl of a record."""
Comment thread
bdraco marked this conversation as resolved.
# It would be better if we made a copy instead of mutating the record
# in place, but records currently don't have a copy method.
self.created = created
self.ttl = ttl

Expand Down
12 changes: 5 additions & 7 deletions src/zeroconf/_handlers/record_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
record,
_DNS_PTR_MIN_TTL,
)
record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)
# Safe because the record is never in the cache yet
record._set_created_ttl(record.created, _DNS_PTR_MIN_TTL)

if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
unique_types.add((record.name, record_type, record.class_))
Expand All @@ -113,13 +114,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:

maybe_entry = cache.async_get_unique(record)
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
if record_type in _ADDRESS_RECORD_TYPES:
address_adds.append(record)
else:
if record_type in _ADDRESS_RECORD_TYPES:
address_adds.append(record)
else:
other_adds.append(record)
other_adds.append(record)
rec_update = RecordUpdate.__new__(RecordUpdate)
rec_update._fast_init(record, maybe_entry)
updates.append(rec_update)
Expand Down
4 changes: 2 additions & 2 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,9 +1509,9 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
)
# Force the ttl to be 1 second
now = current_time_millis()
for cache_record in zc.cache.cache.values():
for cache_record in list(zc.cache.cache.values()):
for record in cache_record:
record.set_created_ttl(now, 1)
zc.cache._async_set_created_ttl(record, now, 1)

time.sleep(0.3)
info.port = 400
Expand Down
2 changes: 1 addition & 1 deletion tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_service_info_rejects_expired_records(self):
ttl,
b"\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==",
)
expired_record.set_created_ttl(1000, 1)
zc.cache._async_set_created_ttl(expired_record, 1000, 1)
info.async_update_records(zc, now, [RecordUpdate(expired_record, None)])
assert info.properties[b"ci"] == b"2"
zc.close()
Expand Down
123 changes: 123 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import logging
import unittest
import unittest.mock
from heapq import heapify, heappop

import pytest

import zeroconf as r
from zeroconf import const
Expand Down Expand Up @@ -358,3 +361,123 @@ def test_async_get_unique_returns_newest_record():
assert record is record2
record = cache.async_get_unique(record2)
assert record is record2


@pytest.mark.asyncio
async def test_cache_heap_cleanup() -> None:
"""Test that the heap gets cleaned up when there are many old expirations."""
cache = r.DNSCache()
# The heap should not be cleaned up when there are less than 100 expiration changes
min_records_to_cleanup = 100
now = r.current_time_millis()
name = "heap.local."
ttl_seconds = 100
ttl_millis = ttl_seconds * 1000

for i in range(min_records_to_cleanup):
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
cache.async_add_records([record])

assert len(cache._expire_heap) == min_records_to_cleanup
assert len(cache.async_entries_with_name(name)) == 1

# Now that we reached the minimum number of cookies to cleanup,
# add one more cookie to trigger the cleanup
record = r.DNSAddress(
name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + min_records_to_cleanup
)
expected_expire_time = record.created + ttl_millis
cache.async_add_records([record])
assert len(cache.async_entries_with_name(name)) == 1
entry = next(iter(cache.async_entries_with_name(name)))
assert (entry.created + ttl_millis) == expected_expire_time
assert entry is record

# Verify that the heap has been cleaned up
assert len(cache.async_entries_with_name(name)) == 1
cache.async_expire(now)

heap_copy = cache._expire_heap.copy()
heapify(heap_copy)
# Ensure heap order is maintained
assert cache._expire_heap == heap_copy

# The heap should have been cleaned up
assert len(cache._expire_heap) == 1
assert len(cache.async_entries_with_name(name)) == 1

entry = next(iter(cache.async_entries_with_name(name)))
assert entry is record

assert (entry.created + ttl_millis) == expected_expire_time

cache.async_expire(expected_expire_time)
assert not cache.async_entries_with_name(name), cache._expire_heap


@pytest.mark.asyncio
async def test_cache_heap_multi_name_cleanup() -> None:
"""Test cleanup with multiple names."""
cache = r.DNSCache()
# The heap should not be cleaned up when there are less than 100 expiration changes
min_records_to_cleanup = 100
now = r.current_time_millis()
name = "heap.local."
name2 = "heap2.local."
ttl_seconds = 100
ttl_millis = ttl_seconds * 1000

for i in range(min_records_to_cleanup):
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
cache.async_add_records([record])
expected_expire_time = record.created + ttl_millis

for i in range(5):
record = r.DNSAddress(
name2, const._TYPE_A, const._CLASS_IN, ttl_seconds, bytes((i,)), created=now + i
)
cache.async_add_records([record])

assert len(cache._expire_heap) == min_records_to_cleanup + 5
assert len(cache.async_entries_with_name(name)) == 1
assert len(cache.async_entries_with_name(name2)) == 5

cache.async_expire(now)
# The heap and expirations should have been cleaned up
assert len(cache._expire_heap) == 1 + 5
assert len(cache._expirations) == 1 + 5

cache.async_expire(expected_expire_time)
assert not cache.async_entries_with_name(name), cache._expire_heap


@pytest.mark.asyncio
async def test_cache_heap_pops_order() -> None:
"""Test cache heap is popped in order."""
cache = r.DNSCache()
# The heap should not be cleaned up when there are less than 100 expiration changes
min_records_to_cleanup = 100
now = r.current_time_millis()
name = "heap.local."
name2 = "heap2.local."
ttl_seconds = 100

for i in range(min_records_to_cleanup):
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
cache.async_add_records([record])

for i in range(5):
record = r.DNSAddress(
name2, const._TYPE_A, const._CLASS_IN, ttl_seconds, bytes((i,)), created=now + i
)
cache.async_add_records([record])

assert len(cache._expire_heap) == min_records_to_cleanup + 5
assert len(cache.async_entries_with_name(name)) == 1
assert len(cache.async_entries_with_name(name2)) == 5

start_ts = 0.0
while cache._expire_heap:
ts, _ = heappop(cache._expire_heap)
assert ts >= start_ts
start_ts = ts
Loading