Skip to content

Commit 74d7ba1

Browse files
authored
fix: always answer QU questions when the exact same packet is received from different sources in sequence (#1178)
If the exact same packet with a QU question is asked from two different sources in a 1s window we end up ignoring the second one as a duplicate. We should still respond in this case because the client wants a unicast response and the question may not be answered by the previous packet since the response may not be multicast. fix: include NSEC records in initial broadcast when registering a new service This also revealed that we do not send NSEC records in the initial broadcast. This needed to be fixed in this PR as well for everything to work as expected since all the tests would fail with 2 updates otherwise.
1 parent b356bc8 commit 74d7ba1

15 files changed

Lines changed: 437 additions & 371 deletions

src/zeroconf/_core.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
_CHECK_TIME,
7272
_CLASS_IN,
7373
_CLASS_UNIQUE,
74+
_DUPLICATE_PACKET_SUPPRESSION_INTERVAL,
7475
_FLAGS_AA,
7576
_FLAGS_QR_QUERY,
7677
_FLAGS_QR_RESPONSE,
@@ -259,26 +260,20 @@ def __init__(self, zc: 'Zeroconf') -> None:
259260
self.zc = zc
260261
self.data: Optional[bytes] = None
261262
self.last_time: float = 0
263+
self.last_message: Optional[DNSIncoming] = None
262264
self.transport: Optional[_WrappedTransport] = None
263265
self.sock_description: Optional[str] = None
264266
self._deferred: Dict[str, List[DNSIncoming]] = {}
265267
self._timers: Dict[str, asyncio.TimerHandle] = {}
266268
super().__init__()
267269

268-
def suppress_duplicate_packet(self, data: bytes, now: float) -> bool:
269-
"""Suppress duplicate packet if the last one was the same in the last second."""
270-
if self.data == data and (now - 1000) < self.last_time:
271-
return True
272-
self.data = data
273-
self.last_time = now
274-
return False
275-
276270
def datagram_received(
277271
self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]]
278272
) -> None:
279273
assert self.transport is not None
280274
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
281275
data_len = len(data)
276+
debug = log.isEnabledFor(logging.DEBUG)
282277

283278
if len(addrs) == 2:
284279
# https://github.com/python/mypy/issues/1178
@@ -290,19 +285,6 @@ def datagram_received(
290285
log.debug('IPv6 scope_id %d associated to the receiving interface', scope)
291286
v6_flow_scope = (flow, scope)
292287

293-
now = current_time_millis()
294-
if self.suppress_duplicate_packet(data, now):
295-
# Guard against duplicate packets
296-
log.debug(
297-
'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]',
298-
addr,
299-
port,
300-
self.sock_description,
301-
data_len,
302-
data,
303-
)
304-
return
305-
306288
if data_len > _MAX_MSG_ABSOLUTE:
307289
# Guard against oversized packets to ensure bad implementations cannot overwhelm
308290
# the system.
@@ -314,26 +296,50 @@ def datagram_received(
314296
)
315297
return
316298

299+
now = current_time_millis()
300+
if (
301+
self.data == data
302+
and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time
303+
and self.last_message is not None
304+
and not self.last_message.has_qu_question()
305+
):
306+
# Guard against duplicate packets
307+
if debug:
308+
log.debug(
309+
'Ignoring duplicate message with no unicast questions received from %r:%r [socket %s] (%d bytes) as [%r]',
310+
addr,
311+
port,
312+
self.sock_description,
313+
data_len,
314+
data,
315+
)
316+
return
317+
317318
msg = DNSIncoming(data, (addr, port), scope, now)
319+
self.data = data
320+
self.last_time = now
321+
self.last_message = msg
318322
if msg.valid:
319-
log.debug(
320-
'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
321-
addr,
322-
port,
323-
self.sock_description,
324-
msg,
325-
data_len,
326-
data,
327-
)
323+
if debug:
324+
log.debug(
325+
'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
326+
addr,
327+
port,
328+
self.sock_description,
329+
msg,
330+
data_len,
331+
data,
332+
)
328333
else:
329-
log.debug(
330-
'Received from %r:%r [socket %s]: (%d bytes) [%r]',
331-
addr,
332-
port,
333-
self.sock_description,
334-
data_len,
335-
data,
336-
)
334+
if debug:
335+
log.debug(
336+
'Received from %r:%r [socket %s]: (%d bytes) [%r]',
337+
addr,
338+
port,
339+
self.sock_description,
340+
data_len,
341+
data,
342+
)
337343
return
338344

339345
if not msg.is_query():
@@ -722,8 +728,8 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
722728
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
723729
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
724730
if broadcast_addresses:
725-
for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now):
726-
out.add_answer_at_time(dns_address, 0)
731+
for record in info.get_address_and_nsec_records(override_ttl=host_ttl, created=now):
732+
out.add_answer_at_time(record, 0)
727733

