Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 3 additions & 46 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
from ._dns import DNSQuestion, DNSQuestionType
from ._engine import AsyncEngine
from ._exceptions import NonUniqueNameException, NotRunningException
from ._handlers.answers import (
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)
from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue
from ._handlers.query_handler import QueryHandler
from ._handlers.record_manager import RecordManager
Expand Down Expand Up @@ -187,15 +183,15 @@ def __init__(
self.registry = ServiceRegistry()
self.cache = DNSCache()
self.question_history = QuestionHistory()
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
self.query_handler = QueryHandler(self)
self.record_manager = RecordManager(self)

self._notify_futures: Set[asyncio.Future] = set()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self._loop_thread: Optional[threading.Thread] = None

self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)

self.start()

Expand Down Expand Up @@ -567,45 +563,6 @@ def handle_response(self, msg: DNSIncoming) -> None:
self.log_warning_once("handle_response is deprecated, use record_manager.async_updates_from_response")
self.record_manager.async_updates_from_response(msg)

def handle_assembled_query(
self,
packets: List[DNSIncoming],
addr: str,
port: int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Respond to a (re)assembled query.

If the protocol received packets with the TC bit set, it will
wait a bit for the rest of the packets and only call
handle_assembled_query once it has a complete set of packets
or the timer expires. If the TC bit is not set, a single
packet will be in packets.
"""
ucast_source = port != _MDNS_PORT
question_answers = self.query_handler.async_response(packets, ucast_source)
if not question_answers:
return
now = packets[0].now
if question_answers.ucast:
questions = packets[0].questions
id_ = packets[0].id
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
# When sending unicast, only send back the reply
# via the same socket that it was recieved from
# as we know its reachable from that socket
self.async_send(out, addr, port, v6_flow_scope, transport)
if question_answers.mcast_now:
self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
if question_answers.mcast_aggregate:
self._out_queue.async_add(now, question_answers.mcast_aggregate)
if question_answers.mcast_aggregate_last_second:
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
# If we broadcast it in the last second, we have to delay
# at least a second before we send it again
self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)

def send(
self,
out: DNSOutgoing,
Expand Down
25 changes: 23 additions & 2 deletions src/zeroconf/_handlers/query_handler.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ from .._history cimport QuestionHistory
from .._protocol.incoming cimport DNSIncoming
from .._services.info cimport ServiceInfo
from .._services.registry cimport ServiceRegistry
from .answers cimport QuestionAnswers
from .answers cimport (
QuestionAnswers,
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)
from .multicast_outgoing_queue cimport MulticastOutgoingQueue


cdef bint TYPE_CHECKING
Expand Down Expand Up @@ -65,6 +70,7 @@ cdef class _QueryResponse:

cdef class QueryHandler:

cdef object zc
cdef ServiceRegistry registry
cdef DNSCache cache
cdef QuestionHistory question_history
Expand Down Expand Up @@ -93,7 +99,22 @@ cdef class QueryHandler:
is_probe=object,
now=double
)
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)
cpdef QuestionAnswers async_response(self, cython.list msgs, cython.bint unicast_source)

@cython.locals(name=str, question_lower_name=str)
cdef _get_answer_strategies(self, DNSQuestion question)

@cython.locals(
first_packet=DNSIncoming,
ucast_source=bint,
out_queue=MulticastOutgoingQueue,
out_delay_queue=MulticastOutgoingQueue
)
cpdef void handle_assembled_query(
self,
list packets,
object addr,
object port,
object transport,
tuple v6_flow_scope
)
71 changes: 61 additions & 10 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
USA
"""

from typing import TYPE_CHECKING, List, Optional, Set, cast
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast

from .._cache import DNSCache, _UniqueRecordsType
from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
from .._history import QuestionHistory
from .._protocol.incoming import DNSIncoming
from .._services.info import ServiceInfo
from .._services.registry import ServiceRegistry
from .._transport import _WrappedTransport
from .._utils.net import IPVersion
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
_DNS_OTHER_TTL,
_MDNS_PORT,
_ONE_SECOND,
_SERVICE_TYPE_ENUMERATION_NAME,
_TYPE_A,
Expand All @@ -43,7 +43,12 @@
_TYPE_SRV,
_TYPE_TXT,
)
from .answers import QuestionAnswers, _AnswerWithAdditionalsType
from .answers import (
QuestionAnswers,
_AnswerWithAdditionalsType,
construct_outgoing_multicast_answers,
construct_outgoing_unicast_answers,
)

_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}

Expand All @@ -53,14 +58,17 @@
_IPVersion_ALL = IPVersion.All

_int = int

_str = str

_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0
_ANSWER_STRATEGY_POINTER = 1
_ANSWER_STRATEGY_ADDRESS = 2
_ANSWER_STRATEGY_SERVICE = 3
_ANSWER_STRATEGY_TEXT = 4

if TYPE_CHECKING:
from .._core import Zeroconf


class _AnswerStrategy:

Expand Down Expand Up @@ -183,13 +191,14 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
class QueryHandler:
"""Query the ServiceRegistry."""

__slots__ = ("registry", "cache", "question_history")
__slots__ = ("zc", "registry", "cache", "question_history")

def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:
def __init__(self, zc: 'Zeroconf') -> None:
"""Init the query handler."""
self.registry = registry
self.cache = cache
self.question_history = question_history
self.zc = zc
self.registry = zc.registry
self.cache = zc.cache
self.question_history = zc.question_history

