diff --git a/tests/test_cache.py b/tests/test_cache.py index 7c75866bc..4b3a8a18e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -84,16 +84,61 @@ def test_cache_empty_multiple_calls(self): assert 'a' not in cache.cache +class TestDNSAsyncCacheAPI(unittest.TestCase): + def test_async_get_unique(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.async_get_unique(record1) == record1 + assert cache.async_get_unique(record2) == record2 + + def test_async_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) + + def test_async_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_server('ab')) == set([record1, record2]) + assert set(cache.async_entries_with_server('AB')) == set([record1, record2]) + + def test_async_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_name('irrelevant')) == set([record1, record2]) + assert set(cache.async_entries_with_name('Irrelevant')) == set([record1, record2]) + + # These functions have been seen in other projects so # we try to maintain a stable API for all the threadsafe getters class TestDNSCacheAPI(unittest.TestCase): def test_get(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + record3 = r.DNSAddress('a', const._TYPE_AAAA, const._CLASS_IN, 1, b'ipv6') cache = r.DNSCache() - cache.async_add_records([record1, record2]) + cache.async_add_records([record1, record2, record3]) assert cache.get(record1) == record1 assert cache.get(record2) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_AAAA, const._CLASS_IN)) == record3 + assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None def test_get_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 92d95fa2f..ddd8ffa47 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -862,17 +862,17 @@ async def test_cache_flush_bit(): for new_record in new_records: assert new_record.unique is True - original_a_record = zc.cache.get(a_record) + original_a_record = zc.cache.async_get_unique(a_record) # Do the run within 1s to verify the original record is not going to be expired out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) for answer in new_records: out.add_answer_at_time(answer, 0) for packet in out.packets(): zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) - assert zc.cache.get(a_record) is original_a_record + assert zc.cache.async_get_unique(a_record) is original_a_record assert original_a_record.ttl != 1 for record in new_records: - assert zc.cache.get(record) is not None + assert zc.cache.async_get_unique(record) is not None original_a_record.created = current_time_millis() - 1001 @@ -884,9 +884,9 @@ async def test_cache_flush_bit(): zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) assert original_a_record.ttl == 1 for record in new_records: - assert zc.cache.get(record) is not None + assert zc.cache.async_get_unique(record) is not None - cached_records = [zc.cache.get(record) for record in new_records] + cached_records = [zc.cache.async_get_unique(record) for record in new_records] for record in cached_records: record.created = current_time_millis() - 1001 @@ -901,7 +901,7 @@ async def test_cache_flush_bit(): for record in cached_records: assert record.ttl == 1 - for entry in zc.cache.get_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + for entry in zc.cache.async_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): if entry.address == fresh_address: assert entry.ttl > 1 else: diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 12e4aa649..24b6a2337 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -20,13 +20,24 @@ USA """ -from typing import Dict, Iterable, List, Optional, cast - -from ._dns import DNSEntry, DNSPointer, DNSRecord, DNSService +import itertools +from typing import Dict, Iterable, Iterator, List, Optional, Union, cast + +from ._dns import ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSPointer, + DNSRecord, + DNSService, + DNSText, + dns_entry_matches, +) from ._utils.time import current_time_millis from .const import _TYPE_PTR - +_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) +_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] @@ -90,16 +101,50 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: for entry in entries: self._async_remove(entry) - def async_expire(self, now: float) -> Iterable[DNSRecord]: + def async_expire(self, now: float) -> List[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. """ - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self._async_remove(record) - yield record + expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)] + self.async_remove_records(expired) + return expired + + def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + """Gets a unique entry by key. Will return None if there is no + matching entry. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(entry.key, {}).get(entry) + + def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: + """Gets all matching entries by details. + + This function is not threadsafe and must be called from + the event loop. + """ + key = name.lower() + for entry in self.cache.get(key, []): + if dns_entry_matches(entry, key, type_, class_): + yield entry + + def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the name. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(name.lower(), {}) + + def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the server. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.service_cache.get(name.lower(), {}) # The below functions are threadsafe and do not need to be run in the # event loop, however they all make copies so they significantly @@ -108,7 +153,9 @@ def async_expire(self, now: float) -> Iterable[DNSRecord]: def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no matching entry.""" - for cached_entry in reversed(self.entries_with_name(entry.key)): + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self.cache.get(entry.key, {}).get(entry) + for cached_entry in reversed(list(self.cache.get(entry.key, []))): if entry.__eq__(cached_entry): return cached_entry return None @@ -125,12 +172,18 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco Use get_all_by_details instead. """ - return self.get(DNSEntry(name, type_, class_)) + key = name.lower() + for cached_entry in reversed(list(self.cache.get(key, []))): + if dns_entry_matches(cached_entry, key, type_, class_): + return cached_entry + return None def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] + key = name.lower() + return [ + entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) + ] def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" diff --git a/zeroconf/_core.py b/zeroconf/_core.py index e5c92ce3f..a7910591a 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() - self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now))) + self.zc.record_manager.async_updates(now, self.zc.cache.async_expire(now)) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 66892d52f..e656bc519 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -49,6 +49,10 @@ from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import +def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: + return key == record.key and type_ == record.type and class_ == record.class_ + + class DNSEntry: """A DNS entry""" @@ -66,12 +70,7 @@ def _entry_tuple(self) -> Tuple[str, int, int]: def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" - return ( - self.key == other.key - and self.type == other.type - and self.class_ == other.class_ - and isinstance(other, DNSEntry) - ) + return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) @staticmethod def get_class_(class_: int) -> str: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 03495b4e5..66b8862f2 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,9 +21,9 @@ """ import itertools -from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from ._cache import DNSCache +from ._cache import DNSCache, _UniqueRecordsType from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing @@ -141,7 +141,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: SHOULD instead multicast the response so as to keep all the peer caches up to date """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and self._now - maybe_entry.created < 1000) @@ -317,7 +317,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 unique_types.add((record.name, record.type, record.class_)) - maybe_entry = self.cache.get(record) + maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) @@ -372,7 +372,7 @@ def _async_mark_unique_cached_records_older_than_1s_to_expire( # invalid, and marked to expire from the cache in one second. answers_rrset = DNSRRSet(answers) for name, type_, class_ in unique_types: - for entry in self.cache.get_all_by_details(name, type_, class_): + for entry in self.cache.async_all_by_details(name, type_, class_): if (now - entry.created > 1000) and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 80fdd5f73..306b69990 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -27,6 +27,7 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from .._cache import _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing @@ -316,7 +317,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def _process_record_update(self, now: float, record: DNSRecord) -> None: + def _async_process_record_update(self, now: float, record: DNSRecord) -> None: """Process a single record update from a batch of updates.""" expired = record.is_expired(now) @@ -340,12 +341,12 @@ def _process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.get(record): + if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): return if isinstance(record, DNSAddress): # Iterate through the DNSCache and callback any services that use this address - for service in self.zc.cache.entries_with_server(record.name): + for service in self.zc.cache.async_entries_with_server(record.name): type_ = self._record_matching_type(service) if type_: self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) @@ -367,7 +368,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSReco This method will be run in the event loop. """ for record in records: - self._process_record_update(now, record) + self._async_process_record_update(now, record) def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers.