diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index 2d50c07a0..b28f73237 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -105,4 +105,4 @@ cdef class DNSRRSet: cpdef suppresses(self, DNSRecord record) @cython.locals(lookup=cython.dict) - cdef _get_lookup(self) + cdef cython.dict _get_lookup(self) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 3764edf72..ada6e9df4 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -538,4 +538,6 @@ def suppresses(self, record: _DNSRecord) -> bool: information held in this record.""" lookup = self._get_lookup() other_ttl = lookup.get(record) - return bool(other_ttl and other_ttl > (record.ttl / 2)) + if other_ttl is None: + return False + return other_ttl > (record.ttl / 2) diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index 79130d8a2..4233c810f 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -35,10 +35,10 @@ cdef class DNSIncoming: cdef bint _did_read_others cdef public unsigned int flags - cdef unsigned int offset + cdef object offset cdef public bytes data cdef unsigned int _data_len - cdef public object name_cache + cdef public cython.dict name_cache cdef public object questions cdef object _answers cdef public object id @@ -55,9 +55,10 @@ cdef class DNSIncoming: off=cython.uint, label_idx=cython.uint, length=cython.uint, - link=cython.uint + link=cython.uint, + link_data=cython.uint ) - cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, object seen_pointers) + cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers) cdef _read_header(self) diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 9996b37c9..9c0c39a2d 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -353,7 +353,9 @@ def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: S ) # We have a DNS compression pointer - link = (length & 0x3F) * 256 + self.data[off + 1] + link_data = self.data[off + 1] + link = (length & 0x3F) * 256 + link_data + lint_int = int(link) if link > self._data_len: raise IncomingDecodeError( f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}" @@ -362,15 +364,16 @@ def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: S raise IncomingDecodeError( f"DNS compression pointer at {off} points to itself from {self.source}" ) - if link in seen_pointers: + if lint_int in seen_pointers: raise IncomingDecodeError( f"DNS compression pointer at {off} was seen again from {self.source}" ) - linked_labels = self.name_cache.get(link, []) + linked_labels = self.name_cache.get(lint_int) if not linked_labels: - seen_pointers.add(link) + linked_labels = [] + seen_pointers.add(lint_int) self._decode_labels_at_offset(link, linked_labels, seen_pointers) - self.name_cache[link] = linked_labels + self.name_cache[lint_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: raise IncomingDecodeError(