Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,29 @@ def test_qu_packet_parser():
parsed = DNSIncoming(qu_packet)
assert parsed.questions[0].unicast is True
assert ",QU," in str(parsed.questions[0])


def test_records_same_packet_share_fate():
"""Test records in the same packet all have the same created time."""
out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
type_ = "_hap._tcp.local."
out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN))

for i in range(30):
out.add_answer_at_time(
DNSText(
("HASS Bridge W9DN %s._hap._tcp.local." % i),
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
const._DNS_OTHER_TTL,
b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1'
b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
),
0,
)

for packet in out.packets():
dnsin = DNSIncoming(packet)
first_time = dnsin.answers[0].created
for answer in dnsin.answers:
assert answer.created == first_time
11 changes: 6 additions & 5 deletions zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,19 +465,20 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d
#
# _CLASS_UNIQUE is the "QU" bit
out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE))
out.add_authorative_answer(info.dns_pointer())
out.add_authorative_answer(info.dns_pointer(created=current_time_millis()))
return out

def _add_broadcast_answer( # pylint: disable=no-self-use
self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]
) -> None:
"""Add answers to broadcast a service."""
now = current_time_millis()
other_ttl = info.other_ttl if override_ttl is None else override_ttl
host_ttl = info.host_ttl if override_ttl is None else override_ttl
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0)
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0)
for dns_address in info.dns_addresses(override_ttl=host_ttl):
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0)
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now):
out.add_answer_at_time(dns_address, 0)

def unregister_service(self, info: ServiceInfo) -> None:
Expand Down
33 changes: 22 additions & 11 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,12 @@ class DNSRecord(DNSEntry):
"""A DNS record - like a DNS entry, but has a TTL"""

# TODO: Switch to just int ttl
def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None:
def __init__(
self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_)
self.ttl = ttl
self.created = current_time_millis()
self.created = created or current_time_millis()
self._expiration_time: Optional[float] = None
self._stale_time: Optional[float] = None
self._recent_time: Optional[float] = None
Expand Down Expand Up @@ -218,8 +220,10 @@ class DNSAddress(DNSRecord):

"""A DNS address record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None:
super().__init__(name, type_, class_, ttl)
def __init__(
self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.address = address

def write(self, out: 'DNSOutgoing') -> None:
Expand Down Expand Up @@ -252,8 +256,10 @@ class DNSHinfo(DNSRecord):

"""A DNS host information record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None:
super().__init__(name, type_, class_, ttl)
def __init__(
self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.cpu = cpu
self.os = os

Expand Down Expand Up @@ -284,8 +290,10 @@ class DNSPointer(DNSRecord):

"""A DNS pointer record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None:
super().__init__(name, type_, class_, ttl)
def __init__(
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.alias = alias

@property
Expand Down Expand Up @@ -319,9 +327,11 @@ class DNSText(DNSRecord):

"""A DNS text record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None:
def __init__(
self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None
) -> None:
assert isinstance(text, (bytes, type(None)))
super().__init__(name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl, created)
self.text = text

def write(self, out: 'DNSOutgoing') -> None:
Expand Down Expand Up @@ -357,8 +367,9 @@ def __init__(
weight: int,
port: int,
server: str,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl, created)
self.priority = priority
self.weight = weight
self.port = port
Expand Down
44 changes: 29 additions & 15 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,67 +162,80 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None:
self.cache = cache

def _add_service_type_enumeration_query_answers(
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
) -> None:
"""Provide an answer to a service type enumeration query.

https://datatracker.ietf.org/doc/html/rfc6763#section-9
"""
for stype in self.registry.get_types():
dns_pointer = DNSPointer(
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now
)
if not known_answers.suppresses(dns_pointer):
answer_set[dns_pointer] = set()

def _add_pointer_answers(
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
) -> None:
"""Answer PTR/ANY question."""
for service in self.registry.get_infos_type(name):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer()
dns_pointer = service.dns_pointer(created=now)
if not known_answers.suppresses(dns_pointer):
answer_set[dns_pointer] = set(
[service.dns_service(), service.dns_text(), *service.dns_addresses()]
[
service.dns_service(created=now),
service.dns_text(created=now),
*service.dns_addresses(created=now),
]
)

def _add_address_answers(
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: int
self,
name: str,
answer_set: _AnswerWithAdditionalsType,
known_answers: DNSRRSet,
now: float,
type_: int,
) -> None:
"""Answer A/AAAA/ANY question."""
for service in self.registry.get_infos_server(name):
for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]):
for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_], created=now):
if not known_answers.suppresses(dns_address):
answer_set[dns_address] = set()

def _answer_question(
self, question: DNSQuestion, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
self,
question: DNSQuestion,
answer_set: _AnswerWithAdditionalsType,
known_answers: DNSRRSet,
now: float,
) -> None:
if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
self._add_service_type_enumeration_query_answers(answer_set, known_answers)
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
return

type_ = question.type

