Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions zeroconf/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,30 @@ class DNSIncoming(DNSMessage, QuietLogger):

"""Object representation of an incoming DNS packet"""

__slots__ = (
'offset',
'data',
'data_len',
'name_cache',
'questions',
'_answers',
'id',
'num_questions',
'num_answers',
'num_authorities',
'num_additionals',
'valid',
'now',
'scope_id',
)

def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None:
"""Constructor from string holding bytes of packet"""
super().__init__(0)
self.offset = 0
self.data = data
self.data_len = len(data)
self.name_cache: Dict[int, List[str]] = {}
self.seen_pointers: Set[int] = set()
self.questions: List[DNSQuestion] = []
self._answers: List[DNSRecord] = []
self.id = 0
Expand Down Expand Up @@ -162,10 +178,9 @@ def read_header(self) -> None:

def read_questions(self) -> None:
"""Reads questions section of packet"""
for _ in range(self.num_questions):
name = self.read_name()
type_, class_ = self.unpack(b'!HH', 4)
self.questions.append(DNSQuestion(name, type_, class_))
self.questions = [
DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions)
]

def read_character_string(self) -> bytes:
"""Reads a character string from the packet"""
Expand Down Expand Up @@ -278,14 +293,14 @@ def read_bitmap(self, end: int) -> List[int]:
def read_name(self) -> str:
"""Reads a domain name from the packet."""
labels: List[str] = []
self.seen_pointers.clear()
self.offset = self._decode_labels_at_offset(self.offset, labels)
seen_pointers: Set[int] = set()
self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers)
name = ".".join(labels) + "."
if len(name) > MAX_NAME_LENGTH:
raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}")
return name

def _decode_labels_at_offset(self, off: int, labels: List[str]) -> int:
def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int:
# This is a tight loop that is called frequently, small optimizations can make a difference.
while off < self.data_len:
length = self.data[off]
Expand All @@ -307,12 +322,12 @@ def _decode_labels_at_offset(self, off: int, labels: List[str]) -> int:
raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet")
if link == off:
raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself")
if link in self.seen_pointers:
if link in seen_pointers:
raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again")
self.seen_pointers.add(link)
seen_pointers.add(link)
linked_labels = self.name_cache.get(link, [])
if not linked_labels:
self._decode_labels_at_offset(link, linked_labels)
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
self.name_cache[link] = linked_labels
labels.extend(linked_labels)
if len(labels) > MAX_DNS_LABELS:
Expand All @@ -326,6 +341,22 @@ class DNSOutgoing(DNSMessage):

"""Object representation of an outgoing packet"""

__slots__ = (
'finished',
'id',
'multicast',
'packets_data',
'names',
'data',
'size',
'allow_long',
'state',
'questions',
'answers',
'authorities',
'additionals',
)

def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
super().__init__(flags)
self.finished = False
Expand Down