From e4be118ccc231d01b07795c0c4b046195d9dc11e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:08:10 -1000 Subject: [PATCH] Fix cache handling of records with different TTLs - There should only be one unique record in the cache at a time as having multiple unique records will different TTLs in the cache can result in unexpected behavior since some functions returned all matching records and some fetched from the right side of the list to return the newest record. Intead we now store the records in a dict to ensure that the newest record always replaces the same unique record and we never have a source of truth problem determining the TTL of a record from the cache. --- tests/test_cache.py | 27 ++++-------- zeroconf/_cache.py | 103 ++++++++++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 51 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index aa6acf6c2..19033b5c9 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -31,8 +31,7 @@ def test_order(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) cached_record = cache.get(entry) assert cached_record == record2 @@ -46,13 +45,11 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry(record2) cached_record = cache.get(entry) assert cached_record == record2 - @unittest.skip('This bug in the implementation needs to be fixed.') def test_adding_same_record_to_cache_different_ttls(self): """Verify we only get one record back. @@ -64,8 +61,7 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) assert cached_records == [record2] @@ -73,25 +69,18 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache - def test_cache_empty_multiple_calls_does_not_throw(self): + def test_cache_empty_multiple_calls(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - # Ensure multiple removes does not throw - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 135b1884e..2e07a7a47 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -27,46 +27,83 @@ from .const import _TYPE_PTR +_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] + + +def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: + """Remove a key from a DNSRecord cache + + This function must be run in from event loop. + """ + del cache[key][entry] + if not cache[key]: + del cache[key] + + class DNSCache: """A cache of DNS entries.""" def __init__(self) -> None: - self.cache: Dict[str, List[DNSRecord]] = {} - self.service_cache: Dict[str, List[DNSRecord]] = {} + self.cache: _DNSRecordCacheType = {} + self.service_cache: _DNSRecordCacheType = {} + + # Functions prefixed with are NOT threadsafe and must + # be run in the event loop. def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert last in list, get will return newest entry - # iteration will result in last update winning - self.cache.setdefault(entry.key, []).append(entry) + """Adds an entry. + + This function must be run in from event loop. + """ + # Previously storage of records was implemented as a list + # instead a dict. Since DNSRecords are now hashable, the implementation + # uses a dict to ensure that adding a new record to the cache + # replaces any existing records that are __eq__ to each other which + # removes the risk that accessing the cache from the wrong + # direction would return the old incorrect entry. + self.cache.setdefault(entry.key, {})[entry] = entry if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server, []).append(entry) + self.service_cache.setdefault(entry.server, {})[entry] = entry def add_records(self, entries: Iterable[DNSRecord]) -> None: - """Add multiple records.""" + """Add multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.add(entry) def remove(self, entry: DNSRecord) -> None: - """Removes an entry.""" + """Removes an entry. + + This function must be run in from event loop. + """ if isinstance(entry, DNSService): - DNSCache.remove_key(self.service_cache, entry.server, entry) - DNSCache.remove_key(self.cache, entry.key, entry) + _remove_key(self.service_cache, entry.server, entry) + _remove_key(self.cache, entry.key, entry) def remove_records(self, entries: Iterable[DNSRecord]) -> None: - """Remove multiple records.""" + """Remove multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.remove(entry) - @staticmethod - def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: - """Forgiving remove of a cache key.""" - try: - cache[key].remove(entry) - if not cache[key]: - del cache[key] - except (KeyError, ValueError): - pass + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache. + + This function must be run in from event loop. + """ + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record + + # The below functions are threadsafe and do not need to be run in the + # event loop, however they all make copies so they significantly + # inefficent def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no @@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: return None def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets the first matching entry by details. Returns None if no entries match.""" + """Gets the first matching entry by details. Returns None if no entries match. + + Calling this function is not recommended as it will only + return one record even if there are multiple entries. + + For example if there are multiple A or AAAA addresses this + function will return the last one that was added to the cache + which may not be the one you expect. + + Use get_all_by_details instead. + """ return self.get(DNSEntry(name, type_, class_)) def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: @@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" - return self.service_cache.get(server, [])[:] + return list(self.service_cache.get(server, {})) def entries_with_name(self, name: str) -> List[DNSRecord]: """Returns a list of entries whose key matches the name.""" - return self.cache.get(name.lower(), [])[:] + return list(self.cache.get(name.lower(), {})) def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() @@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D def names(self) -> List[str]: """Return a copy of the list of current cache names.""" return list(self.cache) - - def expire(self, now: float) -> Iterable[DNSRecord]: - """Purge expired entries from the cache.""" - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self.remove(record) - yield record