Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ cdef object _RECENT_TIME_MS
cdef object _CLASS_UNIQUE
cdef object _CLASS_MASK

cdef object current_time_millis

cdef class DNSEntry:

cdef public object key
Expand Down Expand Up @@ -96,8 +98,11 @@ cdef class DNSNsec(DNSRecord):

cdef class DNSRRSet:

cdef _records
cdef _record_sets
cdef cython.dict _lookup

@cython.locals(other=DNSRecord)
cpdef suppresses(self, DNSRecord record)

@cython.locals(lookup=cython.dict)
cdef _get_lookup(self)
35 changes: 17 additions & 18 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,31 +512,30 @@ def __repr__(self) -> str:


class DNSRRSet:
"""A set of dns records independent of the ttl."""
"""A set of dns records with a lookup to get the ttl."""

__slots__ = ('_records', '_lookup')
__slots__ = ('_record_sets', '_lookup')

def __init__(self, records: Iterable[DNSRecord]) -> None:
"""Create an RRset from records."""
self._records = records
self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None
def __init__(self, record_sets: Iterable[List[DNSRecord]]) -> None:
"""Create an RRset from records sets."""
self._record_sets = record_sets
self._lookup: Optional[Dict[DNSRecord, float]] = None

@property
def lookup(self) -> Dict[DNSRecord, DNSRecord]:
def lookup(self) -> Dict[DNSRecord, float]:
"""Return the lookup table."""
return self._get_lookup()

def _get_lookup(self) -> Dict[DNSRecord, float]:
"""Return the lookup table, building it if needed."""
if self._lookup is None:
# Build the hash table so we can lookup the record independent of the ttl
self._lookup = {record: record for record in self._records}
# Build the hash table so we can lookup the record ttl
self._lookup = {record: record.ttl for record_sets in self._record_sets for record in record_sets}
return self._lookup

def suppresses(self, record: _DNSRecord) -> bool:
"""Returns true if any answer in the rrset can suffice for the
information held in this record."""
if self._lookup is None:
other = self.lookup.get(record)
else:
other = self._lookup.get(record)
return bool(other and other.ttl > (record.ttl / 2))

def __contains__(self, record: DNSRecord) -> bool:
"""Returns true if the rrset contains the record."""
return record in self.lookup
lookup = self._get_lookup()
other_ttl = lookup.get(record)
return bool(other_ttl and other_ttl > (record.ttl / 2))
10 changes: 2 additions & 8 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ class AnswerGroup(NamedTuple):
answers: _AnswerWithAdditionalsType


def _message_is_probe(msg: DNSIncoming) -> bool:
return msg.num_authorities > 0


def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
"""Construct an NSEC record for name and a list of dns types.

Expand Down Expand Up @@ -159,7 +155,7 @@ class _QueryResponse:

def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
"""Build a query response."""
self._is_probe = any(_message_is_probe(msg) for msg in msgs)
self._is_probe = any(msg.is_probe for msg in msgs)
self._msg = msgs[0]
self._now = self._msg.now
self._cache = cache
Expand Down Expand Up @@ -363,9 +359,7 @@ def async_response( # pylint: disable=unused-argument
This function must be run in the event loop as it is not
threadsafe.
"""
known_answers = DNSRRSet(
itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg))
)
known_answers = DNSRRSet(msg.answers for msg in msgs if not msg.is_probe)
query_res = _QueryResponse(self.cache, msgs)

for msg in msgs:
Expand Down
5 changes: 5 additions & 0 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ def answers(self) -> List[DNSRecord]:
)
return self._answers

@property
def is_probe(self) -> bool:
"""Returns true if this is a probe."""
return self.num_authorities > 0

def __repr__(self) -> str:
return '<DNSIncoming:{%s}>' % ', '.join(
[
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_rrset_does_not_consider_ttl():
longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same')
shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same')

rrset = DNSRRSet([longarec, shortaaaarec])
rrset = DNSRRSet([[longarec, shortaaaarec]])

assert rrset.suppresses(longarec)
assert rrset.suppresses(shortarec)
Expand All @@ -404,7 +404,7 @@ def test_rrset_does_not_consider_ttl():
mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same')
shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')

rrset2 = DNSRRSet([mediumarec])
rrset2 = DNSRRSet([[mediumarec]])
assert not rrset2.suppresses(verylongarec)
assert rrset2.suppresses(longarec)
assert rrset2.suppresses(mediumarec)
Expand Down