diff --git a/tests/test_cache.py b/tests/test_cache.py index aa6acf6c2..19033b5c9 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -31,8 +31,7 @@ def test_order(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) cached_record = cache.get(entry) assert cached_record == record2 @@ -46,13 +45,11 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry(record2) cached_record = cache.get(entry) assert cached_record == record2 - @unittest.skip('This bug in the implementation needs to be fixed.') def test_adding_same_record_to_cache_different_ttls(self): """Verify we only get one record back. @@ -64,8 +61,7 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) assert cached_records == [record2] @@ -73,25 +69,18 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache - def test_cache_empty_multiple_calls_does_not_throw(self): + def test_cache_empty_multiple_calls(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - # Ensure multiple removes does not throw - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 135b1884e..2e07a7a47 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -27,46 +27,83 @@ from .const import _TYPE_PTR +_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] + + +def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: + """Remove a key from a DNSRecord cache + + This function must be run in from event loop. + """ + del cache[key][entry] + if not cache[key]: + del cache[key] + + class DNSCache: """A cache of DNS entries.""" def __init__(self) -> None: - self.cache: Dict[str, List[DNSRecord]] = {} - self.service_cache: Dict[str, List[DNSRecord]] = {} + self.cache: _DNSRecordCacheType = {} + self.service_cache: _DNSRecordCacheType = {} + + # Functions prefixed with are NOT threadsafe and must + # be run in the event loop. def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert last in list, get will return newest entry - # iteration will result in last update winning - self.cache.setdefault(entry.key, []).append(entry) + """Adds an entry. + + This function must be run in from event loop. + """ + # Previously storage of records was implemented as a list + # instead a dict. Since DNSRecords are now hashable, the implementation + # uses a dict to ensure that adding a new record to the cache + # replaces any existing records that are __eq__ to each other which + # removes the risk that accessing the cache from the wrong + # direction would return the old incorrect entry. + self.cache.setdefault(entry.key, {})[entry] = entry if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server, []).append(entry) + self.service_cache.setdefault(entry.server, {})[entry] = entry def add_records(self, entries: Iterable[DNSRecord]) -> None: - """Add multiple records.""" + """Add multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.add(entry) def remove(self, entry: DNSRecord) -> None: - """Removes an entry.""" + """Removes an entry. + + This function must be run in from event loop. + """ if isinstance(entry, DNSService): - DNSCache.remove_key(self.service_cache, entry.server, entry) - DNSCache.remove_key(self.cache, entry.key, entry) + _remove_key(self.service_cache, entry.server, entry) + _remove_key(self.cache, entry.key, entry) def remove_records(self, entries: Iterable[DNSRecord]) -> None: - """Remove multiple records.""" + """Remove multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.remove(entry) - @staticmethod - def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: - """Forgiving remove of a cache key.""" - try: - cache[key].remove(entry) - if not cache[key]: - del cache[key] - except (KeyError, ValueError): - pass + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache. + + This function must be run in from event loop. + """ + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record + + # The below functions are threadsafe and do not need to be run in the + # event loop, however they all make copies so they significantly + # inefficent def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no @@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: return None def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets the first matching entry by details. Returns None if no entries match.""" + """Gets the first matching entry by details. Returns None if no entries match. + + Calling this function is not recommended as it will only + return one record even if there are multiple entries. + + For example if there are multiple A or AAAA addresses this + function will return the last one that was added to the cache + which may not be the one you expect. + + Use get_all_by_details instead. + """ return self.get(DNSEntry(name, type_, class_)) def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: @@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" - return self.service_cache.get(server, [])[:] + return list(self.service_cache.get(server, {})) def entries_with_name(self, name: str) -> List[DNSRecord]: """Returns a list of entries whose key matches the name.""" - return self.cache.get(name.lower(), [])[:] + return list(self.cache.get(name.lower(), {})) def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() @@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D def names(self) -> List[str]: """Return a copy of the list of current cache names.""" return list(self.cache) - - def expire(self, now: float) -> Iterable[DNSRecord]: - """Purge expired entries from the cache.""" - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self.remove(record) - yield record