Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ cdef cython.uint MAX_DNS_LABELS
cdef cython.uint DNS_COMPRESSION_POINTER_LEN
cdef cython.uint MAX_NAME_LENGTH

cdef object current_time_millis

cdef cython.uint _TYPE_A
cdef cython.uint _TYPE_CNAME
cdef cython.uint _TYPE_PTR
Expand Down Expand Up @@ -43,6 +41,7 @@ from .._dns cimport (
DNSService,
DNSText,
)
from .._utils.time cimport current_time_millis


cdef class DNSIncoming:
Expand All @@ -62,6 +61,7 @@ cdef class DNSIncoming:
cdef public cython.uint num_additionals
cdef public object valid
cdef public object now
cdef cython.float _now_float
cdef public object scope_id
cdef public object source

Expand All @@ -79,7 +79,9 @@ cdef class DNSIncoming:
label_idx=cython.uint,
length=cython.uint,
link=cython.uint,
link_data=cython.uint
link_data=cython.uint,
link_py_int=object,
linked_labels=cython.list
)
cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)

Expand All @@ -95,9 +97,12 @@ cdef class DNSIncoming:

cdef _read_questions(self)

cdef bytes _read_character_string(self)
@cython.locals(
length=cython.uint,
)
cdef str _read_character_string(self)

cdef _read_string(self, unsigned int length)
cdef bytes _read_string(self, unsigned int length)

@cython.locals(
name_start=cython.uint
Expand Down
26 changes: 15 additions & 11 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class DNSIncoming:
'num_additionals',
'valid',
'now',
'_now_float',
'scope_id',
'source',
)
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
self.valid = False
self._did_read_others = False
self.now = now or current_time_millis()
self._now_float = self.now
self.source = source
self.scope_id = scope_id
try:
Expand Down Expand Up @@ -226,11 +228,13 @@ def _read_questions(self) -> None:
question = DNSQuestion(name, type_, class_)
self.questions.append(question)

def _read_character_string(self) -> bytes:
def _read_character_string(self) -> str:
"""Reads a character string from the packet"""
length = self.data[self.offset]
self.offset += 1
return self._read_string(length)
info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace')
self.offset += length
return info

def _read_string(self, length: _int) -> bytes:
"""Reads a string of a given length from the packet"""
Expand Down Expand Up @@ -273,7 +277,7 @@ def _read_record(
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
dns_address.created = self.now
dns_address.created = self._now_float
return dns_address
if type_ in (_TYPE_CNAME, _TYPE_PTR):
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
Expand All @@ -299,13 +303,13 @@ def _read_record(
type_,
class_,
ttl,
self._read_character_string().decode('utf-8', 'replace'),
self._read_character_string().decode('utf-8', 'replace'),
self._read_character_string(),
self._read_character_string(),
self.now,
)
if type_ == _TYPE_AAAA:
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
dns_address.created = self.now
dns_address.created = self._now_float
dns_address.scope_id = self.scope_id
return dns_address
if type_ == _TYPE_NSEC:
Expand Down Expand Up @@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
# We have a DNS compression pointer
link_data = self.data[off + 1]
link = (length & 0x3F) * 256 + link_data
lint_int = int(link)
link_py_int = link
if link > self._data_len:
raise IncomingDecodeError(
f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}"
Expand All @@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
raise IncomingDecodeError(
f"DNS compression pointer at {off} points to itself from {self.source}"
)
if lint_int in seen_pointers:
if link_py_int in seen_pointers:
raise IncomingDecodeError(
f"DNS compression pointer at {off} was seen again from {self.source}"
)
linked_labels = self.name_cache.get(lint_int)
linked_labels = self.name_cache.get(link_py_int)
if not linked_labels:
linked_labels = []
seen_pointers.add(lint_int)
seen_pointers.add(link_py_int)
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
self.name_cache[lint_int] = linked_labels
self.name_cache[link_py_int] = linked_labels
labels.extend(linked_labels)
if len(labels) > MAX_DNS_LABELS:
raise IncomingDecodeError(
Expand Down