From 5f87045719338d23597483e8492691b03bbd2013 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Aug 2023 11:30:11 -0500 Subject: [PATCH 1/4] chore: split _engine.py into _transport.py and _listener.py --- src/zeroconf/_engine.py | 243 ++----------------------------------- src/zeroconf/_listener.py | 216 +++++++++++++++++++++++++++++++++ src/zeroconf/_transport.py | 67 ++++++++++ 3 files changed, 291 insertions(+), 235 deletions(-) create mode 100644 src/zeroconf/_listener.py create mode 100644 src/zeroconf/_transport.py diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 00ecf51a..44435750 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -22,71 +22,23 @@ import asyncio import itertools -import logging -import random import socket import threading -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, List, Optional, cast -from ._logger import QuietLogger, log -from ._protocol.incoming import DNSIncoming from ._updates import RecordUpdate from ._utils.asyncio import get_running_loop, run_coro_with_timeout -from ._utils.time import current_time_millis, millis_to_seconds -from .const import ( - _CACHE_CLEANUP_INTERVAL, - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL, - _MAX_MSG_ABSOLUTE, -) +from ._utils.time import current_time_millis +from .const import _CACHE_CLEANUP_INTERVAL if TYPE_CHECKING: from ._core import Zeroconf -_TC_DELAY_RANDOM_INTERVAL = (400, 500) - -_CLOSE_TIMEOUT = 3000 # ms - - -class _WrappedTransport: - """A wrapper for transports.""" - - __slots__ = ( - 'transport', - 'is_ipv6', - 'sock', - 'fileno', - 'sock_name', - ) - - def __init__( - self, - transport: asyncio.DatagramTransport, - is_ipv6: bool, - sock: socket.socket, - fileno: int, - sock_name: Any, - ) -> None: - """Initialize the wrapped transport. - - These attributes are used when sending packets. - """ - self.transport = transport - self.is_ipv6 = is_ipv6 - self.sock = sock - self.fileno = fileno - self.sock_name = sock_name +from ._listener import AsyncListener +from ._transport import _WrappedTransport, make_wrapped_transport -def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport: - """Make a wrapped transport.""" - sock: socket.socket = transport.get_extra_info('socket') - return _WrappedTransport( - transport=transport, - is_ipv6=sock.family == socket.AF_INET6, - sock=sock, - fileno=sock.fileno(), - sock_name=sock.getsockname(), - ) +_CLOSE_TIMEOUT = 3000 # ms class AsyncEngine: @@ -154,9 +106,9 @@ async def _async_create_endpoints(self) -> None: lambda: AsyncListener(self.zc), sock=s # type: ignore[arg-type, return-value] ) self.protocols.append(cast(AsyncListener, protocol)) - self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) + self.readers.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) if s in sender_sockets: - self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) + self.senders.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" @@ -198,182 +150,3 @@ def close(self) -> None: if not self.loop.is_running(): return run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT) - - -class AsyncListener: - - """A Listener is used by this module to listen on the multicast - group to which DNS messages are sent, allowing the implementation - to cache information as it arrives. - - It requires registration with an Engine object in order to have - the read() method called when a socket is available for reading.""" - - __slots__ = ( - 'zc', - 'data', - 'last_time', - 'last_message', - 'transport', - 'sock_description', - '_deferred', - '_timers', - ) - - def __init__(self, zc: 'Zeroconf') -> None: - self.zc = zc - self.data: Optional[bytes] = None - self.last_time: float = 0 - self.last_message: Optional[DNSIncoming] = None - self.transport: Optional[_WrappedTransport] = None - self.sock_description: Optional[str] = None - self._deferred: Dict[str, List[DNSIncoming]] = {} - self._timers: Dict[str, asyncio.TimerHandle] = {} - super().__init__() - - def datagram_received( - self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] - ) -> None: - assert self.transport is not None - data_len = len(data) - debug = log.isEnabledFor(logging.DEBUG) - - if data_len > _MAX_MSG_ABSOLUTE: - # Guard against oversized packets to ensure bad implementations cannot overwhelm - # the system. - if debug: - log.debug( - "Discarding incoming packet with length %s, which is larger " - "than the absolute maximum size of %s", - data_len, - _MAX_MSG_ABSOLUTE, - ) - return - - now = current_time_millis() - if ( - self.data == data - and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time - and self.last_message is not None - and not self.last_message.has_qu_question() - ): - # Guard against duplicate packets - if debug: - log.debug( - 'Ignoring duplicate message with no unicast questions received from %s [socket %s] (%d bytes) as [%r]', - addrs, - self.sock_description, - data_len, - data, - ) - return - - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () - if len(addrs) == 2: - # https://github.com/python/mypy/issues/1178 - addr, port = addrs # type: ignore - scope = None - else: - # https://github.com/python/mypy/issues/1178 - addr, port, flow, scope = addrs # type: ignore - if debug: # pragma: no branch - log.debug('IPv6 scope_id %d associated to the receiving interface', scope) - v6_flow_scope = (flow, scope) - - msg = DNSIncoming(data, (addr, port), scope, now) - self.data = data - self.last_time = now - self.last_message = msg - if msg.valid: - if debug: - log.debug( - 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', - addr, - port, - self.sock_description, - msg, - data_len, - data, - ) - else: - if debug: - log.debug( - 'Received from %r:%r [socket %s]: (%d bytes) [%r]', - addr, - port, - self.sock_description, - data_len, - data, - ) - return - - if not msg.is_query(): - self.zc.handle_response(msg) - return - - self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) - - def handle_query_or_defer( - self, - msg: DNSIncoming, - addr: str, - port: int, - transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - ) -> None: - """Deal with incoming query packets. Provides a response if - possible.""" - if not msg.truncated: - self._respond_query(msg, addr, port, transport, v6_flow_scope) - return - - deferred = self._deferred.setdefault(addr, []) - # If we get the same packet we ignore it - for incoming in reversed(deferred): - if incoming.data == msg.data: - return - deferred.append(msg) - delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) - assert self.zc.loop is not None - self._cancel_any_timers_for_addr(addr) - self._timers[addr] = self.zc.loop.call_later( - delay, self._respond_query, None, addr, port, transport, v6_flow_scope - ) - - def _cancel_any_timers_for_addr(self, addr: str) -> None: - """Cancel any future truncated packet timers for the address.""" - if addr in self._timers: - self._timers.pop(addr).cancel() - - def _respond_query( - self, - msg: Optional[DNSIncoming], - addr: str, - port: int, - transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - ) -> None: - """Respond to a query and reassemble any truncated deferred packets.""" - self._cancel_any_timers_for_addr(addr) - packets = self._deferred.pop(addr, []) - if msg: - packets.append(msg) - - self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) - - def error_received(self, exc: Exception) -> None: - """Likely socket closed or IPv6.""" - # We preformat the message string with the socket as we want - # log_exception_once to log a warrning message once PER EACH - # different socket in case there are problems with multiple - # sockets - msg_str = f"Error with socket {self.sock_description}): %s" - QuietLogger.log_exception_once(exc, msg_str, exc) - - def connection_made(self, transport: asyncio.BaseTransport) -> None: - wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport)) - self.transport = wrapped_transport - self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" - - def connection_lost(self, exc: Optional[Exception]) -> None: - """Handle connection lost.""" diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py new file mode 100644 index 00000000..97bcf007 --- /dev/null +++ b/src/zeroconf/_listener.py @@ -0,0 +1,216 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import logging +import random +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast + +from ._logger import QuietLogger, log +from ._protocol.incoming import DNSIncoming +from ._transport import _WrappedTransport, make_wrapped_transport +from ._utils.time import current_time_millis, millis_to_seconds +from .const import _DUPLICATE_PACKET_SUPPRESSION_INTERVAL, _MAX_MSG_ABSOLUTE + +if TYPE_CHECKING: + from ._core import Zeroconf + +_TC_DELAY_RANDOM_INTERVAL = (400, 500) + + +class AsyncListener: + + """A Listener is used by this module to listen on the multicast + group to which DNS messages are sent, allowing the implementation + to cache information as it arrives. + + It requires registration with an Engine object in order to have + the read() method called when a socket is available for reading.""" + + __slots__ = ( + 'zc', + 'data', + 'last_time', + 'last_message', + 'transport', + 'sock_description', + '_deferred', + '_timers', + ) + + def __init__(self, zc: 'Zeroconf') -> None: + self.zc = zc + self.data: Optional[bytes] = None + self.last_time: float = 0 + self.last_message: Optional[DNSIncoming] = None + self.transport: Optional[_WrappedTransport] = None + self.sock_description: Optional[str] = None + self._deferred: Dict[str, List[DNSIncoming]] = {} + self._timers: Dict[str, asyncio.TimerHandle] = {} + super().__init__() + + def datagram_received( + self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] + ) -> None: + assert self.transport is not None + data_len = len(data) + debug = log.isEnabledFor(logging.DEBUG) + + if data_len > _MAX_MSG_ABSOLUTE: + # Guard against oversized packets to ensure bad implementations cannot overwhelm + # the system. + if debug: + log.debug( + "Discarding incoming packet with length %s, which is larger " + "than the absolute maximum size of %s", + data_len, + _MAX_MSG_ABSOLUTE, + ) + return + + now = current_time_millis() + if ( + self.data == data + and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time + and self.last_message is not None + and not self.last_message.has_qu_question() + ): + # Guard against duplicate packets + if debug: + log.debug( + 'Ignoring duplicate message with no unicast questions received from %s [socket %s] (%d bytes) as [%r]', + addrs, + self.sock_description, + data_len, + data, + ) + return + + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + if len(addrs) == 2: + # https://github.com/python/mypy/issues/1178 + addr, port = addrs # type: ignore + scope = None + else: + # https://github.com/python/mypy/issues/1178 + addr, port, flow, scope = addrs # type: ignore + if debug: # pragma: no branch + log.debug('IPv6 scope_id %d associated to the receiving interface', scope) + v6_flow_scope = (flow, scope) + + msg = DNSIncoming(data, (addr, port), scope, now) + self.data = data + self.last_time = now + self.last_message = msg + if msg.valid: + if debug: + log.debug( + 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', + addr, + port, + self.sock_description, + msg, + data_len, + data, + ) + else: + if debug: + log.debug( + 'Received from %r:%r [socket %s]: (%d bytes) [%r]', + addr, + port, + self.sock_description, + data_len, + data, + ) + return + + if not msg.is_query(): + self.zc.handle_response(msg) + return + + self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) + + def handle_query_or_defer( + self, + msg: DNSIncoming, + addr: str, + port: int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Deal with incoming query packets. Provides a response if + possible.""" + if not msg.truncated: + self._respond_query(msg, addr, port, transport, v6_flow_scope) + return + + deferred = self._deferred.setdefault(addr, []) + # If we get the same packet we ignore it + for incoming in reversed(deferred): + if incoming.data == msg.data: + return + deferred.append(msg) + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) + assert self.zc.loop is not None + self._cancel_any_timers_for_addr(addr) + self._timers[addr] = self.zc.loop.call_later( + delay, self._respond_query, None, addr, port, transport, v6_flow_scope + ) + + def _cancel_any_timers_for_addr(self, addr: str) -> None: + """Cancel any future truncated packet timers for the address.""" + if addr in self._timers: + self._timers.pop(addr).cancel() + + def _respond_query( + self, + msg: Optional[DNSIncoming], + addr: str, + port: int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Respond to a query and reassemble any truncated deferred packets.""" + self._cancel_any_timers_for_addr(addr) + packets = self._deferred.pop(addr, []) + if msg: + packets.append(msg) + + self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) + + def error_received(self, exc: Exception) -> None: + """Likely socket closed or IPv6.""" + # We preformat the message string with the socket as we want + # log_exception_once to log a warrning message once PER EACH + # different socket in case there are problems with multiple + # sockets + msg_str = f"Error with socket {self.sock_description}): %s" + QuietLogger.log_exception_once(exc, msg_str, exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + wrapped_transport = make_wrapped_transport(cast(asyncio.DatagramTransport, transport)) + self.transport = wrapped_transport + self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection lost.""" diff --git a/src/zeroconf/_transport.py b/src/zeroconf/_transport.py new file mode 100644 index 00000000..7f6d7ac8 --- /dev/null +++ b/src/zeroconf/_transport.py @@ -0,0 +1,67 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import socket +from typing import Any + + +class _WrappedTransport: + """A wrapper for transports.""" + + __slots__ = ( + 'transport', + 'is_ipv6', + 'sock', + 'fileno', + 'sock_name', + ) + + def __init__( + self, + transport: asyncio.DatagramTransport, + is_ipv6: bool, + sock: socket.socket, + fileno: int, + sock_name: Any, + ) -> None: + """Initialize the wrapped transport. + + These attributes are used when sending packets. + """ + self.transport = transport + self.is_ipv6 = is_ipv6 + self.sock = sock + self.fileno = fileno + self.sock_name = sock_name + + +def make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport: + """Make a wrapped transport.""" + sock: socket.socket = transport.get_extra_info('socket') + return _WrappedTransport( + transport=transport, + is_ipv6=sock.family == socket.AF_INET6, + sock=sock, + fileno=sock.fileno(), + sock_name=sock.getsockname(), + ) From 063b0d287584c7ccf50904f534e05a226946dd8c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Aug 2023 11:31:28 -0500 Subject: [PATCH 2/4] chore: split _engine.py into _transport.py and _listener.py --- src/zeroconf/_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 1548ec5b..0f9b45df 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -29,7 +29,7 @@ from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType -from ._engine import AsyncEngine, _WrappedTransport +from ._engine import AsyncEngine from ._exceptions import NonUniqueNameException, NotRunningException from ._handlers import ( MulticastOutgoingQueue, @@ -46,6 +46,7 @@ from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry +from ._transport import _WrappedTransport from ._updates import RecordUpdateListener from ._utils.asyncio import ( await_awaitable, From 770a3c4e53f4fd2c60f76551a939ac8d386a6687 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Aug 2023 11:35:45 -0500 Subject: [PATCH 3/4] chore: split tests --- tests/conftest.py | 4 +- tests/test_engine.py | 209 +-------------------------------------- tests/test_listener.py | 219 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 207 deletions(-) create mode 100644 tests/test_listener.py diff --git a/tests/conftest.py b/tests/conftest.py index 34fdeb72..5cdff18e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import pytest -from zeroconf import _core, _engine, const +from zeroconf import _core, _listener, const @pytest.fixture(autouse=True) @@ -34,7 +34,7 @@ def disable_duplicate_packet_suppression(): Some tests run too slowly because of the duplicate packet suppression. """ - with patch.object(_engine, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0), patch.object( + with patch.object(_listener, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0), patch.object( const, "_DUPLICATE_PACKET_SUPPRESSION_INTERVAL", 0 ): yield diff --git a/tests/test_engine.py b/tests/test_engine.py index 2c7e14be..dc6674dd 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,22 +1,18 @@ #!/usr/bin/env python -""" Unit tests for zeroconf._core """ +""" Unit tests for zeroconf._engine """ import asyncio import itertools import logging -import unittest -import unittest.mock -from typing import Set, Tuple, Union -from unittest.mock import MagicMock, patch +from typing import Set +from unittest.mock import patch import pytest import zeroconf as r -from zeroconf import Zeroconf, _engine, const, current_time_millis -from zeroconf._protocol import outgoing -from zeroconf._protocol.incoming import DNSIncoming +from zeroconf import _engine, const from zeroconf.asyncio import AsyncZeroconf log = logging.getLogger('zeroconf') @@ -34,13 +30,6 @@ def teardown_module(): log.setLevel(original_logging_level) -def threadsafe_query(zc, protocol, *args): - async def make_query(): - protocol.handle_query_or_defer(*args) - - asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() - - # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio @@ -93,193 +82,3 @@ async def test_reaper_aborts_when_done(): await asyncio.sleep(1.2) assert zeroconf.cache.get(record_with_10s_ttl) is not None assert zeroconf.cache.get(record_with_1s_ttl) is not None - - -def test_guard_against_oversized_packets(): - """Ensure we do not process oversized packets. - - These packets can quickly overwhelm the system. - """ - zc = Zeroconf(interfaces=['127.0.0.1']) - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - - for i in range(5000): - generated.add_answer_at_time( - r.DNSText( - "packet{i}.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - 500, - b'path=/~paulsm/', - ), - 0, - ) - - try: - # We are patching to generate an oversized packet - with patch.object(outgoing, "_MAX_MSG_ABSOLUTE", 100000), patch.object( - outgoing, "_MAX_MSG_TYPICAL", 100000 - ): - over_sized_packet = generated.packets()[0] - assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE - except AttributeError: - # cannot patch with cython - zc.close() - return - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - okpacket_record = r.DNSText( - "okpacket.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - 500, - b'path=/~paulsm/', - ) - - generated.add_answer_at_time( - okpacket_record, - 0, - ) - ok_packet = generated.packets()[0] - - # We cannot test though the network interface as some operating systems - # will guard against the oversized packet and we won't see it. - listener = _engine.AsyncListener(zc) - listener.transport = unittest.mock.MagicMock() - - listener.datagram_received(ok_packet, ('127.0.0.1', const._MDNS_PORT)) - assert zc.cache.async_get_unique(okpacket_record) is not None - - listener.datagram_received(over_sized_packet, ('127.0.0.1', const._MDNS_PORT)) - assert ( - zc.cache.async_get_unique( - r.DNSText( - "packet0.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - 500, - b'path=/~paulsm/', - ) - ) - is None - ) - - logging.getLogger('zeroconf').setLevel(logging.INFO) - - listener.datagram_received(over_sized_packet, ('::1', const._MDNS_PORT, 1, 1)) - assert ( - zc.cache.async_get_unique( - r.DNSText( - "packet0.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - 500, - b'path=/~paulsm/', - ) - ) - is None - ) - - zc.close() - - -def test_guard_against_duplicate_packets(): - """Ensure we do not process duplicate packets. - These packets can quickly overwhelm the system. - """ - zc = Zeroconf(interfaces=['127.0.0.1']) - - class SubListener(_engine.AsyncListener): - def handle_query_or_defer( - self, - msg: DNSIncoming, - addr: str, - port: int, - transport: _engine._WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - ) -> None: - """Handle a query or defer it for later processing.""" - super().handle_query_or_defer(msg, addr, port, transport, v6_flow_scope) - - listener = SubListener(zc) - listener.transport = MagicMock() - - query = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) - question = r.DNSQuestion("x._http._tcp.local.", const._TYPE_PTR, const._CLASS_IN) - query.add_question(question) - packet_with_qm_question = query.packets()[0] - - query3 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) - question3 = r.DNSQuestion("x._ay._tcp.local.", const._TYPE_PTR, const._CLASS_IN) - query3.add_question(question3) - packet_with_qm_question2 = query3.packets()[0] - - query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) - question2 = r.DNSQuestion("x._http._tcp.local.", const._TYPE_PTR, const._CLASS_IN) - question2.unicast = True - query2.add_question(question2) - packet_with_qu_question = query2.packets()[0] - - addrs = ("1.2.3.4", 43) - - with patch.object(_engine, "current_time_millis") as _current_time_millis, patch.object( - listener, "handle_query_or_defer" - ) as _handle_query_or_defer: - start_time = current_time_millis() - - _current_time_millis.return_value = start_time - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call with the same packet again and handle_query_or_defer should not fire - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_not_called() - _handle_query_or_defer.reset_mock() - - # Now walk time forward 1000 seconds - _current_time_millis.return_value = start_time + 1000 - # Now call with the same packet again and handle_query_or_defer should fire - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call with the different packet and handle_query_or_defer should fire - listener.datagram_received(packet_with_qm_question2, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call with the different packet and handle_query_or_defer should fire - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call with the different packet with qu question and handle_query_or_defer should fire - listener.datagram_received(packet_with_qu_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call again with the same packet that has a qu question and handle_query_or_defer should fire - listener.datagram_received(packet_with_qu_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - log.setLevel(logging.WARNING) - - # Call with the QM packet again - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_called_once() - _handle_query_or_defer.reset_mock() - - # Now call with the same packet again and handle_query_or_defer should not fire - listener.datagram_received(packet_with_qm_question, addrs) - _handle_query_or_defer.assert_not_called() - _handle_query_or_defer.reset_mock() - - # Now call with garbage - listener.datagram_received(b'garbage', addrs) - _handle_query_or_defer.assert_not_called() - _handle_query_or_defer.reset_mock() - - zc.close() diff --git a/tests/test_listener.py b/tests/test_listener.py new file mode 100644 index 00000000..fa42d91b --- /dev/null +++ b/tests/test_listener.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._listener """ + +import logging +import unittest +import unittest.mock +from typing import Tuple, Union +from unittest.mock import MagicMock, patch + +import zeroconf as r +from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis +from zeroconf._protocol import outgoing +from zeroconf._protocol.incoming import DNSIncoming + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +def test_guard_against_oversized_packets(): + """Ensure we do not process oversized packets. + + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for i in range(5000): + generated.add_answer_at_time( + r.DNSText( + "packet{i}.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ), + 0, + ) + + try: + # We are patching to generate an oversized packet + with patch.object(outgoing, "_MAX_MSG_ABSOLUTE", 100000), patch.object( + outgoing, "_MAX_MSG_TYPICAL", 100000 + ): + over_sized_packet = generated.packets()[0] + assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE + except AttributeError: + # cannot patch with cython + zc.close() + return + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + okpacket_record = r.DNSText( + "okpacket.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + + generated.add_answer_at_time( + okpacket_record, + 0, + ) + ok_packet = generated.packets()[0] + + # We cannot test though the network interface as some operating systems + # will guard against the oversized packet and we won't see it. + listener = _listener.AsyncListener(zc) + listener.transport = unittest.mock.MagicMock() + + listener.datagram_received(ok_packet, ('127.0.0.1', const._MDNS_PORT)) + assert zc.cache.async_get_unique(okpacket_record) is not None + + listener.datagram_received(over_sized_packet, ('127.0.0.1', const._MDNS_PORT)) + assert ( + zc.cache.async_get_unique( + r.DNSText( + "packet0.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + ) + is None + ) + + logging.getLogger('zeroconf').setLevel(logging.INFO) + + listener.datagram_received(over_sized_packet, ('::1', const._MDNS_PORT, 1, 1)) + assert ( + zc.cache.async_get_unique( + r.DNSText( + "packet0.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + ) + is None + ) + + zc.close() + + +def test_guard_against_duplicate_packets(): + """Ensure we do not process duplicate packets. + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + + class SubListener(_listener.AsyncListener): + def handle_query_or_defer( + self, + msg: DNSIncoming, + addr: str, + port: int, + transport: _engine._WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Handle a query or defer it for later processing.""" + super().handle_query_or_defer(msg, addr, port, transport, v6_flow_scope) + + listener = SubListener(zc) + listener.transport = MagicMock() + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question = r.DNSQuestion("x._http._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + packet_with_qm_question = query.packets()[0] + + query3 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question3 = r.DNSQuestion("x._ay._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + query3.add_question(question3) + packet_with_qm_question2 = query3.packets()[0] + + query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question2 = r.DNSQuestion("x._http._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + question2.unicast = True + query2.add_question(question2) + packet_with_qu_question = query2.packets()[0] + + addrs = ("1.2.3.4", 43) + + with patch.object(_engine, "current_time_millis") as _current_time_millis, patch.object( + listener, "handle_query_or_defer" + ) as _handle_query_or_defer: + start_time = current_time_millis() + + _current_time_millis.return_value = start_time + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call with the same packet again and handle_query_or_defer should not fire + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_not_called() + _handle_query_or_defer.reset_mock() + + # Now walk time forward 1000 seconds + _current_time_millis.return_value = start_time + 1000 + # Now call with the same packet again and handle_query_or_defer should fire + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call with the different packet and handle_query_or_defer should fire + listener.datagram_received(packet_with_qm_question2, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call with the different packet and handle_query_or_defer should fire + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call with the different packet with qu question and handle_query_or_defer should fire + listener.datagram_received(packet_with_qu_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call again with the same packet that has a qu question and handle_query_or_defer should fire + listener.datagram_received(packet_with_qu_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + log.setLevel(logging.WARNING) + + # Call with the QM packet again + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_called_once() + _handle_query_or_defer.reset_mock() + + # Now call with the same packet again and handle_query_or_defer should not fire + listener.datagram_received(packet_with_qm_question, addrs) + _handle_query_or_defer.assert_not_called() + _handle_query_or_defer.reset_mock() + + # Now call with garbage + listener.datagram_received(b'garbage', addrs) + _handle_query_or_defer.assert_not_called() + _handle_query_or_defer.reset_mock() + + zc.close() From 1088a1aa3f71197ae4691ae5c16a562db54e187e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Aug 2023 11:41:21 -0500 Subject: [PATCH 4/4] chore: fix patch target --- tests/test_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_listener.py b/tests/test_listener.py index fa42d91b..737b8111 100644 --- a/tests/test_listener.py +++ b/tests/test_listener.py @@ -157,7 +157,7 @@ def handle_query_or_defer( addrs = ("1.2.3.4", 43) - with patch.object(_engine, "current_time_millis") as _current_time_millis, patch.object( + with patch.object(_listener, "current_time_millis") as _current_time_millis, patch.object( listener, "handle_query_or_defer" ) as _handle_query_or_defer: start_time = current_time_millis()