diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 7f60a695b..5827e2d5b 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -577,15 +577,17 @@ def handle_assembled_query( ) -> None: """Respond to a (re)assembled query. - If the protocol recieved packets with the TC bit set, it will + If the protocol received packets with the TC bit set, it will wait a bit for the rest of the packets and only call handle_assembled_query once it has a complete set of packets or the timer expires. If the TC bit is not set, a single packet will be in packets. """ - now = packets[0].now ucast_source = port != _MDNS_PORT question_answers = self.query_handler.async_response(packets, ucast_source) + if not question_answers: + return + now = packets[0].now if question_answers.ucast: questions = packets[0].questions id_ = packets[0].id diff --git a/src/zeroconf/_handlers/answers.py b/src/zeroconf/_handlers/answers.py index 6ba502ac9..a2dbd66aa 100644 --- a/src/zeroconf/_handlers/answers.py +++ b/src/zeroconf/_handlers/answers.py @@ -59,6 +59,14 @@ def __init__( self.mcast_aggregate = mcast_aggregate self.mcast_aggregate_last_second = mcast_aggregate_last_second + def __repr__(self) -> str: + """Return a string representation of this QuestionAnswers.""" + return ( + f'QuestionAnswers(ucast={self.ucast}, mcast_now={self.mcast_now}, ' + f'mcast_aggregate={self.mcast_aggregate}, ' + f'mcast_aggregate_last_second={self.mcast_aggregate_last_second})' + ) + class AnswerGroup: """A group of answers scheduled to be sent at the same time.""" diff --git a/src/zeroconf/_handlers/multicast_outgoing_queue.py b/src/zeroconf/_handlers/multicast_outgoing_queue.py index 1d398d736..23288d18d 100644 --- a/src/zeroconf/_handlers/multicast_outgoing_queue.py +++ b/src/zeroconf/_handlers/multicast_outgoing_queue.py @@ -77,7 +77,7 @@ def async_add(self, now: _float, answers: _AnswerWithAdditionalsType) -> None: # If we calculate a random delay for the send after time # that is less than the last group scheduled to go out, # we instead add the answers to the last group as this - # allows aggregating additonal responses + # allows aggregating additional responses last_group = self.queue[-1] if send_after <= last_group.send_after: last_group.answers.update(answers) @@ -116,7 +116,7 @@ def async_ready(self) -> None: # be sure we schedule them to go out later loop.call_at(loop.time() + millis_to_seconds(self.queue[0].send_after - now), self.async_ready) - if answers: + if answers: # pragma: no branch # If we have the same answer scheduled to go out, remove them self._remove_answers_from_queue(answers) zc.async_send(construct_outgoing_multicast_answers(answers)) diff --git a/src/zeroconf/_handlers/query_handler.pxd b/src/zeroconf/_handlers/query_handler.pxd index ff970d766..8c42144ca 100644 --- a/src/zeroconf/_handlers/query_handler.pxd +++ b/src/zeroconf/_handlers/query_handler.pxd @@ -18,6 +18,23 @@ cdef cython.set _ADDRESS_RECORD_TYPES cdef object IPVersion, _IPVersion_ALL cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL +cdef unsigned int _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION +cdef unsigned int _ANSWER_STRATEGY_POINTER +cdef unsigned int _ANSWER_STRATEGY_ADDRESS +cdef unsigned int _ANSWER_STRATEGY_SERVICE +cdef unsigned int _ANSWER_STRATEGY_TEXT + +cdef list _EMPTY_SERVICES_LIST +cdef list _EMPTY_TYPES_LIST + +cdef class _AnswerStrategy: + + cdef public DNSQuestion question + cdef public unsigned int strategy_type + cdef public list types + cdef public list services + + cdef class _QueryResponse: cdef bint _is_probe @@ -53,24 +70,30 @@ cdef class QueryHandler: cdef QuestionHistory question_history @cython.locals(service=ServiceInfo) - cdef _add_service_type_enumeration_query_answers(self, cython.dict answer_set, DNSRRSet known_answers) + cdef _add_service_type_enumeration_query_answers(self, list types, cython.dict answer_set, DNSRRSet known_answers) @cython.locals(service=ServiceInfo) - cdef _add_pointer_answers(self, str lower_name, cython.dict answer_set, DNSRRSet known_answers) + cdef _add_pointer_answers(self, list services, cython.dict answer_set, DNSRRSet known_answers) @cython.locals(service=ServiceInfo, dns_address=DNSAddress) - cdef _add_address_answers(self, str lower_name, cython.dict answer_set, DNSRRSet known_answers, cython.uint type_) + cdef _add_address_answers(self, list services, cython.dict answer_set, DNSRRSet known_answers, cython.uint type_) @cython.locals(question_lower_name=str, type_=cython.uint, service=ServiceInfo) - cdef cython.dict _answer_question(self, DNSQuestion question, DNSRRSet known_answers) + cdef cython.dict _answer_question(self, DNSQuestion question, unsigned int strategy_type, list types, list services, DNSRRSet known_answers) @cython.locals( msg=DNSIncoming, + msgs=list, + strategy=_AnswerStrategy, question=DNSQuestion, answer_set=cython.dict, known_answers=DNSRRSet, known_answers_set=cython.set, + is_unicast=bint, is_probe=object, - now=object + now=float ) cpdef async_response(self, cython.list msgs, cython.bint unicast_source) + + @cython.locals(name=str, question_lower_name=str) + cdef _get_answer_strategies(self, DNSQuestion question) diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index 4e74aa5c0..0af72f4c6 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -20,13 +20,13 @@ USA """ - from typing import TYPE_CHECKING, List, Optional, Set, cast from .._cache import DNSCache, _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet from .._history import QuestionHistory from .._protocol.incoming import DNSIncoming +from .._services.info import ServiceInfo from .._services.registry import ServiceRegistry from .._utils.net import IPVersion from ..const import ( @@ -47,11 +47,39 @@ _RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} +_EMPTY_SERVICES_LIST: List[ServiceInfo] = [] +_EMPTY_TYPES_LIST: List[str] = [] + _IPVersion_ALL = IPVersion.All _int = int +_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0 +_ANSWER_STRATEGY_POINTER = 1 +_ANSWER_STRATEGY_ADDRESS = 2 +_ANSWER_STRATEGY_SERVICE = 3 +_ANSWER_STRATEGY_TEXT = 4 + + +class _AnswerStrategy: + + __slots__ = ("question", "strategy_type", "types", "services") + + def __init__( + self, + question: DNSQuestion, + strategy_type: _int, + types: List[str], + services: List[ServiceInfo], + ) -> None: + """Create an answer strategy.""" + self.question = question + self.strategy_type = strategy_type + self.types = types + self.services = services + + class _QueryResponse: """A pair for unicast and multicast DNSOutgoing responses.""" @@ -164,13 +192,13 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: self.question_history = question_history def _add_service_type_enumeration_query_answers( - self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet ) -> None: """Provide an answer to a service type enumeration query. https://datatracker.ietf.org/doc/html/rfc6763#section-9 """ - for stype in self.registry.async_get_types(): + for stype in types: dns_pointer = DNSPointer( _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, 0.0 ) @@ -178,10 +206,10 @@ def _add_service_type_enumeration_query_answers( answer_set[dns_pointer] = set() def _add_pointer_answers( - self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, services: List[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet ) -> None: """Answer PTR/ANY question.""" - for service in self.registry.async_get_infos_type(lower_name): + for service in services: # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service._dns_pointer(None) @@ -190,17 +218,18 @@ def _add_pointer_answers( answer_set[dns_pointer] = { service._dns_service(None), service._dns_text(None), - } | service._get_address_and_nsec_records(None) + *service._get_address_and_nsec_records(None), + } def _add_address_answers( self, - lower_name: str, + services: List[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: _int, ) -> None: """Answer A/AAAA/ANY question.""" - for service in self.registry.async_get_infos_server(lower_name): + for service in services: answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() seen_types: Set[int] = set() @@ -224,75 +253,135 @@ def _add_address_answers( def _answer_question( self, question: DNSQuestion, + strategy_type: _int, + types: List[str], + services: List[ServiceInfo], known_answers: DNSRRSet, ) -> _AnswerWithAdditionalsType: """Answer a question.""" answer_set: _AnswerWithAdditionalsType = {} - question_lower_name = question.name.lower() - type_ = question.type - - if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: - self._add_service_type_enumeration_query_answers(answer_set, known_answers) - return answer_set - - if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question_lower_name, answer_set, known_answers) - if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question_lower_name, answer_set, known_answers, type_) - - if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): - service = self.registry.async_get_info_name(question_lower_name) - if service is not None: - if type_ in (_TYPE_SRV, _TYPE_ANY): - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.2. - dns_service = service._dns_service(None) - if known_answers.suppresses(dns_service) is False: - answer_set[dns_service] = service._get_address_and_nsec_records(None) - if type_ in (_TYPE_TXT, _TYPE_ANY): - dns_text = service._dns_text(None) - if known_answers.suppresses(dns_text) is False: - answer_set[dns_text] = set() + if strategy_type == _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION: + self._add_service_type_enumeration_query_answers(types, answer_set, known_answers) + elif strategy_type == _ANSWER_STRATEGY_POINTER: + self._add_pointer_answers(services, answer_set, known_answers) + elif strategy_type == _ANSWER_STRATEGY_ADDRESS: + self._add_address_answers(services, answer_set, known_answers, question.type) + elif strategy_type == _ANSWER_STRATEGY_SERVICE: + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.2. + service = services[0] + dns_service = service._dns_service(None) + if known_answers.suppresses(dns_service) is False: + answer_set[dns_service] = service._get_address_and_nsec_records(None) + elif strategy_type == _ANSWER_STRATEGY_TEXT: # pragma: no branch + service = services[0] + dns_text = service._dns_text(None) + if known_answers.suppresses(dns_text) is False: + answer_set[dns_text] = set() return answer_set def async_response( # pylint: disable=unused-argument self, msgs: List[DNSIncoming], ucast_source: bool - ) -> QuestionAnswers: + ) -> Optional[QuestionAnswers]: """Deal with incoming query packets. Provides a response if possible. This function must be run in the event loop as it is not threadsafe. """ - answers: List[DNSRecord] = [] + strategies: List[_AnswerStrategy] = [] + for msg in msgs: + for question in msg.questions: + strategies.extend(self._get_answer_strategies(question)) + + if not strategies: + # We have no way to answer the question because we have + # nothing in the ServiceRegistry that matches or we do not + # understand the question. + return None + is_probe = False - msg = msgs[0] questions = msg.questions - now = msg.now + # Only decode known answers if we are not a probe and we have + # at least one answer strategy + answers: List[DNSRecord] = [] for msg in msgs: - if msg.is_probe() is False: - answers.extend(msg.answers()) - else: + if msg.is_probe() is True: is_probe = True + else: + answers.extend(msg.answers()) + + msg = msgs[0] + query_res = _QueryResponse(self.cache, questions, is_probe, msg.now) known_answers = DNSRRSet(answers) - query_res = _QueryResponse(self.cache, questions, is_probe, now) known_answers_set: Optional[Set[DNSRecord]] = None - - for msg in msgs: - for question in msg.questions: - if not question.unique: # unique and unicast are the same flag - if not known_answers_set: # pragma: no branch - known_answers_set = known_answers.lookup_set() - self.question_history.add_question_at_time(question, now, known_answers_set) - answer_set = self._answer_question(question, known_answers) - if not ucast_source and question.unique: # unique and unicast are the same flag - query_res.add_qu_question_response(answer_set) - continue - if ucast_source: - query_res.add_ucast_question_response(answer_set) - # We always multicast as well even if its a unicast - # source as long as we haven't done it recently (75% of ttl) - query_res.add_mcast_question_response(answer_set) + now = msg.now + for strategy in strategies: + question = strategy.question + is_unicast = question.unique is True # unique and unicast are the same flag + if not is_unicast: + if known_answers_set is None: # pragma: no branch + known_answers_set = known_answers.lookup_set() + self.question_history.add_question_at_time(question, now, known_answers_set) + answer_set = self._answer_question( + question, strategy.strategy_type, strategy.types, strategy.services, known_answers + ) + if not ucast_source and is_unicast: + query_res.add_qu_question_response(answer_set) + continue + if ucast_source: + query_res.add_ucast_question_response(answer_set) + # We always multicast as well even if its a unicast + # source as long as we haven't done it recently (75% of ttl) + query_res.add_mcast_question_response(answer_set) return query_res.answers() + + def _get_answer_strategies( + self, + question: DNSQuestion, + ) -> List[_AnswerStrategy]: + """Collect strategies to answer a question.""" + name = question.name + question_lower_name = name.lower() + type_ = question.type + strategies: List[_AnswerStrategy] = [] + + if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: + types = self.registry.async_get_types() + if types: + strategies.append( + _AnswerStrategy( + question, _ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION, types, _EMPTY_SERVICES_LIST + ) + ) + return strategies + + if type_ in (_TYPE_PTR, _TYPE_ANY): + services = self.registry.async_get_infos_type(question_lower_name) + if services: + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_POINTER, _EMPTY_TYPES_LIST, services) + ) + + if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): + services = self.registry.async_get_infos_server(question_lower_name) + if services: + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_ADDRESS, _EMPTY_TYPES_LIST, services) + ) + + if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): + service = self.registry.async_get_info_name(question_lower_name) + if service is not None: + if type_ in (_TYPE_SRV, _TYPE_ANY): + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_SERVICE, _EMPTY_TYPES_LIST, [service]) + ) + if type_ in (_TYPE_TXT, _TYPE_ANY): + strategies.append( + _AnswerStrategy(question, _ANSWER_STRATEGY_TEXT, _EMPTY_TYPES_LIST, [service]) + ) + + return strategies diff --git a/src/zeroconf/_record_update.py b/src/zeroconf/_record_update.py index 5a3625340..8e0e4bdb0 100644 --- a/src/zeroconf/_record_update.py +++ b/src/zeroconf/_record_update.py @@ -26,7 +26,6 @@ class RecordUpdate: - __slots__ = ("new", "old") def __init__(self, new: DNSRecord, old: Optional[DNSRecord] = None): diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 13fe3a516..1a1066fa2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -107,6 +107,7 @@ def _process_outgoing_packet(out): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) # The additonals should all be suppresed since they are all in the answers section @@ -145,6 +146,7 @@ def _process_outgoing_packet(out): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) # There will be one NSEC additional to indicate the lack of AAAA record @@ -244,6 +246,7 @@ def test_ptr_optimization(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -260,6 +263,7 @@ def test_ptr_optimization(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate_last_second @@ -305,6 +309,7 @@ def test_any_query_for_ptr(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers mcast_answers = list(question_answers.mcast_aggregate) assert mcast_answers[0].name == type_ assert mcast_answers[0].alias == registration_name # type: ignore[attr-defined] @@ -332,6 +337,7 @@ def test_aaaa_query(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers mcast_answers = list(question_answers.mcast_now) assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined] # unregister @@ -358,6 +364,7 @@ def test_aaaa_query_upper_case(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers mcast_answers = list(question_answers.mcast_now) assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined] # unregister @@ -391,6 +398,7 @@ def test_a_and_aaaa_record_fate_sharing(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers additionals = set().union(*question_answers.mcast_now.values()) assert aaaa_record in question_answers.mcast_now assert a_record in additionals @@ -403,6 +411,7 @@ def test_a_and_aaaa_record_fate_sharing(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers additionals = set().union(*question_answers.mcast_now.values()) assert a_record in question_answers.mcast_now assert aaaa_record in additionals @@ -437,6 +446,7 @@ def test_unicast_response(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], True ) + assert question_answers for answers in (question_answers.ucast, question_answers.mcast_aggregate): has_srv = has_txt = has_a = has_aaaa = has_nsec = False nbr_additionals = 0 @@ -486,6 +496,7 @@ async def test_probe_answered_immediately(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -499,6 +510,7 @@ async def test_probe_answered_immediately(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert question_answers.ucast assert question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -528,6 +540,7 @@ async def test_probe_answered_immediately_with_uppercase_name(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -541,6 +554,7 @@ async def test_probe_answered_immediately_with_uppercase_name(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert question_answers.ucast assert question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -607,6 +621,7 @@ def _validate_complete_response(answers): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers _validate_complete_response(question_answers.ucast) assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -622,6 +637,7 @@ def _validate_complete_response(answers): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate @@ -637,6 +653,7 @@ def _validate_complete_response(answers): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers _validate_complete_response(question_answers.ucast) _validate_complete_response(question_answers.mcast_now) @@ -652,6 +669,7 @@ def _validate_complete_response(answers): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.mcast_now assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -681,6 +699,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert question_answers.mcast_aggregate @@ -692,6 +711,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_pointer(), now) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -703,6 +723,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -715,6 +736,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(dns_address, now) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -728,6 +750,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(dns_address, now) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast expected_nsec_record = cast(r.DNSNsec, list(question_answers.mcast_now)[0]) assert const._TYPE_A not in expected_nsec_record.rdtypes @@ -741,6 +764,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -752,6 +776,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_service(), now) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -763,6 +788,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert question_answers.mcast_aggregate @@ -774,6 +800,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_text(), now) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -827,6 +854,7 @@ def test_multi_packet_known_answer_supression(): packets = generated.packets() assert len(packets) > 1 question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -868,6 +896,7 @@ def test_known_answer_supression_service_type_enumeration_query(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert question_answers.mcast_aggregate @@ -898,6 +927,7 @@ def test_known_answer_supression_service_type_enumeration_query(): ) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert not question_answers.mcast_aggregate @@ -938,6 +968,7 @@ def test_upper_case_enumeration_query(): generated.add_question(question) packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now assert question_answers.mcast_aggregate @@ -948,6 +979,19 @@ def test_upper_case_enumeration_query(): zc.close() +def test_enumeration_query_with_no_registered_services(): + zc = Zeroconf(interfaces=['127.0.0.1']) + _clear_cache(zc) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers + # unregister + zc.close() + + # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio @@ -1000,6 +1044,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.mcast_now assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -1024,6 +1069,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.mcast_now assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -1047,6 +1093,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -1075,6 +1122,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): question_answers = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], False ) + assert question_answers assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second @@ -1235,8 +1283,22 @@ async def test_questions_query_handler_populates_the_question_history_from_qm_qu now = current_time_millis() _clear_cache(zc) + aiozc.zeroconf.registry.async_add( + ServiceInfo( + "_hap._tcp.local.", + "other._hap._tcp.local.", + 80, + 0, + 0, + {"md": "known"}, + "ash-2.local.", + addresses=[socket.inet_aton("1.2.3.4")], + ) + ) + services = aiozc.zeroconf.registry.async_get_infos_type("_hap._tcp.local.") + assert len(services) == 1 generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question = r.DNSQuestion("_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) question.unicast = False known_answer = r.DNSPointer( "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' @@ -1246,9 +1308,10 @@ async def test_questions_query_handler_populates_the_question_history_from_qm_qu now = r.current_time_millis() packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert question_answers assert not question_answers.ucast assert not question_answers.mcast_now - assert not question_answers.mcast_aggregate + assert question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second assert zc.question_history.suppresses(question, now, {known_answer}) @@ -1261,20 +1324,32 @@ async def test_questions_query_handler_does_not_put_qu_questions_in_history(): zc = aiozc.zeroconf now = current_time_millis() _clear_cache(zc) - + info = ServiceInfo( + "_hap._tcp.local.", + "qu._hap._tcp.local.", + 80, + 0, + 0, + {"md": "known"}, + "ash-2.local.", + addresses=[socket.inet_aton("1.2.3.4")], + ) + aiozc.zeroconf.registry.async_add(info) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question = r.DNSQuestion("_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) question.unicast = True known_answer = r.DNSPointer( - "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'notqu._hap._tcp.local.' ) generated.add_question(question) generated.add_answer_at_time(known_answer, 0) now = r.current_time_millis() packets = generated.packets() question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) - assert not question_answers.ucast - assert not question_answers.mcast_now + assert question_answers + assert "qu._hap._tcp.local." in str(question_answers) + assert not question_answers.ucast # has not multicast recently + assert question_answers.mcast_now assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second assert not zc.question_history.suppresses(question, now, {known_answer})