From ada4280393b087c98981652376981eaf897cdf24 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 10:43:45 -1000 Subject: [PATCH] Synchronize created time for incoming and outgoing queries - The created time would differ by a few ms for records, these should all be the same in each packet to reduce unexpected inconsistency - Records received in the same response MUST be subject to fate sharing. Its possible we could end up expiring part of the unique record set because one of the addresses is a few ms older. Fixes #700 --- tests/test_protocol.py | 26 ++++++++++++++++++++ zeroconf/_core.py | 11 +++++---- zeroconf/_dns.py | 33 ++++++++++++++++--------- zeroconf/_handlers.py | 44 ++++++++++++++++++++++------------ zeroconf/_protocol.py | 13 ++++++---- zeroconf/_services/__init__.py | 17 +++++++++---- 6 files changed, 104 insertions(+), 40 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8b4ebd043..ebdb71105 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -720,3 +720,29 @@ def test_qu_packet_parser(): parsed = DNSIncoming(qu_packet) assert parsed.questions[0].unicast is True assert ",QU," in str(parsed.questions[0]) + + +def test_records_same_packet_share_fate(): + """Test records in the same packet all have the same created time.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for packet in out.packets(): + dnsin = DNSIncoming(packet) + first_time = dnsin.answers[0].created + for answer in dnsin.answers: + assert answer.created == first_time diff --git a/zeroconf/_core.py b/zeroconf/_core.py index f4fb647aa..12f3c5f43 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -465,19 +465,20 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d # # _CLASS_UNIQUE is the "QU" bit out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) - out.add_authorative_answer(info.dns_pointer()) + out.add_authorative_answer(info.dns_pointer(created=current_time_millis())) return out def _add_broadcast_answer( # pylint: disable=no-self-use self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] ) -> None: """Add answers to broadcast a service.""" + now = current_time_millis() other_ttl = info.other_ttl if override_ttl is None else override_ttl host_ttl = info.host_ttl if override_ttl is None else override_ttl - out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) - out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) - out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl): + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0) + for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index b5b2bb790..c6d0108e0 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -139,10 +139,12 @@ class DNSRecord(DNSEntry): """A DNS record - like a DNS entry, but has a TTL""" # TODO: Switch to just int ttl - def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None + ) -> None: super().__init__(name, type_, class_) self.ttl = ttl - self.created = current_time_millis() + self.created = created or current_time_millis() self._expiration_time: Optional[float] = None self._stale_time: Optional[float] = None self._recent_time: Optional[float] = None @@ -218,8 +220,10 @@ class DNSAddress(DNSRecord): """A DNS address record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.address = address def write(self, out: 'DNSOutgoing') -> None: @@ -252,8 +256,10 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.cpu = cpu self.os = os @@ -284,8 +290,10 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.alias = alias @property @@ -319,9 +327,11 @@ class DNSText(DNSRecord): """A DNS text record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None + ) -> None: assert isinstance(text, (bytes, type(None))) - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.text = text def write(self, out: 'DNSOutgoing') -> None: @@ -357,8 +367,9 @@ def __init__( weight: int, port: int, server: str, + created: Optional[float] = None, ) -> None: - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.priority = priority self.weight = weight self.port = port diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f8e9c6dfb..1d6cac4c5 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -162,7 +162,7 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: self.cache = cache def _add_service_type_enumeration_query_answers( - self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Provide an answer to a service type enumeration query. @@ -170,47 +170,60 @@ def _add_service_type_enumeration_query_answers( """ for stype in self.registry.get_types(): dns_pointer = DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now ) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. - dns_pointer = service.dns_pointer() + dns_pointer = service.dns_pointer(created=now) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set( - [service.dns_service(), service.dns_text(), *service.dns_addresses()] + [ + service.dns_service(created=now), + service.dns_text(created=now), + *service.dns_addresses(created=now), + ] ) def _add_address_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: int + self, + name: str, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, + type_: int, ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): - for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): + for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_], created=now): if not known_answers.suppresses(dns_address): answer_set[dns_address] = set() def _answer_question( - self, question: DNSQuestion, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, + question: DNSQuestion, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, ) -> None: if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._add_service_type_enumeration_query_answers(answer_set, known_answers) + self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) return type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answer_set, known_answers) + self._add_pointer_answers(question.name, answer_set, known_answers, now) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answer_set, known_answers, type_) + self._add_address_answers(question.name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore @@ -218,11 +231,11 @@ def _answer_question( if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.2. - dns_service = service.dns_service() + dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = set(service.dns_addresses()) + answer_set[dns_service] = set(service.dns_addresses(created=now)) if type_ in (_TYPE_TXT, _TYPE_ANY): - dns_text = service.dns_text() + dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text): answer_set[dns_text] = set() @@ -233,10 +246,11 @@ def response( # pylint: disable=unused-argument ucast_source = port != _MDNS_PORT known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) + now = current_time_millis() for question in itertools.chain(*[msg.questions for msg in msgs]): answer_set: _AnswerWithAdditionalsType = {} - self._answer_question(question, answer_set, known_answers) + self._answer_question(question, answer_set, known_answers, now) if not ucast_source and question.unicast: query_res.add_qu_question_response(answer_set) else: diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 64c65b96b..80ca7b886 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -24,10 +24,12 @@ import struct from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast + from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ._exceptions import IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log from ._utils.struct import int2byte +from ._utils.time import current_time_millis from .const import ( _CLASS_UNIQUE, _DNS_PACKET_HEADER_LEN, @@ -90,6 +92,7 @@ def __init__(self, data: bytes) -> None: self.num_authorities = 0 self.num_additionals = 0 self.valid = False + self.now = current_time_millis() try: self.read_header() @@ -166,11 +169,11 @@ def read_others(self) -> None: type_, class_, ttl, length = self.unpack(b'!HHiH') rec: Optional[DNSRecord] = None if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), self.now) elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) + rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) + rec = DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) elif type_ == _TYPE_SRV: rec = DNSService( domain, @@ -181,6 +184,7 @@ def read_others(self) -> None: self.read_unsigned_short(), self.read_unsigned_short(), self.read_name(), + self.now, ) elif type_ == _TYPE_HINFO: rec = DNSHinfo( @@ -190,9 +194,10 @@ def read_others(self) -> None: ttl, self.read_character_string().decode('utf-8'), self.read_character_string().decode('utf-8'), + self.now, ) elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16), self.now) else: # Try to ignore types we don't know about # Skip the payload for the resource record so the next diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index b6481eb31..c6efcc8e3 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -369,7 +369,7 @@ def update_records_complete(self) -> None: At this point the cache will have the new records. """ - # Cannot use .update here since PyPy can fail with + # Cannot use .update here since can fail with # RuntimeError: dictionary changed size during iteration # for threaded ServiceBrowsers while self._pending_handlers: @@ -722,7 +722,10 @@ def _process_record(self, record: DNSRecord, now: float) -> None: self._set_text(record.text) def dns_addresses( - self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, ) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return [ @@ -732,11 +735,12 @@ def dns_addresses( _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, address, + created, ) for address in self.addresses_by_version(version) ] - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" return DNSPointer( self.type, @@ -744,9 +748,10 @@ def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: _CLASS_IN, override_ttl if override_ttl is not None else self.other_ttl, self.name, + created, ) - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" return DNSService( self.name, @@ -757,9 +762,10 @@ def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: self.weight, cast(int, self.port), self.server, + created, ) - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: """Return DNSText from ServiceInfo.""" return DNSText( self.name, @@ -767,6 +773,7 @@ def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, + created, ) def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: