Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

import zeroconf as r
from zeroconf import DNSAddress, const
from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis
import zeroconf._services as s
from zeroconf import Zeroconf
from zeroconf._services import (
Expand Down Expand Up @@ -1377,3 +1377,26 @@ def test_serviceinfo_accepts_bytes_or_string_dict():
addresses=addresses,
)
assert info_service.dns_text().text == b'\x0epath=/~paulsm/'


def test_group_ptr_queries_with_known_answers():
questions_with_known_answers: s._QuestionWithKnownAnswers = {}
now = current_time_millis()
for i in range(120):
name = f"_hap{i}._tcp._local."
questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set(
DNSPointer(
name,
const._TYPE_PTR,
const._CLASS_IN,
4500,
f"zoo{counter}.{name}",
)
for counter in range(i)
)
outs = s._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers)
for out in outs:
packets = out.packets()
# If we generate multiple packets there must
# only be one question
assert len(packets) == 1 or len(out.questions) == 1
29 changes: 27 additions & 2 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_CLASSES,
_CLASS_MASK,
_CLASS_UNIQUE,
_DNS_PACKET_HEADER_LEN,
_EXPIRE_FULL_TIME_PERCENT,
_EXPIRE_STALE_TIME_PERCENT,
_FLAGS_QR_MASK,
Expand All @@ -54,6 +55,12 @@
_TYPE_TXT,
)

_LEN_BYTE = 1
_LEN_SHORT = 2
_LEN_INT = 4

_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length
_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2

if TYPE_CHECKING:
# https://github.com/PyCQA/pylint/issues/3525
Expand Down Expand Up @@ -118,6 +125,14 @@ def answered_by(self, rec: 'DNSRecord') -> bool:
and self.name == rec.name
)

def __hash__(self) -> int:
return hash((self.name, self.class_, self.type))

@property
def max_size(self) -> int:
"""Maximum size of the question in the packet."""
return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class

@property
def unicast(self) -> bool:
"""Returns true if the QU (not QM) is set.
Expand Down Expand Up @@ -291,6 +306,16 @@ def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) ->
super().__init__(name, type_, class_, ttl)
self.alias = alias

@property
def max_size_compressed(self) -> int:
"""Maximum size of the record in the packet assuming the name has been compressed."""
return (
_BASE_MAX_SIZE
+ _NAME_COMPRESSION_MIN_SIZE
+ (len(self.alias) - len(self.name))
+ _NAME_COMPRESSION_MIN_SIZE
)

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
out.write_name(self.alias)
Expand Down Expand Up @@ -590,7 +615,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
# these 3 are per-packet -- see also _reset_for_next_packet()
self.names: Dict[str, int] = {}
self.data: List[bytes] = []
self.size: int = 12
self.size: int = _DNS_PACKET_HEADER_LEN
self.allow_long: bool = True

self.state = self.State.init
Expand All @@ -603,7 +628,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
def _reset_for_next_packet(self) -> None:
self.names = {}
self.data = []
self.size = 12
self.size = _DNS_PACKET_HEADER_LEN
self.allow_long = True

def __repr__(self) -> str:
Expand Down
96 changes: 78 additions & 18 deletions zeroconf/_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@
_CLASS_UNIQUE,
_DNS_HOST_TTL,
_DNS_OTHER_TTL,
_DNS_PACKET_HEADER_LEN,
_EXPIRE_REFRESH_TIME_PERCENT,
_FLAGS_QR_QUERY,
_LISTENER_TIME,
_MAX_MSG_TYPICAL,
_MDNS_ADDR,
_MDNS_ADDR6,
_MDNS_PORT,
Expand All @@ -63,6 +65,9 @@
from .._core import Zeroconf # pylint: disable=cyclic-import


_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]]


@enum.unique
class ServiceStateChange(enum.Enum):
Added = 1
Expand Down Expand Up @@ -151,6 +156,67 @@ def update_records_complete(self) -> None:
"""


class _DNSPointerOutgoingBucket:
"""A DNSOutgoing bucket."""

def __init__(self, now: float, multicast: bool) -> None:
"""Create a bucke to wrap a DNSOutgoing."""
self.now = now
self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast)
self.bytes = 0

def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None:
"""Add a new set of questions and known answers to the outgoing."""
self.out.add_question(question)
for answer in answers:
self.out.add_answer_at_time(answer, self.now)
self.bytes += max_compressed_size


