diff --git a/tests/test_handlers.py b/tests/test_handlers.py index a621f0378..44ee1d5af 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -108,8 +108,9 @@ def _process_outgoing_packet(out): _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) # The additonals should all be suppresed since they are all in the answers section + # There will be one NSEC additional to indicate the lack of AAAA record # - assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -143,7 +144,9 @@ def _process_outgoing_packet(out): [r.DNSIncoming(packet) for packet in query.packets()], False ) _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) - assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 + + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -271,7 +274,9 @@ def test_ptr_optimization(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 + assert nbr_answers == 1 and nbr_additionals == 4 + # There will be one NSEC additional to indicate the lack of AAAA record + assert has_srv and has_txt and has_a # unregister @@ -406,7 +411,7 @@ def test_unicast_response(): [r.DNSIncoming(packet) for packet in query.packets()], True ) for answers in (question_answers.ucast, question_answers.mcast_aggregate): - has_srv = has_txt = has_a = False + has_srv = has_txt = has_a = has_aaaa = has_nsec = False nbr_additionals = 0 nbr_answers = len(answers) additionals = set().union(*answers.values()) @@ -418,8 +423,14 @@ def test_unicast_response(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 - assert has_srv and has_txt and has_a + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa # unregister zc.registry.async_remove(info) @@ -497,7 +508,7 @@ def test_qu_response(): zc.register_service(info) def _validate_complete_response(answers): - has_srv = has_txt = has_a = False + has_srv = has_txt = has_a = has_aaaa = has_nsec = False nbr_answers = len(answers.keys()) additionals = set().union(*answers.values()) nbr_additionals = len(additionals) @@ -509,8 +520,13 @@ def _validate_complete_response(answers): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 - assert has_srv and has_txt and has_a + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa # With QU should respond to only unicast when the answer has been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -635,6 +651,21 @@ def test_known_answer_supression(): assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second + # Test NSEC record returned when there is no AAAA record and we expectly ask + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + for dns_address in info.dns_addresses(): + generated.add_answer_at_time(dns_address, now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + expected_nsec_record: r.DNSNsec = list(question_answers.mcast_now)[0] + assert const._TYPE_A not in expected_nsec_record.rdtypes + assert const._TYPE_AAAA in expected_nsec_record.rdtypes + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + # Test SRV supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 06ed54cd1..76ba6cc3a 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -26,15 +26,17 @@ from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from ._cache import DNSCache, _UniqueRecordsType -from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord +from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._history import QuestionHistory from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing +from ._services.info import ServiceInfo from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CLASS_IN, + _CLASS_UNIQUE, _DNS_OTHER_TTL, _DNS_PTR_MIN_TTL, _FLAGS_AA, @@ -44,6 +46,7 @@ _TYPE_A, _TYPE_AAAA, _TYPE_ANY, + _TYPE_NSEC, _TYPE_PTR, _TYPE_SRV, _TYPE_TXT, @@ -56,7 +59,8 @@ _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] _MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) -_RESPOND_IMMEDIATE_TYPES = {_TYPE_SRV, _TYPE_A, _TYPE_AAAA} +_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA} +_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} class QuestionAnswers(NamedTuple): @@ -78,6 +82,15 @@ def _message_is_probe(msg: DNSIncoming) -> bool: return msg.num_authorities > 0 +def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec: + """Construct an NSEC record for name and a list of dns types. + + This function should only be used for SRV/A/AAAA records + which have a TTL of _DNS_OTHER_TTL + """ + return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now) + + def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: """Add answers and additionals to a DNSOutgoing.""" out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True) @@ -244,12 +257,23 @@ def _add_pointer_answers( # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) - if not known_answers.suppresses(dns_pointer): - answer_set[dns_pointer] = { - service.dns_service(created=now), - service.dns_text(created=now), - *service.dns_addresses(created=now), - } + if known_answers.suppresses(dns_pointer): + continue + additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)} + additionals |= self._get_address_and_nsec_records(service, now) + answer_set[dns_pointer] = additionals + + def _get_address_and_nsec_records(self, service: ServiceInfo, now: float) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + seen_types: Set[int] = set() + records: Set[DNSRecord] = set() + for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) + records.add(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if missing_types: + records.add(construct_nsec_record(service.server, list(missing_types), now)) + return records def _add_address_answers( self, @@ -263,13 +287,21 @@ def _add_address_answers( for service in self.registry.async_get_infos_server(name): answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() + seen_types: Set[int] = set() for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) if dns_address.type != type_: additionals.add(dns_address) elif not known_answers.suppresses(dns_address): answers.append(dns_address) - for answer in answers: - answer_set[answer] = additionals + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if answers: + if missing_types: + additionals.add(construct_nsec_record(service.server, list(missing_types), now)) + for answer in answers: + answer_set[answer] = additionals + elif type_ in missing_types: + answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set() def _answer_question( self, @@ -299,7 +331,7 @@ def _answer_question( # https://tools.ietf.org/html/rfc6763#section-12.2. dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = set(service.dns_addresses(created=now)) + answer_set[dns_service] = self._get_address_and_nsec_records(service, now) if type_ in (_TYPE_TXT, _TYPE_ANY): dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text):