Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d
#
# _CLASS_UNIQUE is the "QU" bit
out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE))
out.add_authorative_answer(info.dns_pointer(created=current_time_millis()))
out.add_authorative_answer(info.dns_pointer())
return out

def _add_broadcast_answer( # pylint: disable=no-self-use
Expand All @@ -411,14 +411,14 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
broadcast_addresses: bool = True,
) -> None:
"""Add answers to broadcast a service."""
now = current_time_millis()
other_ttl = info.other_ttl if override_ttl is None else override_ttl
host_ttl = info.host_ttl if override_ttl is None else override_ttl
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0)
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
current_time_millis()
other_ttl = None if override_ttl is None else override_ttl
host_ttl = None if override_ttl is None else override_ttl
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0)
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0)
if broadcast_addresses:
for record in info.get_address_and_nsec_records(override_ttl=host_ttl, created=now):
for record in info.get_address_and_nsec_records(override_ttl=host_ttl):
out.add_answer_at_time(record, 0)

def unregister_service(self, info: ServiceInfo) -> None:
Expand Down
33 changes: 16 additions & 17 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,48 +163,47 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history:
self.question_history = question_history

def _add_service_type_enumeration_query_answers(
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
) -> None:
"""Provide an answer to a service type enumeration query.

https://datatracker.ietf.org/doc/html/rfc6763#section-9
"""
for stype in self.registry.async_get_types():
dns_pointer = DNSPointer(
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, 0.0
)
if not known_answers.suppresses(dns_pointer):
answer_set[dns_pointer] = set()

def _add_pointer_answers(
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
) -> None:
"""Answer PTR/ANY question."""
for service in self.registry.async_get_infos_type(lower_name):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer(created=now)
dns_pointer = service.dns_pointer()
if known_answers.suppresses(dns_pointer):
continue
answer_set[dns_pointer] = {
service.dns_service(created=now),
service.dns_text(created=now),
} | service.get_address_and_nsec_records(created=now)
service.dns_service(),
service.dns_text(),
} | service.get_address_and_nsec_records()