def _group_ptr_queries_with_known_answers(
now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers
) -> List[DNSOutgoing]:
"""Aggregate queries so that as many known answers as possible fit in the same packet
without having known answers spill over into the next packet unless the
question and known answers are always going to exceed the packet size.

Some responders do not implement multi-packet known answer suppression
so we try to keep all the known answers in the same packet as the
questions.
"""
# This is the maximum size the query + known answers can be with name compression.
# The actual size of the query + known answers may be a bit smaller since other
# parts may be shared when the final DNSOutgoing packets are constructed. The
# goal of this algorithm is to quickly bucket the query + known answers without
# the overhead of actually constructing the packets.
query_by_size: Dict[DNSQuestion, int] = {
question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers]))
for question, known_answers in question_with_known_answers.items()
}
max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN
query_buckets: List[_DNSPointerOutgoingBucket] = []
for question in sorted(
query_by_size,
key=query_by_size.get, # type: ignore
reverse=True,
):
max_compressed_size = query_by_size[question]
answers = question_with_known_answers[question]
for query_bucket in query_buckets:
if query_bucket.bytes + max_compressed_size <= max_bucket_size:
query_bucket.add(max_compressed_size, question, answers)
break
else:
# If a single question and known answers won't fit in a packet
# we will end up generating multiple packets, but there will never
# be multiple questions
query_bucket = _DNSPointerOutgoingBucket(now, multicast)
query_bucket.add(max_compressed_size, question, answers)
query_buckets.append(query_bucket)

return [query_bucket.out for query_bucket in query_buckets]


class _ServiceBrowserBase(RecordUpdateListener):
"""Base class for ServiceBrowser."""

Expand All @@ -174,9 +240,7 @@ def __init__(
self.addr = addr
self.port = port
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
self._services = {
check_type_: {} for check_type_ in self.types
} # type: Dict[str, Dict[str, DNSRecord]]
self._services: Dict[str, Dict[str, DNSPointer]] = {check_type_: {} for check_type_ in self.types}
current_time = current_time_millis()
self._next_time = {check_type_: current_time for check_type_ in self.types}
self._delay = {check_type_: delay for check_type_ in self.types}
Expand Down Expand Up @@ -317,29 +381,25 @@ def run(self) -> None:
questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]
self.zc.add_listener(self, questions)

def generate_ready_queries(self) -> Optional[DNSOutgoing]:
def generate_ready_queries(self) -> List[DNSOutgoing]:
"""Generate the service browser query for any type that is due."""
out = None
now = current_time_millis()

if min(self._next_time.values()) > now:
return out
return []

questions_with_known_answers: _QuestionWithKnownAnswers = {}

for type_, due in self._next_time.items():
if due > now:
continue

if out is None:
out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN))

for record in self._services[type_].values():
if not record.is_stale(now):
out.add_answer_at_time(record, now)

questions_with_known_answers[DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)] = set(
record for record in self._services[type_].values() if not record.is_stale(now)
)
self._next_time[type_] = now + self._delay[type_]
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)
return out

return _group_ptr_queries_with_known_answers(now, self.multicast, questions_with_known_answers)

def _seconds_to_wait(self) -> Optional[float]:
"""Returns the number of seconds to wait for the next event."""
Expand Down Expand Up @@ -406,8 +466,8 @@ def run(self) -> None:
if self.zc.done or self.done:
return

out = self.generate_ready_queries()
if out:
outs = self.generate_ready_queries()
for out in outs:
self.zc.send(out, addr=self.addr, port=self.port)

if not self._handlers_to_call:
Expand Down
4 changes: 2 additions & 2 deletions zeroconf/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ async def async_run(self) -> None:
if not self._handlers_to_call:
await wait_condition_or_timeout(self.aiozc.condition, timeout)

out = self.generate_ready_queries()
if out:
outs = self.generate_ready_queries()
for out in outs:
self.aiozc.zeroconf.async_send(out, addr=self.addr, port=self.port)

if not self._handlers_to_call:
Expand Down
2 changes: 2 additions & 0 deletions zeroconf/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762
_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762

_DNS_PACKET_HEADER_LEN = 12

_MAX_MSG_TYPICAL = 1460 # unused
_MAX_MSG_ABSOLUTE = 8966

Expand Down