Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def _process_outgoing_packet(out):
_process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate))

# The additonals should all be suppresed since they are all in the answers section
# There will be one NSEC additional to indicate the lack of AAAA record
#
assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0
assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0
nbr_answers = nbr_additionals = nbr_authorities = 0

# unregister
Expand Down Expand Up @@ -143,7 +144,9 @@ def _process_outgoing_packet(out):
[r.DNSIncoming(packet) for packet in query.packets()], False
)
_process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate))
assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0

# There will be one NSEC additional to indicate the lack of AAAA record
assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0
nbr_answers = nbr_additionals = nbr_authorities = 0

# unregister
Expand Down Expand Up @@ -271,7 +274,9 @@ def test_ptr_optimization():
has_txt = True
elif answer.type == const._TYPE_A:
has_a = True
assert nbr_answers == 1 and nbr_additionals == 3
assert nbr_answers == 1 and nbr_additionals == 4
# There will be one NSEC additional to indicate the lack of AAAA record

assert has_srv and has_txt and has_a

# unregister
Expand Down Expand Up @@ -406,7 +411,7 @@ def test_unicast_response():
[r.DNSIncoming(packet) for packet in query.packets()], True
)
for answers in (question_answers.ucast, question_answers.mcast_aggregate):
has_srv = has_txt = has_a = False
has_srv = has_txt = has_a = has_aaaa = has_nsec = False
nbr_additionals = 0
nbr_answers = len(answers)
additionals = set().union(*answers.values())
Expand All @@ -418,8 +423,14 @@ def test_unicast_response():
has_txt = True
elif answer.type == const._TYPE_A:
has_a = True
assert nbr_answers == 1 and nbr_additionals == 3
assert has_srv and has_txt and has_a
elif answer.type == const._TYPE_AAAA:
has_aaaa = True
elif answer.type == const._TYPE_NSEC:
has_nsec = True
# There will be one NSEC additional to indicate the lack of AAAA record
assert nbr_answers == 1 and nbr_additionals == 4
assert has_srv and has_txt and has_a and has_nsec
assert not has_aaaa

# unregister
zc.registry.async_remove(info)
Expand Down Expand Up @@ -497,7 +508,7 @@ def test_qu_response():
zc.register_service(info)

def _validate_complete_response(answers):
has_srv = has_txt = has_a = False
has_srv = has_txt = has_a = has_aaaa = has_nsec = False
nbr_answers = len(answers.keys())
additionals = set().union(*answers.values())
nbr_additionals = len(additionals)
Expand All @@ -509,8 +520,13 @@ def _validate_complete_response(answers):
has_txt = True
elif answer.type == const._TYPE_A:
has_a = True
assert nbr_answers == 1 and nbr_additionals == 3
assert has_srv and has_txt and has_a
elif answer.type == const._TYPE_AAAA:
has_aaaa = True
elif answer.type == const._TYPE_NSEC:
has_nsec = True
assert nbr_answers == 1 and nbr_additionals == 4
assert has_srv and has_txt and has_a and has_nsec
assert not has_aaaa

# With QU should respond to only unicast when the answer has been recently multicast
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
Expand Down Expand Up @@ -635,6 +651,21 @@ def test_known_answer_supression():
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second

# Test NSEC record returned when there is no AAAA record and we expectly ask
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
generated.add_question(question)
for dns_address in info.dns_addresses():
generated.add_answer_at_time(dns_address, now)
packets = generated.packets()
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
assert not question_answers.ucast
expected_nsec_record: r.DNSNsec = list(question_answers.mcast_now)[0]
assert const._TYPE_A not in expected_nsec_record.rdtypes
assert const._TYPE_AAAA in expected_nsec_record.rdtypes
assert not question_answers.mcast_aggregate
assert not question_answers.mcast_aggregate_last_second

# Test SRV supression
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN)
Expand Down
54 changes: 43 additions & 11 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast

from ._cache import DNSCache, _UniqueRecordsType
from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord
from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord
from ._history import QuestionHistory
from ._logger import log
from ._protocol import DNSIncoming, DNSOutgoing
from ._services.info import ServiceInfo
from ._services.registry import ServiceRegistry
from ._updates import RecordUpdate, RecordUpdateListener
from ._utils.time import current_time_millis, millis_to_seconds
from .const import (
_CLASS_IN,
_CLASS_UNIQUE,
_DNS_OTHER_TTL,
_DNS_PTR_MIN_TTL,
_FLAGS_AA,
Expand All @@ -44,6 +46,7 @@
_TYPE_A,
_TYPE_AAAA,
_TYPE_ANY,
_TYPE_NSEC,
_TYPE_PTR,
_TYPE_SRV,
_TYPE_TXT,
Expand All @@ -56,7 +59,8 @@
_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]]

_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120)
_RESPOND_IMMEDIATE_TYPES = {_TYPE_SRV, _TYPE_A, _TYPE_AAAA}
_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA}
_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}


class QuestionAnswers(NamedTuple):
Expand All @@ -78,6 +82,15 @@ def _message_is_probe(msg: DNSIncoming) -> bool:
return msg.num_authorities > 0


def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
"""Construct an NSEC record for name and a list of dns types.

This function should only be used for SRV/A/AAAA records
which have a TTL of _DNS_OTHER_TTL
"""
return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now)


def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing:
"""Add answers and additionals to a DNSOutgoing."""
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True)
Expand Down Expand Up @@ -244,12 +257,23 @@ def _add_pointer_answers(
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer(created=now)
if not known_answers.suppresses(dns_pointer):
answer_set[dns_pointer] = {
service.dns_service(created=now),
service.dns_text(created=now),
*service.dns_addresses(created=now),
}
if known_answers.suppresses(dns_pointer):
continue
additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)}
additionals |= self._get_address_and_nsec_records(service, now)
answer_set[dns_pointer] = additionals

def _get_address_and_nsec_records(self, service: ServiceInfo, now: float) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
seen_types: Set[int] = set()
records: Set[DNSRecord] = set()
for dns_address in service.dns_addresses(created=now):
seen_types.add(dns_address.type)
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
records.add(construct_nsec_record(service.server, list(missing_types), now))
return records

def _add_address_answers(
self,
Expand All @@ -263,13 +287,21 @@ def _add_address_answers(
for service in self.registry.async_get_infos_server(name):
answers: List[DNSAddress] = []
additionals: Set[DNSRecord] = set()
seen_types: Set[int] = set()
for dns_address in service.dns_addresses(created=now):
seen_types.add(dns_address.type)
if dns_address.type != type_:
additionals.add(dns_address)
elif not known_answers.suppresses(dns_address):
answers.append(dns_address)
for answer in answers:
answer_set[answer] = additionals
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if answers:
if missing_types:
additionals.add(construct_nsec_record(service.server, list(missing_types), now))
for answer in answers:
answer_set[answer] = additionals
elif type_ in missing_types:
answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()

def _answer_question(
self,
Expand Down Expand Up @@ -299,7 +331,7 @@ def _answer_question(
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service.dns_service(created=now)
if not known_answers.suppresses(dns_service):
answer_set[dns_service] = set(service.dns_addresses(created=now))
answer_set[dns_service] = self._get_address_and_nsec_records(service, now)
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service.dns_text(created=now)
if not known_answers.suppresses(dns_text):
Expand Down