728734
def unregister_service(self, info: ServiceInfo) -> None:
729735
"""Unregister a service.

src/zeroconf/_handlers.py

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,17 @@
3737
)
3838

3939
from ._cache import DNSCache, _UniqueRecordsType
40-
from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
40+
from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
4141
from ._history import QuestionHistory
4242
from ._logger import log
4343
from ._protocol.incoming import DNSIncoming
4444
from ._protocol.outgoing import DNSOutgoing
45-
from ._services.info import ServiceInfo
4645
from ._services.registry import ServiceRegistry
4746
from ._updates import RecordUpdate, RecordUpdateListener
4847
from ._utils.time import current_time_millis, millis_to_seconds
4948
from .const import (
5049
_ADDRESS_RECORD_TYPES,
5150
_CLASS_IN,
52-
_CLASS_UNIQUE,
5351
_DNS_OTHER_TTL,
5452
_DNS_PTR_MIN_TTL,
5553
_FLAGS_AA,
@@ -90,15 +88,6 @@ class AnswerGroup(NamedTuple):
9088
answers: _AnswerWithAdditionalsType
9189

9290

93-
def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
94-
"""Construct an NSEC record for name and a list of dns types.
95-
96-
This function should only be used for SRV/A/AAAA records
97-
which have a TTL of _DNS_OTHER_TTL
98-
"""
99-
return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now)
100-
101-
10291
def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing:
10392
"""Add answers and additionals to a DNSOutgoing."""
10493
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True)
@@ -217,20 +206,6 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
217206
return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND)
218207

219208

220-
def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]:
221-
"""Build a set of address records and NSEC records for non-present record types."""
222-
seen_types: Set[int] = set()
223-
records: Set[DNSRecord] = set()
224-
for dns_address in service.dns_addresses(created=now):
225-
seen_types.add(dns_address.type)
226-
records.add(dns_address)
227-
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
228-
if missing_types:
229-
assert service.server is not None, "Service server must be set for NSEC record."
230-
records.add(construct_nsec_record(service.server, list(missing_types), now))
231-
return records
232-
233-
234209
class QueryHandler:
235210
"""Query the ServiceRegistry."""
236211

@@ -264,9 +239,10 @@ def _add_pointer_answers(
264239
dns_pointer = service.dns_pointer(created=now)
265240
if known_answers.suppresses(dns_pointer):
266241
continue
267-
additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)}
268-
additionals |= _get_address_and_nsec_records(service, now)
269-
answer_set[dns_pointer] = additionals
242+
answer_set[dns_pointer] = {
243+
service.dns_service(created=now),
244+
service.dns_text(created=now),
245+
} | service.get_address_and_nsec_records(created=now)
270246

271247
def _add_address_answers(
272248
self,
@@ -291,12 +267,12 @@ def _add_address_answers(
291267
if answers:
292268
if missing_types:
293269
assert service.server is not None, "Service server must be set for NSEC record."
294-
additionals.add(construct_nsec_record(service.server, list(missing_types), now))
270+
additionals.add(service.dns_nsec(list(missing_types), created=now))
295271
for answer in answers:
296272
answer_set[answer] = additionals
297273
elif type_ in missing_types:
298274
assert service.server is not None, "Service server must be set for NSEC record."
299-
answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()
275+
answer_set[service.dns_nsec(list(missing_types), created=now)] = set()
300276

301277
def _answer_question(
302278
self,
@@ -327,7 +303,7 @@ def _answer_question(
327303
# https://tools.ietf.org/html/rfc6763#section-12.2.
328304
dns_service = service.dns_service(created=now)
329305
if not known_answers.suppresses(dns_service):
330-
answer_set[dns_service] = _get_address_and_nsec_records(service, now)
306+
answer_set[dns_service] = service.get_address_and_nsec_records(created=now)
331307
if type_ in (_TYPE_TXT, _TYPE_ANY):
332308
dns_text = service.dns_text(created=now)
333309
if not known_answers.suppresses(dns_text):
@@ -496,7 +472,7 @@ def async_add_listener(
496472
its update_record method called when information is available to
497473
answer the question(s).
498474
499-
This function is not threadsafe and must be called in the eventloop.
475+
This function is not thread-safe and must be called in the eventloop.
500476
"""
501477
if not isinstance(listener, RecordUpdateListener):
502478
log.error( # type: ignore[unreachable]

src/zeroconf/_protocol/incoming.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ cdef class DNSIncoming:
6565
cdef public object scope_id
6666
cdef public object source
6767

68+
@cython.locals(
69+
question=DNSQuestion
70+
)
71+
cpdef has_qu_question(self)
72+
6873
@cython.locals(
6974
off=cython.uint,
7075
label_idx=cython.uint,

src/zeroconf/_protocol/incoming.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ def is_response(self) -> bool:
135135
"""Returns true if this is a response."""
136136
return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
137137

138+
def has_qu_question(self) -> bool:
139+
"""Returns true if any question is a QU question."""
140+
if not self.num_questions:
141+
return False
142+
for question in self.questions:
143+
# QU questions use the same bit as unique
144+
if question.unique:
145+
return True
146+
return False
147+
138148
@property
139149
def truncated(self) -> bool:
140150
"""Returns true if this is a truncated."""

src/zeroconf/_services/info.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
import ipaddress
2525
import random
2626
from functools import lru_cache
27-
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
27+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
2828

2929
from .._dns import (
3030
DNSAddress,
31+
DNSNsec,
3132
DNSPointer,
3233
DNSQuestionType,
3334
DNSRecord,
@@ -47,6 +48,7 @@
4748
from .._utils.net import IPVersion, _encode_address
4849
from .._utils.time import current_time_millis, millis_to_seconds
4950
from ..const import (
51+
_ADDRESS_RECORD_TYPES,
5052
_CLASS_IN,
5153
_CLASS_UNIQUE,
5254
_DNS_HOST_TTL,
@@ -55,6 +57,7 @@
5557
_LISTENER_TIME,
5658
_TYPE_A,
5759
_TYPE_AAAA,
60+
_TYPE_NSEC,
5861
_TYPE_PTR,
5962
_TYPE_SRV,
6063
_TYPE_TXT,
@@ -530,6 +533,35 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]
530533
created,
531534
)
532535

536+
def dns_nsec(
537+
self, missing_types: List[int], override_ttl: Optional[int] = None, created: Optional[float] = None
538+
) -> DNSNsec:
539+
"""Return DNSNsec from ServiceInfo."""
540+
return DNSNsec(
541+
self.name,
542+
_TYPE_NSEC,
543+
_CLASS_IN | _CLASS_UNIQUE,
544+
override_ttl if override_ttl is not None else self.host_ttl,
545+
self.name,
546+
missing_types,
547+
created,
548+
)
549+
550+
def get_address_and_nsec_records(
551+
self, override_ttl: Optional[int] = None, created: Optional[float] = None
552+
) -> Set[DNSRecord]:
553+
"""Build a set of address records and NSEC records for non-present record types."""
554+
seen_types: Set[int] = set()
555+
records: Set[DNSRecord] = set()
556+
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
557+
seen_types.add(dns_address.type)
558+
records.add(dns_address)
559+
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
560+
if missing_types:
561+
assert self.server is not None, "Service server must be set for NSEC record."
562+
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
563+
return records
564+
533565
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
534566
"""Get the addresses from the cache."""
535567
if self.server_key is None:

src/zeroconf/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_LISTENER_TIME = 200 # ms
3232
_BROWSER_TIME = 1000 # ms
3333
_DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms
34+
_DUPLICATE_PACKET_SUPPRESSION_INTERVAL = 1000
3435
_BROWSER_BACKOFF_LIMIT = 3600 # s
3536
_CACHE_CLEANUP_INTERVAL = 10 # s
3637
_LOADED_SYSTEM_TIMEOUT = 10 # s

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,16 @@ def run_isolated():
2727
const, "_MDNS_PORT", 5454
2828
):
2929
yield
30+
31+
32+
@pytest.fixture
33+
def disable_duplicate_packet_suppression():
34+
"""Disable duplicate packet suppress.
35+
36+
Some tests run too slowly because of the duplicate
37+
packet suppression.
38+
"""
39+
with unittest.mock.patch.object(
40+
_core, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0
41+
), unittest.mock.patch.object(const, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0):
42+
yield

tests/services/test_browser.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
180180
service_updated_event.set()
181181

182182
def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
183-
184183
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
185184
assert generated.is_response() is True
186185

@@ -331,7 +330,6 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi
331330

332331
class TestServiceBrowserMultipleTypes(unittest.TestCase):
333332
def test_update_record(self):
334-
335333
service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local']
336334
service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.']
337335

@@ -580,7 +578,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
580578
pass
581579

582580
browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change], delay=5)
583-
time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 5))
581+
time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 50))
584582
try:
585583
assert first_outgoing.questions[0].unicast is True # type: ignore[union-attr]
586584
assert second_outgoing.questions[0].unicast is False # type: ignore[attr-defined]

0 commit comments

Comments
 (0)