Skip to content

Commit 7e30848

Browse files
authored
Efficiently bucket queries with known answers (#698)
1 parent 26fa2fb commit 7e30848

5 files changed

Lines changed: 133 additions & 23 deletions

File tree

tests/test_services.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pytest
1717

1818
import zeroconf as r
19-
from zeroconf import DNSAddress, const
19+
from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis
2020
import zeroconf._services as s
2121
from zeroconf import Zeroconf
2222
from zeroconf._services import (
@@ -1377,3 +1377,26 @@ def test_serviceinfo_accepts_bytes_or_string_dict():
13771377
addresses=addresses,
13781378
)
13791379
assert info_service.dns_text().text == b'\x0epath=/~paulsm/'
1380+
1381+
1382+
def test_group_ptr_queries_with_known_answers():
1383+
questions_with_known_answers: s._QuestionWithKnownAnswers = {}
1384+
now = current_time_millis()
1385+
for i in range(120):
1386+
name = f"_hap{i}._tcp._local."
1387+
questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set(
1388+
DNSPointer(
1389+
name,
1390+
const._TYPE_PTR,
1391+
const._CLASS_IN,
1392+
4500,
1393+
f"zoo{counter}.{name}",
1394+
)
1395+
for counter in range(i)
1396+
)
1397+
outs = s._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers)
1398+
for out in outs:
1399+
packets = out.packets()
1400+
# If we generate multiple packets there must
1401+
# only be one question
1402+
assert len(packets) == 1 or len(out.questions) == 1

zeroconf/_dns.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_CLASSES,
3535
_CLASS_MASK,
3636
_CLASS_UNIQUE,
37+
_DNS_PACKET_HEADER_LEN,
3738
_EXPIRE_FULL_TIME_PERCENT,
3839
_EXPIRE_STALE_TIME_PERCENT,
3940
_FLAGS_QR_MASK,
@@ -54,6 +55,12 @@
5455
_TYPE_TXT,
5556
)
5657

58+
_LEN_BYTE = 1
59+
_LEN_SHORT = 2
60+
_LEN_INT = 4
61+
62+
_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length
63+
_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2
5764

5865
if TYPE_CHECKING:
5966
# https://github.com/PyCQA/pylint/issues/3525
@@ -118,6 +125,14 @@ def answered_by(self, rec: 'DNSRecord') -> bool:
118125
and self.name == rec.name
119126
)
120127

128+
def __hash__(self) -> int:
129+
return hash((self.name, self.class_, self.type))
130+
131+
@property
132+
def max_size(self) -> int:
133+
"""Maximum size of the question in the packet."""
134+
return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class
135+
121136
@property
122137
def unicast(self) -> bool:
123138
"""Returns true if the QU (not QM) is set.
@@ -291,6 +306,16 @@ def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) ->
291306
super().__init__(name, type_, class_, ttl)
292307
self.alias = alias
293308

309+
@property
310+
def max_size_compressed(self) -> int:
311+
"""Maximum size of the record in the packet assuming the name has been compressed."""
312+
return (
313+
_BASE_MAX_SIZE
314+
+ _NAME_COMPRESSION_MIN_SIZE
315+
+ (len(self.alias) - len(self.name))
316+
+ _NAME_COMPRESSION_MIN_SIZE
317+
)
318+
294319
def write(self, out: 'DNSOutgoing') -> None:
295320
"""Used in constructing an outgoing packet"""
296321
out.write_name(self.alias)
@@ -590,7 +615,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
590615
# these 3 are per-packet -- see also _reset_for_next_packet()
591616
self.names: Dict[str, int] = {}
592617
self.data: List[bytes] = []
593-
self.size: int = 12
618+
self.size: int = _DNS_PACKET_HEADER_LEN
594619
self.allow_long: bool = True
595620

596621
self.state = self.State.init
@@ -603,7 +628,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
603628
def _reset_for_next_packet(self) -> None:
604629
self.names = {}
605630
self.data = []
606-
self.size = 12
631+
self.size = _DNS_PACKET_HEADER_LEN
607632
self.allow_long = True
608633

609634
def __repr__(self) -> str:

zeroconf/_services/__init__.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@
4444
_CLASS_UNIQUE,
4545
_DNS_HOST_TTL,
4646
_DNS_OTHER_TTL,
47+
_DNS_PACKET_HEADER_LEN,
4748
_EXPIRE_REFRESH_TIME_PERCENT,
4849
_FLAGS_QR_QUERY,
4950
_LISTENER_TIME,
51+
_MAX_MSG_TYPICAL,
5052
_MDNS_ADDR,
5153
_MDNS_ADDR6,
5254
_MDNS_PORT,
@@ -63,6 +65,9 @@
6365
from .._core import Zeroconf # pylint: disable=cyclic-import
6466

6567

68+
_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]]
69+
70+
6671
@enum.unique
6772
class ServiceStateChange(enum.Enum):
6873
Added = 1
@@ -151,6 +156,67 @@ def update_records_complete(self) -> None:
151156
"""
152157

153158

