Skip to content

Commit e05055c

Browse files
authored
fix: ensure cache does not return stale created and ttl values (#1469)
1 parent afd4517 commit e05055c

3 files changed

Lines changed: 105 additions & 13 deletions

File tree

src/zeroconf/_cache.pxd

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ cdef class DNSCache:
4747
)
4848
cpdef list async_all_by_details(self, str name, unsigned int type_, unsigned int class_)
4949

50-
cpdef cython.dict async_entries_with_name(self, str name)
50+
cpdef list async_entries_with_name(self, str name)
5151

52-
cpdef cython.dict async_entries_with_server(self, str name)
52+
cpdef list async_entries_with_server(self, str name)
5353

5454
@cython.locals(
5555
cached_entry=DNSRecord,
56+
records=dict
5657
)
5758
cpdef DNSRecord get_by_details(self, str name, unsigned int type_, unsigned int class_)
5859

@@ -79,7 +80,15 @@ cdef class DNSCache:
7980
)
8081
cpdef void async_mark_unique_records_older_than_1s_to_expire(self, cython.set unique_types, object answers, double now)
8182

82-
cpdef entries_with_name(self, str name)
83+
@cython.locals(
84+
entries=dict
85+
)
86+
cpdef list entries_with_name(self, str name)
87+
88+
@cython.locals(
89+
entries=dict
90+
)
91+
cpdef list entries_with_server(self, str server)
8392

8493
@cython.locals(
8594
record=DNSRecord,

src/zeroconf/_cache.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,26 +149,26 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN
149149
matches: List[DNSRecord] = []
150150
if records is None:
151151
return matches
152-
for record in records:
152+
for record in records.values():
153153
if type_ == record.type and class_ == record.class_:
154154
matches.append(record)
155155
return matches
156156

157-
def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
157+
def async_entries_with_name(self, name: str) -> List[DNSRecord]:
158158
"""Returns a dict of entries whose key matches the name.
159159
160160
This function is not threadsafe and must be called from
161161
the event loop.
162162
"""
163-
return self.cache.get(name.lower()) or {}
163+
return self.entries_with_name(name)
164164

165-
def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]:
165+
def async_entries_with_server(self, name: str) -> List[DNSRecord]:
166166
"""Returns a dict of entries whose key matches the server.
167167
168168
This function is not threadsafe and must be called from
169169
the event loop.
170170
"""
171-
return self.service_cache.get(name.lower()) or {}
171+
return self.entries_with_server(name)
172172

