Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 125 additions & 53 deletions zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,16 +875,25 @@ def __init__(self, flags: int, multicast: bool = True) -> None:
self.id = 0
self.multicast = multicast
self.flags = flags
self.packets_data = [] # type: List[bytes]

# these 3 are per-packet -- see also reset_for_next_packet()
self.names = {} # type: Dict[str, int]
self.data = [] # type: List[bytes]
self.size = 12

self.state = self.State.init

self.questions = [] # type: List[DNSQuestion]
self.answers = [] # type: List[Tuple[DNSRecord, float]]
self.authorities = [] # type: List[DNSPointer]
self.additionals = [] # type: List[DNSRecord]

def reset_for_next_packet(self) -> None:
self.names = {}
self.data = []
self.size = 12

def __repr__(self) -> str:
return '<DNSOutgoing:{%s}>' % ', '.join(
[
Expand Down Expand Up @@ -1059,11 +1068,13 @@ def write_question(self, question: DNSQuestion) -> None:
self.write_short(question.type)
self.write_short(question.class_)

def write_record(self, record: DNSRecord, now: float) -> int:
def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) -> bool:
"""Writes a record (answer, authoritative answer, additional) to
the packet"""
the packet. Returns True on success, or False if we did not (either
because the packet was already finished or because the record does
not fit."""
if self.state == self.State.finished:
return 1
return False

start_data_length, start_size = len(self.data), self.size
self.write_name(record.name)
Expand All @@ -1087,44 +1098,102 @@ def write_record(self, record: DNSRecord, now: float) -> int:
# Here is the short we adjusted for
self.insert_short(index, length)

len_limit = _MAX_MSG_ABSOLUTE if allow_long else _MAX_MSG_TYPICAL

# if we go over, then rollback and quit
if self.size > _MAX_MSG_ABSOLUTE:
if self.size > len_limit:
while len(self.data) > start_data_length:
self.data.pop()
self.size = start_size
self.state = self.State.finished
return 1
return 0
return False
return True

def packet(self) -> bytes:
"""Returns a string containing the packet's bytes
"""Returns a bytestring containing the first packet's bytes.

Generally, you want to use packets() in case the response
does not fit in a single packet, but this exists for
backward compatibility."""
packets = self.packets()
if len(packets) > 0:
if len(packets[0]) > _MAX_MSG_ABSOLUTE:
QuietLogger.log_warning_once(
"Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0]
)
return packets[0]
else:
return b''

No further parts should be added to the packet once this
is done."""
def packets(self) -> List[bytes]:
"""Returns a list of bytestrings containing the packets' bytes

overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0
No further parts should be added to the packet once this
is done. The packets are each restricted to _MAX_MSG_TYPICAL
or less in length, except for the case of a single answer which
will be written out to a single oversized packet no more than
_MAX_MSG_ABSOLUTE in length (and hence will be subject to IP
fragmentation potentially). """

if self.state != self.State.finished:
if self.state == self.State.finished:
return self.packets_data

answer_offset = 0
authority_offset = 0
additional_offset = 0

# we have to at least write out the question
first_time = True

while (
first_time
or answer_offset < len(self.answers)
or authority_offset < len(self.authorities)
or additional_offset < len(self.additionals)
):
first_time = False
log.debug("offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset)
log.debug("lengths = %d, %d, %d", len(self.answers), len(self.authorities), len(self.additionals))

additionals_written = 0
authorities_written = 0
answers_written = 0
questions_written = 0
for question in self.questions:
self.write_question(question)
for answer, time_ in self.answers:
overrun_answers += self.write_record(answer, time_)
for authority in self.authorities:
overrun_authorities += self.write_record(authority, 0)
for additional in self.additionals:
overrun_additionals += self.write_record(additional, 0)
self.state = self.State.finished

self.insert_short(0, len(self.additionals) - overrun_additionals)
self.insert_short(0, len(self.authorities) - overrun_authorities)
self.insert_short(0, len(self.answers) - overrun_answers)
self.insert_short(0, len(self.questions))
questions_written += 1
allow_long = True # at most one answer is allowed to be a long packet

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it at most one answer per packet or per group of packets?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring -- multiple answers per packet are fine. It's just the case of having so many answers that it overflows the smaller 1470ish MTU of ethernet -- in that scenario, the application layer MUST send separate application-level responses rather than rely on IP fragmentation to break up a bigger packet. The only exception is when a single answer does not fit inside the smaller MTU in which case it's allowed to be bigger. I suspect that scenario will still crash ChromeCast Audios, but that situation is exceedingly rare, whereas many responses that exceed 1470 bytes happens within a large but not crazy set of devices.

for answer, time_ in self.answers[answer_offset:]:
if self.write_record(answer, time_, allow_long):
answers_written += 1
allow_long = False
for authority in self.authorities[authority_offset:]:
if self.write_record(authority, 0):
authorities_written += 1
for additional in self.additionals[additional_offset:]:
if self.write_record(additional, 0):
additionals_written += 1

self.insert_short(0, additionals_written)
self.insert_short(0, authorities_written)
self.insert_short(0, answers_written)
self.insert_short(0, questions_written)
self.insert_short(0, self.flags)
if self.multicast:
self.insert_short(0, 0)
else:
self.insert_short(0, self.id)
return b''.join(self.data)
self.packets_data.append(b''.join(self.data))
self.reset_for_next_packet()

answer_offset += answers_written
authority_offset += authorities_written
additional_offset += additionals_written
log.debug("now offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset)
if answers_written == 0 and authorities_written == 0 and additional_offset == 0:
log.warning("packets() made no progress adding records; returning")
break
self.state = self.State.finished
return self.packets_data


class DNSCache:
Expand Down Expand Up @@ -2708,36 +2777,39 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None

def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None:
"""Sends an outgoing packet."""
packet = out.packet()
if len(packet) > _MAX_MSG_ABSOLUTE:
self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
return
log.debug('Sending %r (%d bytes) as [%r]', out, len(packet), packet)
for s in self._respond_sockets:
if self._GLOBAL_DONE:
packets = out.packets()
packet_num = 0
for packet in packets:
packet_num += 1
if len(packet) > _MAX_MSG_ABSOLUTE:
self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
return
try:
if addr is None:
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
elif not can_send_to(s, addr):
continue
log.debug('Sending (%d bytes #%d) %r as %r...', len(packet), packet_num, out, packet)
for s in self._respond_sockets:
if self._GLOBAL_DONE:
return
try:
if addr is None:
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
elif not can_send_to(s, addr):
continue
else:
real_addr = addr
bytes_sent = s.sendto(packet, 0, (real_addr, port))
except Exception as exc: # TODO stop catching all Exceptions
if (
isinstance(exc, OSError)
and exc.errno == errno.ENETUNREACH
and s.family == socket.AF_INET6
):
# with IPv6 we don't have a reliable way to determine if an interface actually has
# IPV6 support, so we have to try and ignore errors.
continue
# on send errors, log the exception and keep going
self.log_exception_warning()
else:
real_addr = addr
bytes_sent = s.sendto(packet, 0, (real_addr, port))
except Exception as exc: # TODO stop catching all Exceptions
if (
isinstance(exc, OSError)
and exc.errno == errno.ENETUNREACH
and s.family == socket.AF_INET6
):
# with IPv6 we don't have a reliable way to determine if an interface actually has IPv6
# support, so we have to try and ignore errors.
continue
# on send errors, log the exception and keep going
self.log_exception_warning()
else:
if bytes_sent != len(packet):
self.log_warning_once('!!! sent %d out of %d bytes to %r' % (bytes_sent, len(packet), s))
if bytes_sent != len(packet):
self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s))

def close(self) -> None:
"""Ends the background threads, and prevent this instance from
Expand Down
34 changes: 21 additions & 13 deletions zeroconf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ def test_exceedingly_long_name(self):
generated.add_question(question)
r.DNSIncoming(generated.packet())

def test_extra_exceedingly_long_name(self):
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
name = "%slocal." % ("part." * 4000)
question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN)
generated.add_question(question)
r.DNSIncoming(generated.packet())

def test_exceedingly_long_name_part(self):
name = "%s.local." % ("a" * 1000)
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
Expand Down Expand Up @@ -355,12 +362,12 @@ def test_lots_of_names(self):

def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
"""Sends an outgoing packet."""
packet = out.packet()
nonlocal longest_packet_len, longest_packet
if longest_packet_len < len(packet):
longest_packet_len = len(packet)
longest_packet = out
old_send(out, addr=addr, port=port)
for packet in out.packets():
nonlocal longest_packet_len, longest_packet
if longest_packet_len < len(packet):
longest_packet_len = len(packet)
longest_packet = out
old_send(out, addr=addr, port=port)

# monkey patch the zeroconf send
setattr(zc, "send", send)
Expand All @@ -374,6 +381,9 @@ def on_service_state_change(zeroconf, service_type, state_change, name):

# wait until the browse request packet has maxed out in size
sleep_count = 0
# we will never get to this large of a packet given the application-layer
# splitting of packets, but we still want to track the longest_packet_len
# for the debug message below
while sleep_count < 100 and longest_packet_len < r._MAX_MSG_ABSOLUTE - 100:
sleep_count += 1
time.sleep(0.1)
Expand All @@ -386,8 +396,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len)

# now the browser has sent at least one request, verify the size
assert longest_packet_len <= r._MAX_MSG_ABSOLUTE
assert longest_packet_len >= r._MAX_MSG_ABSOLUTE - 100
assert longest_packet_len <= r._MAX_MSG_TYPICAL
assert longest_packet_len >= r._MAX_MSG_TYPICAL - 100

# mock zeroconf's logger warning() and debug()
from unittest.mock import patch
Expand All @@ -407,13 +417,11 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
# try to send an oversized packet
zc.send(out)
assert mocked_log_warn.call_count == call_counts[0] + 1
assert mocked_log_debug.call_count == call_counts[0]
assert mocked_log_warn.call_count == call_counts[0]
zc.send(out)
assert mocked_log_warn.call_count == call_counts[0] + 1
assert mocked_log_debug.call_count == call_counts[0] + 1
assert mocked_log_warn.call_count == call_counts[0]

# force a receive of an oversized packet
# force a receive of a packet
packet = out.packet()
s = zc._respond_sockets[0]

Expand Down