def _add_address_answers(
self,
lower_name: str,
answer_set: _AnswerWithAdditionalsType,
known_answers: DNSRRSet,
now: float,
type_: int,
) -> None:
"""Answer A/AAAA/ANY question."""
for service in self.registry.async_get_infos_server(lower_name):
answers: List[DNSAddress] = []
additionals: Set[DNSRecord] = set()
seen_types: Set[int] = set()
for dns_address in service.dns_addresses(created=now):
for dns_address in service.dns_addresses():
seen_types.add(dns_address.type)
if dns_address.type != type_:
additionals.add(dns_address)
Expand All @@ -214,12 +213,12 @@ def _add_address_answers(
if answers:
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
additionals.add(service.dns_nsec(list(missing_types), created=now))
additionals.add(service.dns_nsec(list(missing_types)))
for answer in answers:
answer_set[answer] = additionals
elif type_ in missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
answer_set[service.dns_nsec(list(missing_types), created=now)] = set()
answer_set[service.dns_nsec(list(missing_types))] = set()

def _answer_question(
self,
Expand All @@ -231,28 +230,28 @@ def _answer_question(
question_lower_name = question.name.lower()

if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
self._add_service_type_enumeration_query_answers(answer_set, known_answers)
return answer_set

type_ = question.type

if type_ in (_TYPE_PTR, _TYPE_ANY):
self._add_pointer_answers(question_lower_name, answer_set, known_answers, now)
self._add_pointer_answers(question_lower_name, answer_set, known_answers)

if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_)
self._add_address_answers(question_lower_name, answer_set, known_answers, type_)

if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
service = self.registry.async_get_info_name(question_lower_name)
if service is not None:
if type_ in (_TYPE_SRV, _TYPE_ANY):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service.dns_service(created=now)
dns_service = service.dns_service()
if not known_answers.suppresses(dns_service):
answer_set[dns_service] = service.get_address_and_nsec_records(created=now)
answer_set[dns_service] = service.get_address_and_nsec_records()
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service.dns_text(created=now)
dns_text = service.dns_text()
if not known_answers.suppresses(dns_text):
answer_set[dns_text] = set()

Expand Down
81 changes: 60 additions & 21 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class ServiceInfo(RecordUpdateListener):
"other_ttl",
"interface_index",
"_new_records_futures",
"_dns_pointer_cache",
"_dns_service_cache",
"_dns_text_cache",
"_dns_address_cache",
"_get_address_and_nsec_records_cache",
)

def __init__(
Expand Down Expand Up @@ -180,6 +185,11 @@ def __init__(
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: Set[asyncio.Future] = set()
self._dns_address_cache: Optional[List[DNSAddress]] = None
self._dns_pointer_cache: Optional[DNSPointer] = None
self._dns_service_cache: Optional[DNSService] = None
self._dns_text_cache: Optional[DNSText] = None
Comment on lines +188 to +191

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would have required a subclass of ServiceInfo to be registered with the registry but thats a design choice we would have needed to make many years ago and would be a breaking change now

self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None

@property
def name(self) -> str:
Expand All @@ -191,6 +201,9 @@ def name(self, name: str) -> None:
"""Replace the the name and reset the key."""
self._name = name
self.key = name.lower()
self._dns_service_cache = None
self._dns_pointer_cache = None
self._dns_text_cache = None

@property
def addresses(self) -> List[bytes]:
Expand All @@ -210,6 +223,8 @@ def addresses(self, value: List[bytes]) -> None:
"""
self._ipv4_addresses.clear()
self._ipv6_addresses.clear()
self._dns_address_cache = None
self._get_address_and_nsec_records_cache = None

for address in value:
try:
Expand Down Expand Up @@ -489,42 +504,56 @@ def dns_addresses(
self,
override_ttl: Optional[int] = None,
version: IPVersion = IPVersion.All,
created: Optional[float] = None,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
cacheable = version is IPVersion.All and override_ttl is None
if self._dns_address_cache is not None and cacheable:
return self._dns_address_cache
name = self.server or self._name
ttl = override_ttl if override_ttl is not None else self.host_ttl
class_ = _CLASS_IN_UNIQUE
version_value = version.value
return [
records = [
DNSAddress(
name,
_TYPE_AAAA if type(ip_addr) is IPv6Address else _TYPE_A,
class_,
ttl,
ip_addr.packed,
created=created,
created=0.0,
)
for ip_addr in self._ip_addresses_by_version_value(version_value)
]
if cacheable:
self._dns_address_cache = records
return records

def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer:
"""Return DNSPointer from ServiceInfo."""
return DNSPointer(
cacheable = override_ttl is None
if self._dns_pointer_cache is not None and cacheable:
return self._dns_pointer_cache
record = DNSPointer(
self.type,
_TYPE_PTR,
_CLASS_IN,
override_ttl if override_ttl is not None else self.other_ttl,
self._name,
created,
0.0,
)
if cacheable:
self._dns_pointer_cache = record
return record

def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
"""Return DNSService from ServiceInfo."""
cacheable = override_ttl is None
if self._dns_service_cache is not None and cacheable:
return self._dns_service_cache
port = self.port
if TYPE_CHECKING:
assert isinstance(port, int)
return DNSService(
record = DNSService(
self._name,
_TYPE_SRV,
_CLASS_IN_UNIQUE,
Expand All @@ -533,23 +562,30 @@ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[floa
self.weight,
port,
self.server or self._name,
created,
0.0,
)
if cacheable:
self._dns_service_cache = record
return record

def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
def dns_text(self, override_ttl: Optional[int] = None) -> DNSText:
"""Return DNSText from ServiceInfo."""
return DNSText(
cacheable = override_ttl is None
if self._dns_text_cache is not None and cacheable:
return self._dns_text_cache
record = DNSText(
self._name,
_TYPE_TXT,
_CLASS_IN_UNIQUE,
override_ttl if override_ttl is not None else self.other_ttl,
self.text,
created,
0.0,
)
if cacheable:
self._dns_text_cache = record
return record

def dns_nsec(
self, missing_types: List[int], override_ttl: Optional[int] = None, created: Optional[float] = None
) -> DNSNsec:
def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return DNSNsec(
self._name,
Expand All @@ -558,21 +594,24 @@ def dns_nsec(
override_ttl if override_ttl is not None else self.host_ttl,
self._name,
missing_types,
created,
0.0,
)

def get_address_and_nsec_records(
self, override_ttl: Optional[int] = None, created: Optional[float] = None
) -> Set[DNSRecord]:
def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
cacheable = override_ttl is None
if self._get_address_and_nsec_records_cache is not None and cacheable:
return self._get_address_and_nsec_records_cache
missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy()
records: Set[DNSRecord] = set()
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
for dns_address in self.dns_addresses(override_ttl, IPVersion.All):
missing_types.discard(dns_address.type)
records.add(dns_address)
if missing_types:
assert self.server is not None, "Service server must be set for NSEC record."
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
records.add(self.dns_nsec(list(missing_types), override_ttl))
if cacheable:
self._get_address_and_nsec_records_cache = records
return records

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
Expand Down
10 changes: 10 additions & 0 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
assert service_info.port == 80

info.port = 400
info._dns_service_cache = None # we are mutating the record so clear the cache

_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
Expand Down Expand Up @@ -856,6 +858,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
)
time.sleep(0.2)
info._dns_service_cache = None # we are mutating the record so clear the cache

info.port = 400
_inject_response(
zc,
Expand Down Expand Up @@ -914,6 +918,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
)
time.sleep(0.2)
info.port = 400
info._dns_service_cache = None # we are mutating the record so clear the cache

_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
Expand Down Expand Up @@ -1131,6 +1137,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
)
time.sleep(0.2)
info.port = 400
info._dns_service_cache = None # we are mutating the record so clear the cache

_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
Expand Down Expand Up @@ -1210,6 +1218,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
)
time.sleep(0.3)
info.port = 400
info._dns_service_cache = None # we are mutating the record so clear the cache

_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
Expand Down
1 change: 1 addition & 0 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer():
a_record = info.dns_addresses()[0]
a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
assert not a_record.is_recent(current_time_millis())
info._dns_address_cache = None # we are mutating the record so clear the cache
zc.cache.async_add_records([a_record])

# With QU should respond to only unicast when the answer has been recently multicast
Expand Down