Skip to content

Commit 88aa610

Browse files
authored
Fix cache handling of records with different TTLs (python-zeroconf#729)
- There should only be one unique record in the cache at a time as having multiple unique records will different TTLs in the cache can result in unexpected behavior since some functions returned all matching records and some fetched from the right side of the list to return the newest record. Intead we now store the records in a dict to ensure that the newest record always replaces the same unique record and we never have a source of truth problem determining the TTL of a record from the cache.
1 parent ceb79bd commit 88aa610

2 files changed

Lines changed: 79 additions & 51 deletions

File tree

tests/test_cache.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def test_order(self):
3131
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
3232
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
3333
cache = r.DNSCache()
34-
cache.add(record1)
35-
cache.add(record2)
34+
cache.add_records([record1, record2])
3635
entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN)
3736
cached_record = cache.get(entry)
3837
assert cached_record == record2
@@ -46,13 +45,11 @@ def test_adding_same_record_to_cache_different_ttls(self):
4645
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
4746
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
4847
cache = r.DNSCache()
49-
cache.add(record1)
50-
cache.add(record2)
48+
cache.add_records([record1, record2])
5149
entry = r.DNSEntry(record2)
5250
cached_record = cache.get(entry)
5351
assert cached_record == record2
5452

55-
@unittest.skip('This bug in the implementation needs to be fixed.')
5653
def test_adding_same_record_to_cache_different_ttls(self):
5754
"""Verify we only get one record back.
5855
@@ -64,34 +61,26 @@ def test_adding_same_record_to_cache_different_ttls(self):
6461
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
6562
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
6663
cache = r.DNSCache()
67-
cache.add(record1)
68-
cache.add(record2)
64+
cache.add_records([record1, record2])
6965
cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)
7066
assert cached_records == [record2]
7167

7268
def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self):
7369
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
7470
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
7571
cache = r.DNSCache()
76-
cache.add(record1)
77-
cache.add(record2)
72+
cache.add_records([record1, record2])
7873
assert 'a' in cache.cache
79-
cache.remove(record1)
80-
cache.remove(record2)
74+
cache.remove_records([record1, record2])
8175
assert 'a' not in cache.cache
8276

83-
def test_cache_empty_multiple_calls_does_not_throw(self):
77+
def test_cache_empty_multiple_calls(self):
8478
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
8579
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
8680
cache = r.DNSCache()
87-
cache.add(record1)
88-
cache.add(record2)
81+
cache.add_records([record1, record2])
8982
assert 'a' in cache.cache
90-
cache.remove(record1)
91-
cache.remove(record2)
92-
# Ensure multiple removes does not throw
93-
cache.remove(record1)
94-
cache.remove(record2)
83+
cache.remove_records([record1, record2])
9584
assert 'a' not in cache.cache
9685

9786

zeroconf/_cache.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,46 +27,83 @@
2727
from .const import _TYPE_PTR
2828

2929

30+
_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
31+
32+
33+
def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
34+
"""Remove a key from a DNSRecord cache
35+
36+
This function must be run in from event loop.
37+
"""
38+
del cache[key][entry]
39+
if not cache[key]:
40+
del cache[key]
41+
42+
3043
class DNSCache:
3144
"""A cache of DNS entries."""
3245

3346
def __init__(self) -> None:
34-
self.cache: Dict[str, List[DNSRecord]] = {}
35-
self.service_cache: Dict[str, List[DNSRecord]] = {}
47+
self.cache: _DNSRecordCacheType = {}
48+
self.service_cache: _DNSRecordCacheType = {}
49+
50+
# Functions prefixed with are NOT threadsafe and must
51+
# be run in the event loop.
3652

3753
def add(self, entry: DNSRecord) -> None:
38-
"""Adds an entry"""
39-
# Insert last in list, get will return newest entry
40-
# iteration will result in last update winning
41-
self.cache.setdefault(entry.key, []).append(entry)
54+
"""Adds an entry.
55+
56+
This function must be run in from event loop.
57+
"""
58+
# Previously storage of records was implemented as a list
59+
# instead a dict. Since DNSRecords are now hashable, the implementation
60+
# uses a dict to ensure that adding a new record to the cache
61+
# replaces any existing records that are __eq__ to each other which
62+
# removes the risk that accessing the cache from the wrong
63+
# direction would return the old incorrect entry.
64+
self.cache.setdefault(entry.key, {})[entry] = entry
4265
if isinstance(entry, DNSService):
43-
self.service_cache.setdefault(entry.server, []).append(entry)
66+
self.service_cache.setdefault(entry.server, {})[entry] = entry
4467

4568
def add_records(self, entries: Iterable[DNSRecord]) -> None:
46-
"""Add multiple records."""
69+
"""Add multiple records.
70+
71+
This function must be run in from event loop.
72+
"""
4773
for entry in entries:
4874
self.add(entry)
4975

5076
def remove(self, entry: DNSRecord) -> None:
51-
"""Removes an entry."""
77+
"""Removes an entry.
78+
79+
This function must be run in from event loop.
80+
"""
5281
if isinstance(entry, DNSService):
53-
DNSCache.remove_key(self.service_cache, entry.server, entry)
54-
DNSCache.remove_key(self.cache, entry.key, entry)
82+
_remove_key(self.service_cache, entry.server, entry)
83+
_remove_key(self.cache, entry.key, entry)
5584

5685
def remove_records(self, entries: Iterable[DNSRecord]) -> None:
57-
"""Remove multiple records."""
86+
"""Remove multiple records.
87+
88+
This function must be run in from event loop.
89+
"""
5890
for entry in entries:
5991
self.remove(entry)
6092

61-
@staticmethod
62-
def remove_key(cache: dict, key: str, entry: DNSRecord) -> None:
63-
"""Forgiving remove of a cache key."""
64-
try:
65-
cache[key].remove(entry)
66-
if not cache[key]:
67-
del cache[key]
68-
except (KeyError, ValueError):
69-
pass
93+
def expire(self, now: float) -> Iterable[DNSRecord]:
94+
"""Purge expired entries from the cache.
95+
96+
This function must be run in from event loop.
97+
"""
98+
for name in self.names():
99+
for record in self.entries_with_name(name):
100+
if record.is_expired(now):
101+
self.remove(record)
102+
yield record
103+
104+
# The below functions are threadsafe and do not need to be run in the
105+
# event loop, however they all make copies so they significantly
106+
# inefficent
70107

71108
def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
72109
"""Gets an entry by key. Will return None if there is no
@@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
77114
return None
78115