if type_ in (_TYPE_PTR, _TYPE_ANY):
self._add_pointer_answers(question.name, answer_set, known_answers)
self._add_pointer_answers(question.name, answer_set, known_answers, now)

if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
self._add_address_answers(question.name, answer_set, known_answers, type_)
self._add_address_answers(question.name, answer_set, known_answers, now, type_)

if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
service = self.registry.get_info_name(question.name) # type: ignore
if service is not None:
if type_ in (_TYPE_SRV, _TYPE_ANY):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service.dns_service()
dns_service = service.dns_service(created=now)
if not known_answers.suppresses(dns_service):
answer_set[dns_service] = set(service.dns_addresses())
answer_set[dns_service] = set(service.dns_addresses(created=now))
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service.dns_text()
dns_text = service.dns_text(created=now)
if not known_answers.suppresses(dns_text):
answer_set[dns_text] = set()

Expand All @@ -233,10 +246,11 @@ def response( # pylint: disable=unused-argument
ucast_source = port != _MDNS_PORT
known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs]))
query_res = _QueryResponse(self.cache, msgs[0], ucast_source)
now = current_time_millis()

for question in itertools.chain(*[msg.questions for msg in msgs]):
answer_set: _AnswerWithAdditionalsType = {}
self._answer_question(question, answer_set, known_answers)
self._answer_question(question, answer_set, known_answers, now)
if not ucast_source and question.unicast:
query_res.add_qu_question_response(answer_set)
else:
Expand Down
13 changes: 9 additions & 4 deletions zeroconf/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import struct
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast


from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText
from ._exceptions import IncomingDecodeError, NamePartTooLongException
from ._logger import QuietLogger, log
from ._utils.struct import int2byte
from ._utils.time import current_time_millis
from .const import (
_CLASS_UNIQUE,
_DNS_PACKET_HEADER_LEN,
Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(self, data: bytes) -> None:
self.num_authorities = 0
self.num_additionals = 0
self.valid = False
self.now = current_time_millis()

try:
self.read_header()
Expand Down Expand Up @@ -166,11 +169,11 @@ def read_others(self) -> None:
type_, class_, ttl, length = self.unpack(b'!HHiH')
rec: Optional[DNSRecord] = None
if type_ == _TYPE_A:
rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4))
rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), self.now)
elif type_ in (_TYPE_CNAME, _TYPE_PTR):
rec = DNSPointer(domain, type_, class_, ttl, self.read_name())
rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now)
elif type_ == _TYPE_TXT:
rec = DNSText(domain, type_, class_, ttl, self.read_string(length))
rec = DNSText(domain, type_, class_, ttl, self.read_string(length), self.now)
elif type_ == _TYPE_SRV:
rec = DNSService(
domain,
Expand All @@ -181,6 +184,7 @@ def read_others(self) -> None:
self.read_unsigned_short(),
self.read_unsigned_short(),
self.read_name(),
self.now,
)
elif type_ == _TYPE_HINFO:
rec = DNSHinfo(
Expand All @@ -190,9 +194,10 @@ def read_others(self) -> None:
ttl,
self.read_character_string().decode('utf-8'),
self.read_character_string().decode('utf-8'),
self.now,
)
elif type_ == _TYPE_AAAA:
rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16))
rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16), self.now)
else:
# Try to ignore types we don't know about
# Skip the payload for the resource record so the next
Expand Down
17 changes: 12 additions & 5 deletions zeroconf/_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def update_records_complete(self) -> None:

At this point the cache will have the new records.
"""
# Cannot use .update here since PyPy can fail with
# Cannot use .update here since can fail with
# RuntimeError: dictionary changed size during iteration
# for threaded ServiceBrowsers
while self._pending_handlers:
Expand Down Expand Up @@ -722,7 +722,10 @@ def _process_record(self, record: DNSRecord, now: float) -> None:
self._set_text(record.text)

def dns_addresses(
self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All
self,
override_ttl: Optional[int] = None,
version: IPVersion = IPVersion.All,
created: Optional[float] = None,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
return [
Expand All @@ -732,21 +735,23 @@ def dns_addresses(
_CLASS_IN | _CLASS_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
address,
created,
)
for address in self.addresses_by_version(version)
]

def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer:
def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
"""Return DNSPointer from ServiceInfo."""
return DNSPointer(
self.type,
_TYPE_PTR,
_CLASS_IN,
override_ttl if override_ttl is not None else self.other_ttl,
self.name,
created,
)

def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
"""Return DNSService from ServiceInfo."""
return DNSService(
self.name,
Expand All @@ -757,16 +762,18 @@ def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
self.weight,
cast(int, self.port),
self.server,
created,
)

def dns_text(self, override_ttl: Optional[int] = None) -> DNSText:
def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
"""Return DNSText from ServiceInfo."""
return DNSText(
self.name,
_TYPE_TXT,
_CLASS_IN | _CLASS_UNIQUE,
override_ttl if override_ttl is not None else self.other_ttl,
self.text,
created,
)

def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]:
Expand Down