From ff21b97e9b636b1b649da749f53144cee340075f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:55:17 -1000 Subject: [PATCH 01/23] wip --- zeroconf/_cache.py | 75 ++++++++++++++++++++++++++++++---- zeroconf/_handlers.py | 8 ++-- zeroconf/_services/__init__.py | 8 ++-- 3 files changed, 76 insertions(+), 15 deletions(-) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 2e07a7a47..28870cf54 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -20,12 +20,14 @@ USA """ -from typing import Dict, Iterable, List, Optional, cast +from typing import Dict, Iterable, List, Optional, Union, cast -from ._dns import DNSEntry, DNSPointer, DNSRecord, DNSService +from ._dns import DNSAddress, DNSEntry, DNSHinfo, DNSPointer, DNSRecord, DNSService, DNSText from ._utils.time import current_time_millis from .const import _TYPE_PTR +_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) + _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] @@ -101,18 +103,77 @@ def expire(self, now: float) -> Iterable[DNSRecord]: self.remove(record) yield record + def async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry. + + This function is not threadsafe and must be called from + the event loop. + """ + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self._lookup_unique_entry_threadsafe(entry) + return self._async_get(entry) + + def async_get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details. + + This function is not threadsafe and must be called from + the event loop. + """ + match_entry = DNSEntry(name, type_, class_) + return [entry for entry in self.cache.get(match_entry.key, []) if match_entry.__eq__(entry)] + + def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the name. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(name.lower(), {}) + + def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the server. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.service_cache.get(name.lower(), {}) + + def _async_get(self, entry: DNSEntry): + """Search a dict of entries by making a copy of it first. + + This function is not threadsafe and must be called from + the event loop. + """ + for cached_entry in reversed(self.cache.get(entry.key, [])): + if entry.__eq__(cached_entry): + return cached_entry + return None + # The below functions are threadsafe and do not need to be run in the # event loop, however they all make copies so they significantly # inefficent - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Gets an entry by key. Will return None if there is no - matching entry.""" - for cached_entry in reversed(self.entries_with_name(entry.key)): + def _lookup_unique_entry_threadsafe( + self, entry: Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] + ) -> Optional[Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]]: + """Lookup a unique entry threadsafe.""" + return self.cache.get(entry.key, {}).get(entry) + + def _get_threadsafe(self, entry: DNSEntry): + """Search a dict of entries by making a copy of it first.""" + for cached_entry in reversed(list(self.cache.get(entry.key, []))): if entry.__eq__(cached_entry): return cached_entry return None + def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry.""" + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self._lookup_unique_entry_threadsafe(entry) + return self._get_threadsafe(entry) + def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: """Gets the first matching entry by details. Returns None if no entries match. @@ -130,7 +191,7 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] + return [entry for entry in list(self.cache.get(match_entry.key, [])) if match_entry.__eq__(entry)] def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index b5279654f..b46d1be64 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -141,7 +141,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: SHOULD instead multicast the response so as to keep all the peer caches up to date """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get(record) return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get(record) return bool(maybe_entry and self._now - maybe_entry.created < 1000) @@ -320,7 +320,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # Since unique is set, all old records with that name, rrtype, # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. - for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): + for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): if entry == record: updated = False if record.created - entry.created > 1000 and entry not in msg.answers: @@ -328,7 +328,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: entry.set_created_ttl(now, 1) expired = record.is_expired(now) - maybe_entry = self.cache.get(record) + maybe_entry = self.cache.async_get(record) if not expired: if maybe_entry is not None: maybe_entry.reset_ttl(record) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 80fdd5f73..8ab1baafb 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -316,7 +316,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def _process_record_update(self, now: float, record: DNSRecord) -> None: + def _async_process_record_update(self, now: float, record: DNSRecord) -> None: """Process a single record update from a batch of updates.""" expired = record.is_expired(now) @@ -340,12 +340,12 @@ def _process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.get(record): + if expired or self.zc.cache.async_get(record): return if isinstance(record, DNSAddress): # Iterate through the DNSCache and callback any services that use this address - for service in self.zc.cache.entries_with_server(record.name): + for service in self.zc.cache.async_entries_with_server(record.name): type_ = self._record_matching_type(service) if type_: self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) @@ -367,7 +367,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSReco This method will be run in the event loop. """ for record in records: - self._process_record_update(now, record) + self._async_process_record_update(now, record) def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers. From 6386318de7ccd55047eff5b3c45588c649e6ebe2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:02:22 -1000 Subject: [PATCH 02/23] typing --- zeroconf/_cache.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 870c665ea..e9082d892 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -27,8 +27,7 @@ from .const import _TYPE_PTR _UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) - - +_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] @@ -139,7 +138,7 @@ def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: """ return self.service_cache.get(name.lower(), {}) - def _async_get(self, entry: DNSEntry): + def _async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Search a dict of entries by making a copy of it first. This function is not threadsafe and must be called from @@ -154,13 +153,11 @@ def _async_get(self, entry: DNSEntry): # event loop, however they all make copies so they significantly # inefficent - def _lookup_unique_entry_threadsafe( - self, entry: Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] - ) -> Optional[Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]]: + def _lookup_unique_entry_threadsafe(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: """Lookup a unique entry threadsafe.""" return self.cache.get(entry.key, {}).get(entry) - def _get_threadsafe(self, entry: DNSEntry): + def _get_threadsafe(self, entry: DNSEntry) -> Optional[DNSRecord]: """Search a dict of entries by making a copy of it first.""" for cached_entry in reversed(list(self.cache.get(entry.key, []))): if entry.__eq__(cached_entry): From 8102ff0ebc3f9b481f9572119665df9d9ce450fd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:13:38 -1000 Subject: [PATCH 03/23] typing --- tests/test_cache.py | 43 +++++++++++++++++++++++++++++++++++++++++++ zeroconf/_cache.py | 14 +++++++------- zeroconf/_core.py | 2 +- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 7c75866bc..bbef0c15f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -84,6 +84,49 @@ def test_cache_empty_multiple_calls(self): assert 'a' not in cache.cache +class TestDNSAsyncCacheAPI(unittest.TestCase): + def test_async_get(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.async_get(record1) == record1 + assert cache.async_get(record2) == record2 + + def test_async_get_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set( + [record1, record2] + ) + + def test_async_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_server('ab')) == set([record1, record2]) + assert set(cache.async_entries_with_server('AB')) == set([record1, record2]) + + def test_async_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_name('irrelevant')) == set([record1, record2]) + assert set(cache.async_entries_with_name('Irrelevant')) == set([record1, record2]) + + # These functions have been seen in other projects so # we try to maintain a stable API for all the threadsafe getters class TestDNSCacheAPI(unittest.TestCase): diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index e9082d892..b1572847e 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -20,6 +20,8 @@ USA """ +import itertools + from typing import Dict, Iterable, List, Optional, Union, cast from ._dns import DNSAddress, DNSEntry, DNSHinfo, DNSPointer, DNSRecord, DNSService, DNSText @@ -91,16 +93,14 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: for entry in entries: self._async_remove(entry) - def async_expire(self, now: float) -> Iterable[DNSRecord]: + def async_expire(self, now: float) -> List[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. """ - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self._async_remove(record) - yield record + expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)] + self.async_remove_records(expired) + return expired def async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no @@ -144,7 +144,7 @@ def _async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: This function is not threadsafe and must be called from the event loop. """ - for cached_entry in reversed(self.cache.get(entry.key, [])): + for cached_entry in self.cache.get(entry.key, []): if entry.__eq__(cached_entry): return cached_entry return None diff --git a/zeroconf/_core.py b/zeroconf/_core.py index e5c92ce3f..a7910591a 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() - self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now))) + self.zc.record_manager.async_updates(now, self.zc.cache.async_expire(now)) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) From 4a1dde5e61e7a748f339f444ea4644727776a5c7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:20:17 -1000 Subject: [PATCH 04/23] typing --- zeroconf/_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index b1572847e..acf0bd772 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -21,7 +21,6 @@ """ import itertools - from typing import Dict, Iterable, List, Optional, Union, cast from ._dns import DNSAddress, DNSEntry, DNSHinfo, DNSPointer, DNSRecord, DNSService, DNSText From 291ec0fc7c565c4aca005170541860eb84ac9fd8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:21:39 -1000 Subject: [PATCH 05/23] fix unreachable --- zeroconf/_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index acf0bd772..470737135 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -182,7 +182,7 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco Use get_all_by_details instead. """ - return self.get(DNSEntry(name, type_, class_)) + return self._get_threadsafe(DNSEntry(name, type_, class_)) def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" From 587a3aed5532b2ba41f7b150e022894067ff9aaa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:31:10 -1000 Subject: [PATCH 06/23] remove unreachable --- tests/test_cache.py | 6 +++--- zeroconf/_cache.py | 8 +++----- zeroconf/_handlers.py | 6 +++--- zeroconf/_services/__init__.py | 2 +- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index bbef0c15f..f5d415002 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -85,13 +85,13 @@ def test_cache_empty_multiple_calls(self): class TestDNSAsyncCacheAPI(unittest.TestCase): - def test_async_get(self): + def test_async_get_unique(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert cache.async_get(record1) == record1 - assert cache.async_get(record2) == record2 + assert cache.async_get_unique(record1) == record1 + assert cache.async_get_unique(record2) == record2 def test_async_get_all_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 470737135..b4f8dd8ca 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -101,16 +101,14 @@ def async_expire(self, now: float) -> List[DNSRecord]: self.async_remove_records(expired) return expired - def async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Gets an entry by key. Will return None if there is no + def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + """Gets a unique entry by key. Will return None if there is no matching entry. This function is not threadsafe and must be called from the event loop. """ - if isinstance(entry, _UNIQUE_RECORD_TYPES): - return self._lookup_unique_entry_threadsafe(entry) - return self._async_get(entry) + return self._lookup_unique_entry_threadsafe(entry) def async_get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details. diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index e0406f857..a44c9e4c1 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -141,7 +141,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: SHOULD instead multicast the response so as to keep all the peer caches up to date """ - maybe_entry = self._cache.async_get(record) + maybe_entry = self._cache.async_get_unique(record) return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ - maybe_entry = self._cache.async_get(record) + maybe_entry = self._cache.async_get_unique(record) return bool(maybe_entry and self._now - maybe_entry.created < 1000) @@ -328,7 +328,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: entry.set_created_ttl(now, 1) expired = record.is_expired(now) - maybe_entry = self.cache.async_get(record) + maybe_entry = self.cache.async_get_unique(record) if not expired: if maybe_entry is not None: maybe_entry.reset_ttl(record) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 8ab1baafb..7f4a3656f 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -340,7 +340,7 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.async_get(record): + if expired or self.zc.cache.async_get_unique(record): return if isinstance(record, DNSAddress): From 3ccb133693605c437aae9570e7e56a66fe300a72 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:32:31 -1000 Subject: [PATCH 07/23] remove unreachable --- tests/test_cache.py | 1 + zeroconf/_cache.py | 11 ----------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index f5d415002..63e48a581 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -137,6 +137,7 @@ def test_get(self): cache.async_add_records([record1, record2]) assert cache.get(record1) == record1 assert cache.get(record2) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 def test_get_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index b4f8dd8ca..9fd9e0e61 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -135,17 +135,6 @@ def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: """ return self.service_cache.get(name.lower(), {}) - def _async_get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Search a dict of entries by making a copy of it first. - - This function is not threadsafe and must be called from - the event loop. - """ - for cached_entry in self.cache.get(entry.key, []): - if entry.__eq__(cached_entry): - return cached_entry - return None - # The below functions are threadsafe and do not need to be run in the # event loop, however they all make copies so they significantly # inefficent From 4db808d38950deac82ba6f84bcbf44e332520a18 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:39:25 -1000 Subject: [PATCH 08/23] typing --- zeroconf/_handlers.py | 10 +++++----- zeroconf/_services/__init__.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index a44c9e4c1..b581c6675 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,9 +21,9 @@ """ import itertools -from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from ._cache import DNSCache +from ._cache import _UniqueRecordsType, DNSCache from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing @@ -141,7 +141,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: SHOULD instead multicast the response so as to keep all the peer caches up to date """ - maybe_entry = self._cache.async_get_unique(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ - maybe_entry = self._cache.async_get_unique(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and self._now - maybe_entry.created < 1000) @@ -328,7 +328,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: entry.set_created_ttl(now, 1) expired = record.is_expired(now) - maybe_entry = self.cache.async_get_unique(record) + maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) if not expired: if maybe_entry is not None: maybe_entry.reset_ttl(record) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 7f4a3656f..306b69990 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -27,6 +27,7 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from .._cache import _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing @@ -340,7 +341,7 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.async_get_unique(record): + if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): return if isinstance(record, DNSAddress): From 059d1bc721841a759160d61f3c8093494debed42 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 13:42:38 -1000 Subject: [PATCH 09/23] flake8 --- zeroconf/_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index b581c6675..29c9ded91 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -23,7 +23,7 @@ import itertools from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from ._cache import _UniqueRecordsType, DNSCache +from ._cache import DNSCache, _UniqueRecordsType from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing From 97bfd5424f5c98b4bf2585838d98c583706f35b6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:16:20 -1000 Subject: [PATCH 10/23] tweaks --- zeroconf/_cache.py | 38 ++++++++++++++++++++++++-------------- zeroconf/_dns.py | 11 +++++------ zeroconf/_handlers.py | 30 +++++++++++++++++------------- 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 9fd9e0e61..3cf9ff2f9 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -23,7 +23,16 @@ import itertools from typing import Dict, Iterable, List, Optional, Union, cast -from ._dns import DNSAddress, DNSEntry, DNSHinfo, DNSPointer, DNSRecord, DNSService, DNSText +from ._dns import ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSPointer, + DNSRecord, + DNSService, + DNSText, + dns_entry_matches, +) from ._utils.time import current_time_millis from .const import _TYPE_PTR @@ -116,8 +125,8 @@ def async_get_all_by_details(self, name: str, type_: int, class_: int) -> List[D This function is not threadsafe and must be called from the event loop. """ - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.cache.get(match_entry.key, []) if match_entry.__eq__(entry)] + key = name.lower() + return [entry for entry in self.cache.get(key, []) if dns_entry_matches(entry, key, type_, class_)] def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: """Returns a dict of entries whose key matches the name. @@ -143,19 +152,15 @@ def _lookup_unique_entry_threadsafe(self, entry: _UniqueRecordsType) -> Optional """Lookup a unique entry threadsafe.""" return self.cache.get(entry.key, {}).get(entry) - def _get_threadsafe(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Search a dict of entries by making a copy of it first.""" - for cached_entry in reversed(list(self.cache.get(entry.key, []))): - if entry.__eq__(cached_entry): - return cached_entry - return None - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no matching entry.""" if isinstance(entry, _UNIQUE_RECORD_TYPES): return self._lookup_unique_entry_threadsafe(entry) - return self._get_threadsafe(entry) + for cached_entry in reversed(list(self.cache.get(entry.key, []))): + if entry.__eq__(cached_entry): + return cached_entry + return None def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: """Gets the first matching entry by details. Returns None if no entries match. @@ -169,12 +174,17 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco Use get_all_by_details instead. """ - return self._get_threadsafe(DNSEntry(name, type_, class_)) + key = name.lower() + for cached_entry in reversed(list(self.cache.get(key, []))): + if dns_entry_matches(cached_entry, key, type_, class_): + return cached_entry def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in list(self.cache.get(match_entry.key, [])) if match_entry.__eq__(entry)] + key = name.lower() + return [ + entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) + ] def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 5b7fe70fe..1abbb5dae 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -49,6 +49,10 @@ from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import +def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: + return key == record.key and type_ == record.type and class_ == record.class_ + + class DNSEntry: """A DNS entry""" @@ -66,12 +70,7 @@ def _entry_tuple(self) -> Tuple[str, int, int]: def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" - return ( - self.key == other.key - and self.type == other.type - and self.class_ == other.class_ - and isinstance(other, DNSEntry) - ) + return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) @staticmethod def get_class_(class_: int) -> str: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 29c9ded91..99c0330c9 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -312,20 +312,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes: List[DNSRecord] = [] now = msg.now for record in msg.answers: - - updated = True - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False - if record.created - entry.created > 1000 and entry not in msg.answers: - # Expire in 1s - entry.set_created_ttl(now, 1) + updated = self._async_process_unique(record, msg.answers, now) + else: + updated = True expired = record.is_expired(now) maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) @@ -369,6 +359,20 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: self.cache.async_remove_records(removes) self.async_updates_complete() + def _async_process_unique(self, record: DNSRecord, answers: List[DNSRecord], now: float) -> bool: + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + updated = True + for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): + if entry == record: + updated = False + if record.created - entry.created > 1000 and entry not in answers: + # Expire in 1s + entry.set_created_ttl(now, 1) + return updated + def add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] ) -> None: From 44740bdd670964b7143610622374584165cdbfd5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:18:00 -1000 Subject: [PATCH 11/23] mypy --- zeroconf/_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 3cf9ff2f9..f94a323cb 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -178,6 +178,7 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco for cached_entry in reversed(list(self.cache.get(key, []))): if dns_entry_matches(cached_entry, key, type_, class_): return cached_entry + return None def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" From 821bd9eaa578f739422a8303424963570a1e8c1a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:33:24 -1000 Subject: [PATCH 12/23] fixes --- zeroconf/_handlers.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 99c0330c9..64ad08e32 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -313,13 +313,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: now = msg.now for record in msg.answers: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - updated = self._async_process_unique(record, msg.answers, now) - else: - updated = True + self._async_process_unique(record, msg.answers, now) - expired = record.is_expired(now) maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) - if not expired: + if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) else: @@ -327,8 +324,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: address_adds.append(record) else: other_adds.append(record) - if updated: + if not maybe_entry: updates.append(record) + # This is likely a goodbye since the record is + # expired and exists in the cache elif maybe_entry is not None: updates.append(record) removes.append(record) @@ -359,19 +358,15 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: self.cache.async_remove_records(removes) self.async_updates_complete() - def _async_process_unique(self, record: DNSRecord, answers: List[DNSRecord], now: float) -> bool: + def _async_process_unique(self, record: DNSRecord, answers: List[DNSRecord], now: float) -> None: # rfc6762#section-10.2 para 2 # Since unique is set, all old records with that name, rrtype, # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. - updated = True for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False if record.created - entry.created > 1000 and entry not in answers: # Expire in 1s entry.set_created_ttl(now, 1) - return updated def add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] From 66f263eca1feb78c916d0d2f48e1448551273aa2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:38:54 -1000 Subject: [PATCH 13/23] fixes --- zeroconf/_handlers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 64ad08e32..59878fa79 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -324,8 +324,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: address_adds.append(record) else: other_adds.append(record) - if not maybe_entry: - updates.append(record) + updates.append(record) # This is likely a goodbye since the record is # expired and exists in the cache elif maybe_entry is not None: From ca8ae38ab456d857afe992b8579988bd4f24d5a3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:48:34 -1000 Subject: [PATCH 14/23] tweaks --- zeroconf/_dns.py | 7 +++++++ zeroconf/_handlers.py | 22 ++++++++++++---------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 1abbb5dae..e656bc519 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -422,3 +422,10 @@ def suppresses(self, record: DNSRecord) -> bool: self._lookup = {record: record for record in self._records} other = self._lookup.get(record) return bool(other and other.ttl > (record.ttl / 2)) + + def __contains__(self, record: DNSRecord) -> bool: + """Returns true if the rrset contains the record.""" + if self._lookup is None: + # Build the hash table so we can lookup the record independent of the ttl + self._lookup = {record: record for record in self._records} + return record in self._lookup diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 59878fa79..b1b7f5d15 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -311,9 +311,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: other_adds: List[DNSRecord] = [] removes: List[DNSRecord] = [] now = msg.now + answers_rrset = DNSRRSet(msg.answers) for record in msg.answers: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - self._async_process_unique(record, msg.answers, now) + self._async_process_unique(record, answers_rrset, now) maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) if not record.is_expired(now): @@ -331,10 +332,8 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: updates.append(record) removes.append(record) - if not updates and not address_adds and not other_adds and not removes: - return - - self.async_updates(now, updates) + if updates: + self.async_updates(now, updates) # The cache adds must be processed AFTER we trigger # the updates since we compare existing data # with the new data and updating the cache @@ -350,20 +349,23 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # zc.get_service_info will see the cached value # but ONLY after all the record updates have been # processsed. - self.cache.async_add_records(itertools.chain(address_adds, other_adds)) + if other_adds or address_adds: + self.cache.async_add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. - self.cache.async_remove_records(removes) - self.async_updates_complete() + if removes: + self.cache.async_remove_records(removes) + if updates: + self.async_updates_complete() - def _async_process_unique(self, record: DNSRecord, answers: List[DNSRecord], now: float) -> None: + def _async_process_unique(self, record: DNSRecord, answers_rrset: DNSRRSet, now: float) -> None: # rfc6762#section-10.2 para 2 # Since unique is set, all old records with that name, rrtype, # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): - if record.created - entry.created > 1000 and entry not in answers: + if record.created - entry.created > 1000 and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) From 1ed0ae5861446a350eb4d255f41a60842bc63ffc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 14:56:42 -1000 Subject: [PATCH 15/23] test cache miss --- tests/test_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_cache.py b/tests/test_cache.py index 63e48a581..57385242c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -138,6 +138,7 @@ def test_get(self): assert cache.get(record1) == record1 assert cache.get(record2) == record2 assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 + assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None def test_get_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') From 1778ddefb788155f4340eac799a2600ef82cbbb2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 15:07:16 -1000 Subject: [PATCH 16/23] collapse --- tests/test_cache.py | 6 ++---- zeroconf/_cache.py | 16 +++++++--------- zeroconf/_handlers.py | 2 +- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 57385242c..a4d4c979e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -93,14 +93,12 @@ def test_async_get_unique(self): assert cache.async_get_unique(record1) == record1 assert cache.async_get_unique(record2) == record2 - def test_async_get_all_by_details(self): + def test_async_all_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.async_get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set( - [record1, record2] - ) + assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) def test_async_entries_with_server(self): record1 = r.DNSService( diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index f94a323cb..24b6a2337 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -21,7 +21,7 @@ """ import itertools -from typing import Dict, Iterable, List, Optional, Union, cast +from typing import Dict, Iterable, Iterator, List, Optional, Union, cast from ._dns import ( DNSAddress, @@ -117,16 +117,18 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: This function is not threadsafe and must be called from the event loop. """ - return self._lookup_unique_entry_threadsafe(entry) + return self.cache.get(entry.key, {}).get(entry) - def async_get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: """Gets all matching entries by details. This function is not threadsafe and must be called from the event loop. """ key = name.lower() - return [entry for entry in self.cache.get(key, []) if dns_entry_matches(entry, key, type_, class_)] + for entry in self.cache.get(key, []): + if dns_entry_matches(entry, key, type_, class_): + yield entry def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: """Returns a dict of entries whose key matches the name. @@ -148,15 +150,11 @@ def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: # event loop, however they all make copies so they significantly # inefficent - def _lookup_unique_entry_threadsafe(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: - """Lookup a unique entry threadsafe.""" - return self.cache.get(entry.key, {}).get(entry) - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no matching entry.""" if isinstance(entry, _UNIQUE_RECORD_TYPES): - return self._lookup_unique_entry_threadsafe(entry) + return self.cache.get(entry.key, {}).get(entry) for cached_entry in reversed(list(self.cache.get(entry.key, []))): if entry.__eq__(cached_entry): return cached_entry diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index b1b7f5d15..05bad9a99 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -364,7 +364,7 @@ def _async_process_unique(self, record: DNSRecord, answers_rrset: DNSRRSet, now: # Since unique is set, all old records with that name, rrtype, # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. - for entry in self.cache.async_get_all_by_details(record.name, record.type, record.class_): + for entry in self.cache.async_all_by_details(record.name, record.type, record.class_): if record.created - entry.created > 1000 and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) From da0ebb625bee8e54c2070e85c7bdecfe5d53b466 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 15:14:39 -1000 Subject: [PATCH 17/23] tweaks --- zeroconf/_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 05bad9a99..4daf867c8 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -365,7 +365,7 @@ def _async_process_unique(self, record: DNSRecord, answers_rrset: DNSRRSet, now: # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. for entry in self.cache.async_all_by_details(record.name, record.type, record.class_): - if record.created - entry.created > 1000 and entry not in answers_rrset: + if record.created - entry.created > 1000 and entry != record and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) From e3949792b4c4600178e53bbcb15d5a779c10933a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 15:35:31 -1000 Subject: [PATCH 18/23] tweak --- zeroconf/_handlers.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 4daf867c8..af21ba580 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -311,10 +311,11 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: other_adds: List[DNSRecord] = [] removes: List[DNSRecord] = [] now = msg.now - answers_rrset = DNSRRSet(msg.answers) + unique_types: Set[Tuple[str, int, int]] = set() + for record in msg.answers: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - self._async_process_unique(record, answers_rrset, now) + unique_types.add((record.name, record.type, record.class_)) maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) if not record.is_expired(now): @@ -332,6 +333,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: updates.append(record) removes.append(record) + if unique_types: + self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) + if updates: self.async_updates(now, updates) # The cache adds must be processed AFTER we trigger @@ -359,15 +363,19 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: if updates: self.async_updates_complete() - def _async_process_unique(self, record: DNSRecord, answers_rrset: DNSRRSet, now: float) -> None: + def _async_mark_unique_cached_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[str, int, int]], answers: List[DNSRecord], now: float + ) -> None: # rfc6762#section-10.2 para 2 # Since unique is set, all old records with that name, rrtype, # and rrclass that were received more than one second ago are declared # invalid, and marked to expire from the cache in one second. - for entry in self.cache.async_all_by_details(record.name, record.type, record.class_): - if record.created - entry.created > 1000 and entry != record and entry not in answers_rrset: - # Expire in 1s - entry.set_created_ttl(now, 1) + answers_rrset = DNSRRSet(answers) + for name, type_, class_ in unique_types: + for entry in self.cache.async_all_by_details(name, type_, class_): + if now - entry.created > 1000 and entry not in answers_rrset: + # Expire in 1s + entry.set_created_ttl(now, 1) def add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] From 4b3612850706c64dc4cd6c70b5505e3c3ba20e1c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 15:51:09 -1000 Subject: [PATCH 19/23] adjust --- tests/test_handlers.py | 52 ++++++++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 4 ++++ 2 files changed, 56 insertions(+) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index f9e7639ea..d1015501f 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -834,3 +834,55 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): # unregister zc.registry.remove(info) await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_cache_flush_bit(): + """Test that the cache flush bit sets the TTL to one for matching records.""" + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + zc.cache.async_add_records([info.dns_pointer(), a_record, info.dns_text(), info.dns_service()]) + + info.addresses = [socket.inet_aton("10.0.1.5"), socket.inet_aton("10.0.1.6")] + new_records = info.dns_addresses() + for new_record in new_records: + assert new_record.unique is True + + original_a_record = zc.cache.async_get_unique(a_record) + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert zc.cache.async_get_unique(a_record) is original_a_record + assert original_a_record.ttl != 1 + for record in new_records: + assert zc.cache.async_get_unique(record) is not None + + original_a_record.created = current_time_millis() - 1001 + + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert original_a_record.ttl == 1 + for record in new_records: + assert zc.cache.async_get_unique(record) is not None + + await aiozc.async_close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index af21ba580..302768137 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -333,6 +333,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: updates.append(record) removes.append(record) + import pprint + + pprint.pprint(unique_types) + if unique_types: self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) From 07061b6b2734cf90efb20672fbcc41bae25c483a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 15:52:07 -1000 Subject: [PATCH 20/23] adjust --- zeroconf/_handlers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 302768137..af21ba580 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -333,10 +333,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: updates.append(record) removes.append(record) - import pprint - - pprint.pprint(unique_types) - if unique_types: self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) From 2cc6053e3d2d9bad4f8449f82ee603e869d06660 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 16:09:02 -1000 Subject: [PATCH 21/23] coverage --- tests/test_handlers.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d1015501f..ddd8ffa47 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._handlers """ +import asyncio import logging import pytest import socket @@ -885,4 +886,32 @@ async def test_cache_flush_bit(): for record in new_records: assert zc.cache.async_get_unique(record) is not None + cached_records = [zc.cache.async_get_unique(record) for record in new_records] + for record in cached_records: + record.created = current_time_millis() - 1001 + + fresh_address = socket.inet_aton("4.4.4.4") + info.addresses = [fresh_address] + # Do the run within 1s to verify the two new records get marked as expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in info.dns_addresses(): + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + for record in cached_records: + assert record.ttl == 1 + + for entry in zc.cache.async_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + if entry.address == fresh_address: + assert entry.ttl > 1 + else: + assert entry.ttl == 1 + + # Wait for the ttl 1 records to expire + await asyncio.sleep(1.01) + + loaded_info = r.ServiceInfo(type_, registration_name) + loaded_info.load_from_cache(zc) + assert loaded_info.addresses == info.addresses + await aiozc.async_close() From 25fe6c1e8811c1cab5597b968904d9c9f3f37ea5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 16:11:17 -1000 Subject: [PATCH 22/23] coverage --- zeroconf/_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index af21ba580..66b8862f2 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -373,7 +373,7 @@ def _async_mark_unique_cached_records_older_than_1s_to_expire( answers_rrset = DNSRRSet(answers) for name, type_, class_ in unique_types: for entry in self.cache.async_all_by_details(name, type_, class_): - if now - entry.created > 1000 and entry not in answers_rrset: + if (now - entry.created > 1000) and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) From a22c5890943962bd1fbe899bd69ca0c519961ed6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 16:17:21 -1000 Subject: [PATCH 23/23] cover --- tests/test_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index a4d4c979e..4b3a8a18e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -131,11 +131,13 @@ class TestDNSCacheAPI(unittest.TestCase): def test_get(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + record3 = r.DNSAddress('a', const._TYPE_AAAA, const._CLASS_IN, 1, b'ipv6') cache = r.DNSCache() - cache.async_add_records([record1, record2]) + cache.async_add_records([record1, record2, record3]) assert cache.get(record1) == record1 assert cache.get(record2) == record2 assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_AAAA, const._CLASS_IN)) == record3 assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None def test_get_by_details(self):