Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,10 @@ def suppresses(self, record: DNSRecord) -> bool:
self._lookup = {record: record for record in self._records}
other = self._lookup.get(record)
return bool(other and other.ttl > (record.ttl / 2))

def __contains__(self, record: DNSRecord) -> bool:
"""Returns true if the rrset contains the record."""
if self._lookup is None:
# Build the hash table so we can lookup the record independent of the ttl
self._lookup = {record: record for record in self._records}
return record in self._lookup
56 changes: 32 additions & 24 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,42 +311,33 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
other_adds: List[DNSRecord] = []
removes: List[DNSRecord] = []
now = msg.now
for record in msg.answers:

updated = True
unique_types: Set[Tuple[str, int, int]] = set()

for record in msg.answers:
if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
for entry in self.cache.get_all_by_details(record.name, record.type, record.class_):
if entry == record:
updated = False
if record.created - entry.created > 1000 and entry not in msg.answers:
# Expire in 1s
entry.set_created_ttl(now, 1)

expired = record.is_expired(now)
unique_types.add((record.name, record.type, record.class_))

maybe_entry = self.cache.get(record)
if not expired:
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
else:
if isinstance(record, DNSAddress):
address_adds.append(record)
else:
other_adds.append(record)
if updated:
updates.append(record)
updates.append(record)
# This is likely a goodbye since the record is
# expired and exists in the cache
elif maybe_entry is not None:
updates.append(record)
removes.append(record)

if not updates and not address_adds and not other_adds and not removes:
return
if unique_types:
self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now)

self.async_updates(now, updates)
if updates:
self.async_updates(now, updates)
# The cache adds must be processed AFTER we trigger
# the updates since we compare existing data
# with the new data and updating the cache
Expand All @@ -362,12 +353,29 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# zc.get_service_info will see the cached value
# but ONLY after all the record updates have been
# processsed.
self.cache.async_add_records(itertools.chain(address_adds, other_adds))
if other_adds or address_adds:
self.cache.async_add_records(itertools.chain(address_adds, other_adds))
# Removes are processed last since
# ServiceInfo could generate an un-needed query
# because the data was not yet populated.
self.cache.async_remove_records(removes)
self.async_updates_complete()
if removes:
self.cache.async_remove_records(removes)
if updates:
self.async_updates_complete()

def _async_mark_unique_cached_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[str, int, int]], answers: List[DNSRecord], now: float
) -> None:
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
answers_rrset = DNSRRSet(answers)
for name, type_, class_ in unique_types:
for entry in self.cache.get_all_by_details(name, type_, class_):
if (now - entry.created > 1000) and entry not in answers_rrset:
# Expire in 1s
entry.set_created_ttl(now, 1)

def add_listener(
self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
Expand Down