Skip to content
Merged
86 changes: 46 additions & 40 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
_CHECK_TIME,
_CLASS_IN,
_CLASS_UNIQUE,
_DUPLICATE_PACKET_SUPPRESSION_INTERVAL,
_FLAGS_AA,
_FLAGS_QR_QUERY,
_FLAGS_QR_RESPONSE,
Expand Down Expand Up @@ -259,26 +260,20 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self.data: Optional[bytes] = None
self.last_time: float = 0
self.last_message: Optional[DNSIncoming] = None
self.transport: Optional[_WrappedTransport] = None
self.sock_description: Optional[str] = None
self._deferred: Dict[str, List[DNSIncoming]] = {}
self._timers: Dict[str, asyncio.TimerHandle] = {}
super().__init__()

def suppress_duplicate_packet(self, data: bytes, now: float) -> bool:
"""Suppress duplicate packet if the last one was the same in the last second."""
if self.data == data and (now - 1000) < self.last_time:
return True
self.data = data
self.last_time = now
return False

def datagram_received(
self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]]
) -> None:
assert self.transport is not None
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
data_len = len(data)
debug = log.isEnabledFor(logging.DEBUG)

if len(addrs) == 2:
# https://github.com/python/mypy/issues/1178
Expand All @@ -290,19 +285,6 @@ def datagram_received(
log.debug('IPv6 scope_id %d associated to the receiving interface', scope)
v6_flow_scope = (flow, scope)

now = current_time_millis()
if self.suppress_duplicate_packet(data, now):
# Guard against duplicate packets
log.debug(
'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]',
addr,
port,
self.sock_description,
data_len,
data,
)
return

if data_len > _MAX_MSG_ABSOLUTE:
# Guard against oversized packets to ensure bad implementations cannot overwhelm
# the system.
Expand All @@ -314,26 +296,50 @@ def datagram_received(
)
return

now = current_time_millis()
if (
self.data == data
and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time
and self.last_message is not None
and not self.last_message.has_qu_question()
):
# Guard against duplicate packets
if debug:
log.debug(
'Ignoring duplicate message with no unicast questions received from %r:%r [socket %s] (%d bytes) as [%r]',
addr,
port,
self.sock_description,
data_len,
data,
)
return

msg = DNSIncoming(data, (addr, port), scope, now)
self.data = data
self.last_time = now
self.last_message = msg
if msg.valid:
log.debug(
'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
addr,
port,
self.sock_description,
msg,
data_len,
data,
)
if debug:
log.debug(
'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
addr,
port,
self.sock_description,
msg,
data_len,
data,
)
else:
log.debug(
'Received from %r:%r [socket %s]: (%d bytes) [%r]',
addr,
port,
self.sock_description,
data_len,
data,
)
if debug:
log.debug(
'Received from %r:%r [socket %s]: (%d bytes) [%r]',
addr,
port,
self.sock_description,
data_len,
data,
)
return

if not msg.is_query():
Expand Down Expand Up @@ -722,8 +728,8 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
if broadcast_addresses:
for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now):
out.add_answer_at_time(dns_address, 0)
for record in info.get_address_and_nsec_records(override_ttl=host_ttl, created=now):
out.add_answer_at_time(record, 0)

def unregister_service(self, info: ServiceInfo) -> None:
"""Unregister a service.
Expand Down
42 changes: 9 additions & 33 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,17 @@
)

from ._cache import DNSCache, _UniqueRecordsType
from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
from ._history import QuestionHistory
from ._logger import log
from ._protocol.incoming import DNSIncoming
from ._protocol.outgoing import DNSOutgoing
from ._services.info import ServiceInfo
from ._services.registry import ServiceRegistry
from ._updates import RecordUpdate, RecordUpdateListener
from ._utils.time import current_time_millis, millis_to_seconds
from .const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_CLASS_UNIQUE,
_DNS_OTHER_TTL,
_DNS_PTR_MIN_TTL,
_FLAGS_AA,
Expand Down Expand Up @@ -90,15 +88,6 @@ class AnswerGroup(NamedTuple):
answers: _AnswerWithAdditionalsType


def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
"""Construct an NSEC record for name and a list of dns types.