173173
# The below functions are threadsafe and do not need to be run in the
174174
# event loop, however they all make copies so they significantly
@@ -179,7 +179,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
179179
matching entry."""
180180
if isinstance(entry, _UNIQUE_RECORD_TYPES):
181181
return self.cache.get(entry.key, {}).get(entry)
182-
for cached_entry in reversed(list(self.cache.get(entry.key, []))):
182+
for cached_entry in reversed(list(self.cache.get(entry.key, {}).values())):
183183
if entry.__eq__(cached_entry):
184184
return cached_entry
185185
return None
@@ -200,7 +200,7 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe
200200
records = self.cache.get(key)
201201
if records is None:
202202
return None
203-
for cached_entry in reversed(list(records)):
203+
for cached_entry in reversed(list(records.values())):
204204
if type_ == cached_entry.type and class_ == cached_entry.class_:
205205
return cached_entry
206206
return None
@@ -211,15 +211,19 @@ def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRe
211211
records = self.cache.get(key)
212212
if records is None:
213213
return []
214-
return [entry for entry in list(records) if type_ == entry.type and class_ == entry.class_]
214+
return [entry for entry in list(records.values()) if type_ == entry.type and class_ == entry.class_]
215215

216216
def entries_with_server(self, server: str) -> List[DNSRecord]:
217217
"""Returns a list of entries whose server matches the name."""
218-
return list(self.service_cache.get(server.lower(), []))
218+
if entries := self.service_cache.get(server.lower()):
219+
return list(entries.values())
220+
return []
219221

220222
def entries_with_name(self, name: str) -> List[DNSRecord]:
221223
"""Returns a list of entries whose key matches the name."""
222-
return list(self.cache.get(name.lower(), []))
224+
if entries := self.cache.get(name.lower()):
225+
return list(entries.values())
226+
return []
223227

224228
def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
225229
now = current_time_millis()

tests/test_cache.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,82 @@ def test_name(self):
279279
cache = r.DNSCache()
280280
cache.async_add_records([record1, record2])
281281
assert cache.names() == ["irrelevant"]
282+
283+
284+
def test_async_entries_with_name_returns_newest_record():
285+
cache = r.DNSCache()
286+
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
287+
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
288+
cache.async_add_records([record1])
289+
cache.async_add_records([record2])
290+
assert next(iter(cache.async_entries_with_name("a"))) is record2
291+
292+
293+
def test_async_entries_with_server_returns_newest_record():
294+
cache = r.DNSCache()
295+
record1 = r.DNSService("a", const._TYPE_SRV, const._CLASS_IN, 1, 1, 1, 1, "a", created=1.0)
296+
record2 = r.DNSService("a", const._TYPE_SRV, const._CLASS_IN, 1, 1, 1, 1, "a", created=2.0)
297+
cache.async_add_records([record1])
298+
cache.async_add_records([record2])
299+
assert next(iter(cache.async_entries_with_server("a"))) is record2
300+
301+
302+
def test_async_get_returns_newest_record():
303+
cache = r.DNSCache()
304+
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
305+
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
306+
cache.async_add_records([record1])
307+
cache.async_add_records([record2])
308+
assert cache.get(record2) is record2
309+
310+
311+
def test_async_get_returns_newest_nsec_record():
312+
cache = r.DNSCache()
313+
record1 = r.DNSNsec("a", const._TYPE_NSEC, const._CLASS_IN, 1, "a", [], created=1.0)
314+
record2 = r.DNSNsec("a", const._TYPE_NSEC, const._CLASS_IN, 1, "a", [], created=2.0)
315+
cache.async_add_records([record1])
316+
cache.async_add_records([record2])
317+
assert cache.get(record2) is record2
318+
319+
320+
def test_get_by_details_returns_newest_record():
321+
cache = r.DNSCache()
322+
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
323+
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
324+
cache.async_add_records([record1])
325+
cache.async_add_records([record2])
326+
assert cache.get_by_details("a", const._TYPE_A, const._CLASS_IN) is record2
327+
328+
329+
def test_get_all_by_details_returns_newest_record():
330+
cache = r.DNSCache()
331+
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
332+
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
333+
cache.async_add_records([record1])
334+
cache.async_add_records([record2])
335+
records = cache.get_all_by_details("a", const._TYPE_A, const._CLASS_IN)
336+
assert len(records) == 1
337+
assert records[0] is record2
338+
339+
340+
def test_async_get_all_by_details_returns_newest_record():
341+
cache = r.DNSCache()
342+
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
343+
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
344+
cache.async_add_records([record1])
345+
cache.async_add_records([record2])
346+
records = cache.async_all_by_details("a", const._TYPE_A, const._CLASS_IN)
347+
assert len(records) == 1
348+
assert records[0] is record2
349+
350+
351+
def test_async_get_unique_returns_newest_record():
352+
cache = r.DNSCache()
353+
record1 = r.DNSPointer("a", const._TYPE_PTR, const._CLASS_IN, 1, "a", created=1.0)
354+
record2 = r.DNSPointer("a", const._TYPE_PTR, const._CLASS_IN, 1, "a", created=2.0)
355+
cache.async_add_records([record1])
356+
cache.async_add_records([record2])
357+
record = cache.async_get_unique(record1)
358+
assert record is record2
359+
record = cache.async_get_unique(record2)
360+
assert record is record2

0 commit comments

Comments
 (0)