def _add_service_type_enumeration_query_answers(
self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
Expand Down Expand Up @@ -385,3 +394,45 @@ def _get_answer_strategies(
)

return strategies

def handle_assembled_query(
self,
packets: List[DNSIncoming],
addr: _str,
port: _int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Respond to a (re)assembled query.

If the protocol recieved packets with the TC bit set, it will
wait a bit for the rest of the packets and only call
handle_assembled_query once it has a complete set of packets
or the timer expires. If the TC bit is not set, a single
packet will be in packets.
"""
first_packet = packets[0]
now = first_packet.now
ucast_source = port != _MDNS_PORT
question_answers = self.async_response(packets, ucast_source)
if not question_answers:
return
if question_answers.ucast:
questions = first_packet.questions
id_ = first_packet.id
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
# When sending unicast, only send back the reply
# via the same socket that it was recieved from
# as we know its reachable from that socket
self.zc.async_send(out, addr, port, v6_flow_scope, transport)
if question_answers.mcast_now:
self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
if question_answers.mcast_aggregate:
out_queue = self.zc.out_queue
out_queue.async_add(now, question_answers.mcast_aggregate)
if question_answers.mcast_aggregate_last_second:
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
# If we broadcast it in the last second, we have to delay
# at least a second before we send it again
out_delay_queue = self.zc.out_delay_queue
out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)
Comment on lines +431 to +438
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_queue and out_delay_queue never change, we could rebind them to the QueryHandler object to avoid the getattr here but it probably doesn't make much difference as they should rarely be called in practice

13 changes: 6 additions & 7 deletions src/zeroconf/_handlers/record_manager.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@ cdef class RecordManager:
cdef public DNSCache cache
cdef public cython.set listeners

cpdef async_updates(self, object now, object records)
cpdef void async_updates(self, object now, object records)

cpdef async_updates_complete(self, object notify)
cpdef void async_updates_complete(self, bint notify)

@cython.locals(
cache=DNSCache,
record=DNSRecord,
answers=cython.list,
maybe_entry=DNSRecord,
now_double=double
)
cpdef async_updates_from_response(self, DNSIncoming msg)
cpdef void async_updates_from_response(self, DNSIncoming msg)

cpdef async_add_listener(self, RecordUpdateListener listener, object question)
cpdef void async_add_listener(self, RecordUpdateListener listener, object question)

cpdef async_remove_listener(self, RecordUpdateListener listener)
cpdef void async_remove_listener(self, RecordUpdateListener listener)

@cython.locals(question=DNSQuestion, record=DNSRecord)
cdef _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)
cdef void _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)
5 changes: 2 additions & 3 deletions src/zeroconf/_handlers/record_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
other_adds: List[DNSRecord] = []
removes: Set[DNSRecord] = set()
now = msg.now
now_double = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers()
Expand Down Expand Up @@ -113,7 +112,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
record = cast(_UniqueRecordsType, record)

maybe_entry = cache.async_get_unique(record)
if not record.is_expired(now_double):
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
else:
Expand All @@ -129,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now_double)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

if updates:
self.async_updates(now, updates)
Expand Down
2 changes: 2 additions & 0 deletions src/zeroconf/_listener.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import cython

from ._handlers.query_handler cimport QueryHandler
from ._handlers.record_manager cimport RecordManager
from ._protocol.incoming cimport DNSIncoming
from ._services.registry cimport ServiceRegistry
Expand All @@ -21,6 +22,7 @@ cdef class AsyncListener:
cdef public object zc
cdef ServiceRegistry _registry
cdef RecordManager _record_manager
cdef QueryHandler _query_handler
cdef public cython.bytes data
cdef public double last_time
cdef public DNSIncoming last_message
Expand Down
4 changes: 3 additions & 1 deletion src/zeroconf/_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class AsyncListener:
'zc',
'_registry',
'_record_manager',
"_query_handler",
'data',
'last_time',
'last_message',
Expand All @@ -72,6 +73,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self._registry = zc.registry
self._record_manager = zc.record_manager
self._query_handler = zc.query_handler
self.data: Optional[bytes] = None
self.last_time: float = 0
self.last_message: Optional[DNSIncoming] = None
Expand Down Expand Up @@ -228,7 +230,7 @@ def _respond_query(
if msg:
packets.append(msg)

self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
self._query_handler.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)

def error_received(self, exc: Exception) -> None:
"""Likely socket closed or IPv6."""
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cdef class DNSIncoming:
cdef cython.uint _num_authorities
cdef cython.uint _num_additionals
cdef public bint valid
cdef public object now
cdef public double now
cdef public object scope_id
cdef public object source
cdef bint _has_qu_question
Expand Down
4 changes: 2 additions & 2 deletions src/zeroconf/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import asyncio
import socket
from typing import Any
from typing import Tuple


class _WrappedTransport:
Expand All @@ -42,7 +42,7 @@ def __init__(
is_ipv6: bool,
sock: socket.socket,
fileno: int,
sock_name: Any,
sock_name: Tuple,
) -> None:
"""Initialize the wrapped transport.

Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from zeroconf import _core, const
from zeroconf._handlers import query_handler


@pytest.fixture(autouse=True)
Expand All @@ -23,7 +24,9 @@ def verify_threads_ended():
@pytest.fixture
def run_isolated():
"""Change the mDNS port to run the test in isolation."""
with patch.object(_core, "_MDNS_PORT", 5454), patch.object(const, "_MDNS_PORT", 5454):
with patch.object(query_handler, "_MDNS_PORT", 5454), patch.object(
_core, "_MDNS_PORT", 5454
), patch.object(const, "_MDNS_PORT", 5454):
yield


Expand Down