This function should only be used for SRV/A/AAAA records
which have a TTL of _DNS_OTHER_TTL
"""
return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now)


def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing:
"""Add answers and additionals to a DNSOutgoing."""
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True)
Expand Down Expand Up @@ -217,20 +206,6 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND)


def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
seen_types: Set[int] = set()
records: Set[DNSRecord] = set()
for dns_address in service.dns_addresses(created=now):
seen_types.add(dns_address.type)
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
records.add(construct_nsec_record(service.server, list(missing_types), now))
return records


class QueryHandler:
"""Query the ServiceRegistry."""

Expand Down Expand Up @@ -264,9 +239,10 @@ def _add_pointer_answers(
dns_pointer = service.dns_pointer(created=now)
if known_answers.suppresses(dns_pointer):
continue
additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)}
additionals |= _get_address_and_nsec_records(service, now)
answer_set[dns_pointer] = additionals
answer_set[dns_pointer] = {
service.dns_service(created=now),
service.dns_text(created=now),
} | service.get_address_and_nsec_records(created=now)

def _add_address_answers(
self,
Expand All @@ -291,12 +267,12 @@ def _add_address_answers(
if answers:
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
additionals.add(construct_nsec_record(service.server, list(missing_types), now))
additionals.add(service.dns_nsec(list(missing_types), created=now))
for answer in answers:
answer_set[answer] = additionals
elif type_ in missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()
answer_set[service.dns_nsec(list(missing_types), created=now)] = set()

def _answer_question(
self,
Expand Down Expand Up @@ -327,7 +303,7 @@ def _answer_question(
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service.dns_service(created=now)
if not known_answers.suppresses(dns_service):
answer_set[dns_service] = _get_address_and_nsec_records(service, now)
answer_set[dns_service] = service.get_address_and_nsec_records(created=now)
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service.dns_text(created=now)
if not known_answers.suppresses(dns_text):
Expand Down Expand Up @@ -496,7 +472,7 @@ def async_add_listener(
its update_record method called when information is available to
answer the question(s).

This function is not threadsafe and must be called in the eventloop.
This function is not thread-safe and must be called in the eventloop.
"""
if not isinstance(listener, RecordUpdateListener):
log.error( # type: ignore[unreachable]
Expand Down
5 changes: 5 additions & 0 deletions src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ cdef class DNSIncoming:
cdef public object scope_id
cdef public object source

@cython.locals(
question=DNSQuestion
)
cpdef has_qu_question(self)

@cython.locals(
off=cython.uint,
label_idx=cython.uint,
Expand Down
10 changes: 10 additions & 0 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def is_response(self) -> bool:
"""Returns true if this is a response."""
return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE

def has_qu_question(self) -> bool:
"""Returns true if any question is a QU question."""
if not self.num_questions:
return False
for question in self.questions:
# QU questions use the same bit as unique
if question.unique:
return True
return False

@property
def truncated(self) -> bool:
"""Returns true if this is a truncated."""
Expand Down
34 changes: 33 additions & 1 deletion src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
import ipaddress
import random
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from .._dns import (
DNSAddress,
DNSNsec,
DNSPointer,
DNSQuestionType,
DNSRecord,
Expand All @@ -47,6 +48,7 @@
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis, millis_to_seconds
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_CLASS_UNIQUE,
_DNS_HOST_TTL,
Expand All @@ -55,6 +57,7 @@
_LISTENER_TIME,
_TYPE_A,
_TYPE_AAAA,
_TYPE_NSEC,
_TYPE_PTR,
_TYPE_SRV,
_TYPE_TXT,
Expand Down Expand Up @@ -530,6 +533,35 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]
created,
)

def dns_nsec(
self, missing_types: List[int], override_ttl: Optional[int] = None, created: Optional[float] = None
) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return DNSNsec(
self.name,
_TYPE_NSEC,
_CLASS_IN | _CLASS_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
self.name,
missing_types,
created,
)

def get_address_and_nsec_records(
self, override_ttl: Optional[int] = None, created: Optional[float] = None
) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
seen_types: Set[int] = set()
records: Set[DNSRecord] = set()
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
seen_types.add(dns_address.type)
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
assert self.server is not None, "Service server must be set for NSEC record."
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
return records

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
"""Get the addresses from the cache."""
if self.server_key is None:
Expand Down
1 change: 1 addition & 0 deletions src/zeroconf/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_LISTENER_TIME = 200 # ms
_BROWSER_TIME = 1000 # ms
_DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms
_DUPLICATE_PACKET_SUPPRESSION_INTERVAL = 1000
_BROWSER_BACKOFF_LIMIT = 3600 # s
_CACHE_CLEANUP_INTERVAL = 10 # s
_LOADED_SYSTEM_TIMEOUT = 10 # s
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,16 @@ def run_isolated():
const, "_MDNS_PORT", 5454
):
yield


@pytest.fixture
def disable_duplicate_packet_suppression():
"""Disable duplicate packet suppress.

Some tests run too slowly because of the duplicate
packet suppression.
"""
with unittest.mock.patch.object(
_core, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0
), unittest.mock.patch.object(const, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0):
yield
4 changes: 1 addition & 3 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
service_updated_event.set()

def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:

generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
assert generated.is_response() is True

Expand Down Expand Up @@ -331,7 +330,6 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi

class TestServiceBrowserMultipleTypes(unittest.TestCase):
def test_update_record(self):

service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local']
service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.']

Expand Down Expand Up @@ -580,7 +578,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
pass

browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change], delay=5)
time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 5))
time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 50))
try:
assert first_outgoing.questions[0].unicast is True # type: ignore[union-attr]
assert second_outgoing.questions[0].unicast is False # type: ignore[attr-defined]
Expand Down
Loading