159+
class _DNSPointerOutgoingBucket:
160+
"""A DNSOutgoing bucket."""
161+
162+
def __init__(self, now: float, multicast: bool) -> None:
163+
"""Create a bucke to wrap a DNSOutgoing."""
164+
self.now = now
165+
self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast)
166+
self.bytes = 0
167+
168+
def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None:
169+
"""Add a new set of questions and known answers to the outgoing."""
170+
self.out.add_question(question)
171+
for answer in answers:
172+
self.out.add_answer_at_time(answer, self.now)
173+
self.bytes += max_compressed_size
174+
175+
176+
def _group_ptr_queries_with_known_answers(
177+
now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers
178+
) -> List[DNSOutgoing]:
179+
"""Aggregate queries so that as many known answers as possible fit in the same packet
180+
without having known answers spill over into the next packet unless the
181+
question and known answers are always going to exceed the packet size.
182+
183+
Some responders do not implement multi-packet known answer suppression
184+
so we try to keep all the known answers in the same packet as the
185+
questions.
186+
"""
187+
# This is the maximum size the query + known answers can be with name compression.
188+
# The actual size of the query + known answers may be a bit smaller since other
189+
# parts may be shared when the final DNSOutgoing packets are constructed. The
190+
# goal of this algorithm is to quickly bucket the query + known answers without
191+
# the overhead of actually constructing the packets.
192+
query_by_size: Dict[DNSQuestion, int] = {
193+
question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers]))
194+
for question, known_answers in question_with_known_answers.items()
195+
}
196+
max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN
197+
query_buckets: List[_DNSPointerOutgoingBucket] = []
198+
for question in sorted(
199+
query_by_size,
200+
key=query_by_size.get, # type: ignore
201+
reverse=True,
202+
):
203+
max_compressed_size = query_by_size[question]
204+
answers = question_with_known_answers[question]
205+
for query_bucket in query_buckets:
206+
if query_bucket.bytes + max_compressed_size <= max_bucket_size:
207+
query_bucket.add(max_compressed_size, question, answers)
208+
break
209+
else:
210+
# If a single question and known answers won't fit in a packet
211+
# we will end up generating multiple packets, but there will never
212+
# be multiple questions
213+
query_bucket = _DNSPointerOutgoingBucket(now, multicast)
214+
query_bucket.add(max_compressed_size, question, answers)
215+
query_buckets.append(query_bucket)
216+
217+
return [query_bucket.out for query_bucket in query_buckets]
218+
219+
154220
class _ServiceBrowserBase(RecordUpdateListener):
155221
"""Base class for ServiceBrowser."""
156222

@@ -174,9 +240,7 @@ def __init__(
174240
self.addr = addr
175241
self.port = port
176242
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
177-
self._services = {
178-
check_type_: {} for check_type_ in self.types
179-
} # type: Dict[str, Dict[str, DNSRecord]]
243+
self._services: Dict[str, Dict[str, DNSPointer]] = {check_type_: {} for check_type_ in self.types}
180244
current_time = current_time_millis()
181245
self._next_time = {check_type_: current_time for check_type_ in self.types}
182246
self._delay = {check_type_: delay for check_type_ in self.types}
@@ -317,29 +381,25 @@ def run(self) -> None:
317381
questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]
318382
self.zc.add_listener(self, questions)
319383

320-
def generate_ready_queries(self) -> Optional[DNSOutgoing]:
384+
def generate_ready_queries(self) -> List[DNSOutgoing]:
321385
"""Generate the service browser query for any type that is due."""
322-
out = None
323386
now = current_time_millis()
324387

325388
if min(self._next_time.values()) > now:
326-
return out
389+
return []
390+
391+
questions_with_known_answers: _QuestionWithKnownAnswers = {}
327392

328393
for type_, due in self._next_time.items():
329394
if due > now:
330395
continue
331-
332-
if out is None:
333-
out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
334-
out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN))
335-
336-
for record in self._services[type_].values():
337-
if not record.is_stale(now):
338-
out.add_answer_at_time(record, now)
339-
396+
questions_with_known_answers[DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)] = set(
397+
record for record in self._services[type_].values() if not record.is_stale(now)
398+
)
340399
self._next_time[type_] = now + self._delay[type_]
341400
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)
342-
return out
401+
402+
return _group_ptr_queries_with_known_answers(now, self.multicast, questions_with_known_answers)
343403

344404
def _seconds_to_wait(self) -> Optional[float]:
345405
"""Returns the number of seconds to wait for the next event."""
@@ -406,8 +466,8 @@ def run(self) -> None:
406466
if self.zc.done or self.done:
407467
return
408468

409-
out = self.generate_ready_queries()
410-
if out:
469+
outs = self.generate_ready_queries()
470+
for out in outs:
411471
self.zc.send(out, addr=self.addr, port=self.port)
412472

413473
if not self._handlers_to_call:

zeroconf/aio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ async def async_run(self) -> None:
159159
if not self._handlers_to_call:
160160
await wait_condition_or_timeout(self.aiozc.condition, timeout)
161161

162-
out = self.generate_ready_queries()
163-
if out:
162+
outs = self.generate_ready_queries()
163+
for out in outs:
164164
self.aiozc.zeroconf.async_send(out, addr=self.addr, port=self.port)
165165

166166
if not self._handlers_to_call:

zeroconf/const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762
4848
_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762
4949

50+
_DNS_PACKET_HEADER_LEN = 12
51+
5052
_MAX_MSG_TYPICAL = 1460 # unused
5153
_MAX_MSG_ABSOLUTE = 8966
5254

0 commit comments

Comments
 (0)