79116
def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]:
80-
"""Gets the first matching entry by details. Returns None if no entries match."""
117+
"""Gets the first matching entry by details. Returns None if no entries match.
118+
119+
Calling this function is not recommended as it will only
120+
return one record even if there are multiple entries.
121+
122+
For example if there are multiple A or AAAA addresses this
123+
function will return the last one that was added to the cache
124+
which may not be the one you expect.
125+
126+
Use get_all_by_details instead.
127+
"""
81128
return self.get(DNSEntry(name, type_, class_))
82129

83130
def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
@@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco
87134

88135
def entries_with_server(self, server: str) -> List[DNSRecord]:
89136
"""Returns a list of entries whose server matches the name."""
90-
return self.service_cache.get(server, [])[:]
137+
return list(self.service_cache.get(server, {}))
91138

92139
def entries_with_name(self, name: str) -> List[DNSRecord]:
93140
"""Returns a list of entries whose key matches the name."""
94-
return self.cache.get(name.lower(), [])[:]
141+
return list(self.cache.get(name.lower(), {}))
95142

96143
def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
97144
now = current_time_millis()
@@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
107154
def names(self) -> List[str]:
108155
"""Return a copy of the list of current cache names."""
109156
return list(self.cache)
110-
111-
def expire(self, now: float) -> Iterable[DNSRecord]:
112-
"""Purge expired entries from the cache."""
113-
for name in self.names():
114-
for record in self.entries_with_name(name):
115-
if record.is_expired(now):
116-
self.remove(record)
117-
yield record

0 commit comments

Comments
 (0)