diff --git a/src/zeroconf/_cache.pxd b/src/zeroconf/_cache.pxd index ea436be70..07eeb8079 100644 --- a/src/zeroconf/_cache.pxd +++ b/src/zeroconf/_cache.pxd @@ -13,6 +13,7 @@ from ._dns cimport ( cdef object _UNIQUE_RECORD_TYPES cdef object _TYPE_PTR +cdef object _ONE_SECOND cdef _remove_key(cython.dict cache, object key, DNSRecord record) @@ -22,9 +23,19 @@ cdef class DNSCache: cdef public cython.dict cache cdef public cython.dict service_cache + @cython.locals( + records=cython.dict, + record=DNSRecord, + ) + cdef _async_all_by_details(self, object name, object type_, object class_) + cdef _async_add(self, DNSRecord record) cdef _async_remove(self, DNSRecord record) + @cython.locals( + record=DNSRecord, + ) + cdef _async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now) cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_) diff --git a/src/zeroconf/_cache.py b/src/zeroconf/_cache.py index 49f92f911..505143b3f 100644 --- a/src/zeroconf/_cache.py +++ b/src/zeroconf/_cache.py @@ -21,7 +21,7 @@ """ import itertools -from typing import Dict, Iterable, Iterator, List, Optional, Union, cast +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast from ._dns import ( DNSAddress, @@ -34,13 +34,15 @@ DNSText, ) from ._utils.time import current_time_millis -from .const import _TYPE_PTR +from .const import _ONE_SECOND, _TYPE_PTR _UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) _UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] _DNSRecord = DNSRecord _str = str +_float = float +_int = int def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None: @@ -134,19 +136,29 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: return None return store.get(entry) - def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterator[DNSRecord]: + def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]: """Gets all matching entries by details. - This function is not threadsafe and must be called from + This function is not thread-safe and must be called from + the event loop. + """ + return self._async_all_by_details(name, type_, class_) + + def _async_all_by_details(self, name: _str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details. + + This function is not thread-safe and must be called from the event loop. """ key = name.lower() records = self.cache.get(key) + matches: List[DNSRecord] = [] if records is None: - return - for entry in records: - if _dns_record_matches(entry, key, type_, class_): - yield entry + return matches + for record in records: + if _dns_record_matches(record, key, type_, class_): + matches.append(record) + return matches def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: """Returns a dict of entries whose key matches the name. @@ -226,6 +238,25 @@ def names(self) -> List[str]: """Return a copy of the list of current cache names.""" return list(self.cache) + def async_mark_unique_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float + ) -> None: + self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now) + + def _async_mark_unique_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float + ) -> None: + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + answers_rrset = set(answers) + for name, type_, class_ in unique_types: + for record in self._async_all_by_details(name, type_, class_): + if (now - record.created > _ONE_SECOND) and record not in answers_rrset: + # Expire in 1s + record.set_created_ttl(now, 1) + def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool: return key == record.key and type_ == record.type and class_ == record.class_ diff --git a/src/zeroconf/_handlers.py b/src/zeroconf/_handlers.py index 38a1b034b..fb5ed7c71 100644 --- a/src/zeroconf/_handlers.py +++ b/src/zeroconf/_handlers.py @@ -26,7 +26,6 @@ from typing import ( TYPE_CHECKING, Dict, - Iterable, List, NamedTuple, Optional, @@ -421,7 +420,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes.add(record) if unique_types: - self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) + self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now) if updates: self.async_updates(now, updates) @@ -451,20 +450,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: if updates: self.async_updates_complete(new) - def _async_mark_unique_cached_records_older_than_1s_to_expire( - self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float - ) -> None: - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - answers_rrset = set(answers) - for name, type_, class_ in unique_types: - for entry in self.cache.async_all_by_details(name, type_, class_): - if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset: - # Expire in 1s - entry.set_created_ttl(now, 1) - def async_add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] ) -> None: