diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 5827e2d5b..3a3381a91 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -31,10 +31,6 @@ from ._dns import DNSQuestion, DNSQuestionType from ._engine import AsyncEngine from ._exceptions import NonUniqueNameException, NotRunningException -from ._handlers.answers import ( - construct_outgoing_multicast_answers, - construct_outgoing_unicast_answers, -) from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue from ._handlers.query_handler import QueryHandler from ._handlers.record_manager import RecordManager @@ -187,15 +183,15 @@ def __init__( self.registry = ServiceRegistry() self.cache = DNSCache() self.question_history = QuestionHistory() - self.query_handler = QueryHandler(self.registry, self.cache, self.question_history) + self.query_handler = QueryHandler(self) self.record_manager = RecordManager(self) self._notify_futures: Set[asyncio.Future] = set() self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None - self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY) - self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY) + self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY) + self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY) self.start() @@ -567,45 +563,6 @@ def handle_response(self, msg: DNSIncoming) -> None: self.log_warning_once("handle_response is deprecated, use record_manager.async_updates_from_response") self.record_manager.async_updates_from_response(msg) - def handle_assembled_query( - self, - packets: List[DNSIncoming], - addr: str, - port: int, - transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]], - ) -> None: - """Respond to a (re)assembled query. - - If the protocol received packets with the TC bit set, it will - wait a bit for the rest of the packets and only call - handle_assembled_query once it has a complete set of packets - or the timer expires. If the TC bit is not set, a single - packet will be in packets. - """ - ucast_source = port != _MDNS_PORT - question_answers = self.query_handler.async_response(packets, ucast_source) - if not question_answers: - return - now = packets[0].now - if question_answers.ucast: - questions = packets[0].questions - id_ = packets[0].id - out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) - # When sending unicast, only send back the reply - # via the same socket that it was recieved from - # as we know its reachable from that socket - self.async_send(out, addr, port, v6_flow_scope, transport) - if question_answers.mcast_now: - self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) - if question_answers.mcast_aggregate: - self._out_queue.async_add(now, question_answers.mcast_aggregate) - if question_answers.mcast_aggregate_last_second: - # https://datatracker.ietf.org/doc/html/rfc6762#section-14 - # If we broadcast it in the last second, we have to delay - # at least a second before we send it again - self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second) - def send( self, out: DNSOutgoing, diff --git a/src/zeroconf/_handlers/query_handler.pxd b/src/zeroconf/_handlers/query_handler.pxd index 3e726a533..bb7198be5 100644 --- a/src/zeroconf/_handlers/query_handler.pxd +++ b/src/zeroconf/_handlers/query_handler.pxd @@ -7,7 +7,12 @@ from .._history cimport QuestionHistory from .._protocol.incoming cimport DNSIncoming from .._services.info cimport ServiceInfo from .._services.registry cimport ServiceRegistry -from .answers cimport QuestionAnswers +from .answers cimport ( + QuestionAnswers, + construct_outgoing_multicast_answers, + construct_outgoing_unicast_answers, +) +from .multicast_outgoing_queue cimport MulticastOutgoingQueue cdef bint TYPE_CHECKING @@ -65,6 +70,7 @@ cdef class _QueryResponse: cdef class QueryHandler: + cdef object zc cdef ServiceRegistry registry cdef DNSCache cache cdef QuestionHistory question_history @@ -93,7 +99,22 @@ cdef class QueryHandler: is_probe=object, now=double ) - cpdef async_response(self, cython.list msgs, cython.bint unicast_source) + cpdef QuestionAnswers async_response(self, cython.list msgs, cython.bint unicast_source) @cython.locals(name=str, question_lower_name=str) cdef _get_answer_strategies(self, DNSQuestion question) + + @cython.locals( + first_packet=DNSIncoming, + ucast_source=bint, + out_queue=MulticastOutgoingQueue, + out_delay_queue=MulticastOutgoingQueue + ) + cpdef void handle_assembled_query( + self, + list packets, + object addr, + object port, + object transport, + tuple v6_flow_scope + ) diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index c66d9c302..8349b584b 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -20,19 +20,19 @@ USA """ -from typing import TYPE_CHECKING, List, Optional, Set, cast +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast from .._cache import DNSCache, _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet -from .._history import QuestionHistory from .._protocol.incoming import DNSIncoming from .._services.info import ServiceInfo -from .._services.registry import ServiceRegistry +from .._transport import _WrappedTransport from .._utils.net import IPVersion from ..const import ( _ADDRESS_RECORD_TYPES, _CLASS_IN, _DNS_OTHER_TTL, + _MDNS_PORT, _ONE_SECOND, _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_A, @@ -43,7 +43,12 @@ _TYPE_SRV, _TYPE_TXT, ) -from .answers import QuestionAnswers, _AnswerWithAdditionalsType +from .answers import ( + QuestionAnswers, + _AnswerWithAdditionalsType, + construct_outgoing_multicast_answers, + construct_outgoing_unicast_answers, +) _RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} @@ -53,7 +58,7 @@ _IPVersion_ALL = IPVersion.All _int = int - +_str = str _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0 _ANSWER_STRATEGY_POINTER = 1 @@ -61,6 +66,9 @@ _ANSWER_STRATEGY_SERVICE = 3 _ANSWER_STRATEGY_TEXT = 4 +if TYPE_CHECKING: + from .._core import Zeroconf + class _AnswerStrategy: @@ -183,13 +191,14 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: class QueryHandler: """Query the ServiceRegistry.""" - __slots__ = ("registry", "cache", "question_history") + __slots__ = ("zc", "registry", "cache", "question_history") - def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None: + def __init__(self, zc: 'Zeroconf') -> None: """Init the query handler.""" - self.registry = registry - self.cache = cache - self.question_history = question_history + self.zc = zc + self.registry = zc.registry + self.cache = zc.cache + self.question_history = zc.question_history def _add_service_type_enumeration_query_answers( self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet @@ -385,3 +394,45 @@ def _get_answer_strategies( ) return strategies + + def handle_assembled_query( + self, + packets: List[DNSIncoming], + addr: _str, + port: _int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + ) -> None: + """Respond to a (re)assembled query. + + If the protocol recieved packets with the TC bit set, it will + wait a bit for the rest of the packets and only call + handle_assembled_query once it has a complete set of packets + or the timer expires. If the TC bit is not set, a single + packet will be in packets. + """ + first_packet = packets[0] + now = first_packet.now + ucast_source = port != _MDNS_PORT + question_answers = self.async_response(packets, ucast_source) + if not question_answers: + return + if question_answers.ucast: + questions = first_packet.questions + id_ = first_packet.id + out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) + # When sending unicast, only send back the reply + # via the same socket that it was recieved from + # as we know its reachable from that socket + self.zc.async_send(out, addr, port, v6_flow_scope, transport) + if question_answers.mcast_now: + self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) + if question_answers.mcast_aggregate: + out_queue = self.zc.out_queue + out_queue.async_add(now, question_answers.mcast_aggregate) + if question_answers.mcast_aggregate_last_second: + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + # If we broadcast it in the last second, we have to delay + # at least a second before we send it again + out_delay_queue = self.zc.out_delay_queue + out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second) diff --git a/src/zeroconf/_handlers/record_manager.pxd b/src/zeroconf/_handlers/record_manager.pxd index 0f543aff2..5be2c283b 100644 --- a/src/zeroconf/_handlers/record_manager.pxd +++ b/src/zeroconf/_handlers/record_manager.pxd @@ -22,22 +22,21 @@ cdef class RecordManager: cdef public DNSCache cache cdef public cython.set listeners - cpdef async_updates(self, object now, object records) + cpdef void async_updates(self, object now, object records) - cpdef async_updates_complete(self, object notify) + cpdef void async_updates_complete(self, bint notify) @cython.locals( cache=DNSCache, record=DNSRecord, answers=cython.list, maybe_entry=DNSRecord, - now_double=double ) - cpdef async_updates_from_response(self, DNSIncoming msg) + cpdef void async_updates_from_response(self, DNSIncoming msg) - cpdef async_add_listener(self, RecordUpdateListener listener, object question) + cpdef void async_add_listener(self, RecordUpdateListener listener, object question) - cpdef async_remove_listener(self, RecordUpdateListener listener) + cpdef void async_remove_listener(self, RecordUpdateListener listener) @cython.locals(question=DNSQuestion, record=DNSRecord) - cdef _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions) + cdef void _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions) diff --git a/src/zeroconf/_handlers/record_manager.py b/src/zeroconf/_handlers/record_manager.py index 129acd0b6..0a0f6c54b 100644 --- a/src/zeroconf/_handlers/record_manager.py +++ b/src/zeroconf/_handlers/record_manager.py @@ -84,7 +84,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: other_adds: List[DNSRecord] = [] removes: Set[DNSRecord] = set() now = msg.now - now_double = now unique_types: Set[Tuple[str, int, int]] = set() cache = self.cache answers = msg.answers() @@ -113,7 +112,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: record = cast(_UniqueRecordsType, record) maybe_entry = cache.async_get_unique(record) - if not record.is_expired(now_double): + if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) else: @@ -129,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: removes.add(record) if unique_types: - cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now_double) + cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now) if updates: self.async_updates(now, updates) diff --git a/src/zeroconf/_listener.pxd b/src/zeroconf/_listener.pxd index 8b144653e..96f52be02 100644 --- a/src/zeroconf/_listener.pxd +++ b/src/zeroconf/_listener.pxd @@ -1,6 +1,7 @@ import cython +from ._handlers.query_handler cimport QueryHandler from ._handlers.record_manager cimport RecordManager from ._protocol.incoming cimport DNSIncoming from ._services.registry cimport ServiceRegistry @@ -21,6 +22,7 @@ cdef class AsyncListener: cdef public object zc cdef ServiceRegistry _registry cdef RecordManager _record_manager + cdef QueryHandler _query_handler cdef public cython.bytes data cdef public double last_time cdef public DNSIncoming last_message diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index 23d245785..0f8a8cac7 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -59,6 +59,7 @@ class AsyncListener: 'zc', '_registry', '_record_manager', + "_query_handler", 'data', 'last_time', 'last_message', @@ -72,6 +73,7 @@ def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self._registry = zc.registry self._record_manager = zc.record_manager + self._query_handler = zc.query_handler self.data: Optional[bytes] = None self.last_time: float = 0 self.last_message: Optional[DNSIncoming] = None @@ -228,7 +230,7 @@ def _respond_query( if msg: packets.append(msg) - self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) + self._query_handler.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" diff --git a/src/zeroconf/_protocol/incoming.pxd b/src/zeroconf/_protocol/incoming.pxd index 07ae6e78e..a8c0dbdb8 100644 --- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -56,7 +56,7 @@ cdef class DNSIncoming: cdef cython.uint _num_authorities cdef cython.uint _num_additionals cdef public bint valid - cdef public object now + cdef public double now cdef public object scope_id cdef public object source cdef bint _has_qu_question diff --git a/src/zeroconf/_transport.py b/src/zeroconf/_transport.py index 7f6d7ac8c..c37af2efd 100644 --- a/src/zeroconf/_transport.py +++ b/src/zeroconf/_transport.py @@ -22,7 +22,7 @@ import asyncio import socket -from typing import Any +from typing import Tuple class _WrappedTransport: @@ -42,7 +42,7 @@ def __init__( is_ipv6: bool, sock: socket.socket, fileno: int, - sock_name: Any, + sock_name: Tuple, ) -> None: """Initialize the wrapped transport. diff --git a/tests/conftest.py b/tests/conftest.py index c0e926a34..5525c4ee0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import pytest from zeroconf import _core, const +from zeroconf._handlers import query_handler @pytest.fixture(autouse=True) @@ -23,7 +24,9 @@ def verify_threads_ended(): @pytest.fixture def run_isolated(): """Change the mDNS port to run the test in isolation.""" - with patch.object(_core, "_MDNS_PORT", 5454), patch.object(const, "_MDNS_PORT", 5454): + with patch.object(query_handler, "_MDNS_PORT", 5454), patch.object( + _core, "_MDNS_PORT", 5454 + ), patch.object(const, "_MDNS_PORT", 5454): yield