Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from ._dns cimport (

cdef object _UNIQUE_RECORD_TYPES
cdef object _TYPE_PTR
cdef object _ONE_SECOND

cdef _remove_key(cython.dict cache, object key, DNSRecord record)

Expand All @@ -22,9 +23,19 @@ cdef class DNSCache:
cdef public cython.dict cache
cdef public cython.dict service_cache

@cython.locals(
records=cython.dict,
record=DNSRecord,
)
cdef _async_all_by_details(self, object name, object type_, object class_)

cdef _async_add(self, DNSRecord record)

cdef _async_remove(self, DNSRecord record)

@cython.locals(
record=DNSRecord,
)
cdef _async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now)

cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_)
47 changes: 39 additions & 8 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

import itertools
from typing import Dict, Iterable, Iterator, List, Optional, Union, cast
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

from ._dns import (
DNSAddress,
Expand All @@ -34,13 +34,15 @@
DNSText,
)
from ._utils.time import current_time_millis
from .const import _TYPE_PTR
from .const import _ONE_SECOND, _TYPE_PTR

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
_DNSRecord = DNSRecord
_str = str
_float = float
_int = int


def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
Expand Down Expand Up @@ -134,19 +136,29 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
return None
return store.get(entry)

def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterator[DNSRecord]:
def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]:
"""Gets all matching entries by details.

This function is not threadsafe and must be called from
This function is not thread-safe and must be called from
the event loop.
"""
return self._async_all_by_details(name, type_, class_)

def _async_all_by_details(self, name: _str, type_: int, class_: int) -> List[DNSRecord]:
"""Gets all matching entries by details.

This function is not thread-safe and must be called from
the event loop.
"""
key = name.lower()
records = self.cache.get(key)
matches: List[DNSRecord] = []
if records is None:
return
for entry in records:
if _dns_record_matches(entry, key, type_, class_):
yield entry
return matches
for record in records:
if _dns_record_matches(record, key, type_, class_):
matches.append(record)
return matches

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
"""Returns a dict of entries whose key matches the name.
Expand Down Expand Up @@ -226,6 +238,25 @@ def names(self) -> List[str]:
"""Return a copy of the list of current cache names."""
return list(self.cache)

def async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

def _async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
answers_rrset = set(answers)
for name, type_, class_ in unique_types:
for record in self._async_all_by_details(name, type_, class_):
if (now - record.created > _ONE_SECOND) and record not in answers_rrset:
# Expire in 1s
record.set_created_ttl(now, 1)


def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:
return key == record.key and type_ == record.type and class_ == record.class_
17 changes: 1 addition & 16 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Expand Down Expand Up @@ -421,7 +420,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now)
self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)

if updates:
self.async_updates(now, updates)
Expand Down Expand Up @@ -451,20 +450,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
if updates:
self.async_updates_complete(new)

def _async_mark_unique_cached_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float
) -> None:
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
answers_rrset = set(answers)
for name, type_, class_ in unique_types:
for entry in self.cache.async_all_by_details(name, type_, class_):
if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset:
# Expire in 1s
entry.set_created_ttl(now, 1)

def async_add_listener(
self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
) -> None:
Expand Down