diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8b4ebd043..ebdb71105 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -720,3 +720,29 @@ def test_qu_packet_parser(): parsed = DNSIncoming(qu_packet) assert parsed.questions[0].unicast is True assert ",QU," in str(parsed.questions[0]) + + +def test_records_same_packet_share_fate(): + """Test records in the same packet all have the same created time.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for packet in out.packets(): + dnsin = DNSIncoming(packet) + first_time = dnsin.answers[0].created + for answer in dnsin.answers: + assert answer.created == first_time diff --git a/zeroconf/_core.py b/zeroconf/_core.py index f4fb647aa..12f3c5f43 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -465,19 +465,20 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d # # _CLASS_UNIQUE is the "QU" bit out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) - out.add_authorative_answer(info.dns_pointer()) + out.add_authorative_answer(info.dns_pointer(created=current_time_millis())) return out def _add_broadcast_answer( # pylint: disable=no-self-use self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] ) -> None: """Add answers to broadcast a service.""" + now = current_time_millis() other_ttl = info.other_ttl if override_ttl is None else override_ttl host_ttl = info.host_ttl if override_ttl is None else override_ttl - out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) - out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) - out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl): + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0) + for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index b5b2bb790..c6d0108e0 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -139,10 +139,12 @@ class DNSRecord(DNSEntry): """A DNS record - like a DNS entry, but has a TTL""" # TODO: Switch to just int ttl - def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None + ) -> None: super().__init__(name, type_, class_) self.ttl = ttl - self.created = current_time_millis() + self.created = created or current_time_millis() self._expiration_time: Optional[float] = None self._stale_time: Optional[float] = None self._recent_time: Optional[float] = None @@ -218,8 +220,10 @@ class DNSAddress(DNSRecord): """A DNS address record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.address = address def write(self, out: 'DNSOutgoing') -> None: @@ -252,8 +256,10 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.cpu = cpu self.os = os @@ -284,8 +290,10 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.alias = alias @property @@ -319,9 +327,11 @@ class DNSText(DNSRecord): """A DNS text record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None + ) -> None: assert isinstance(text, (bytes, type(None))) - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.text = text def write(self, out: 'DNSOutgoing') -> None: @@ -357,8 +367,9 @@ def __init__( weight: int, port: int, server: str, + created: Optional[float] = None, ) -> None: - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.priority = priority self.weight = weight self.port = port diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f8e9c6dfb..1d6cac4c5 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -162,7 +162,7 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: self.cache = cache def _add_service_type_enumeration_query_answers( - self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Provide an answer to a service type enumeration query. @@ -170,47 +170,60 @@ def _add_service_type_enumeration_query_answers( """ for stype in self.registry.get_types(): dns_pointer = DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now ) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. - dns_pointer = service.dns_pointer() + dns_pointer = service.dns_pointer(created=now) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set( - [service.dns_service(), service.dns_text(), *service.dns_addresses()] + [ + service.dns_service(created=now), + service.dns_text(created=now), + *service.dns_addresses(created=now), + ] ) def _add_address_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: int + self, + name: str, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, + type_: int, ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): - for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): + for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_], created=now): if not known_answers.suppresses(dns_address): answer_set[dns_address] = set() def _answer_question( - self, question: DNSQuestion, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, + question: DNSQuestion, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, ) -> None: if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._add_service_type_enumeration_query_answers(answer_set, known_answers) + self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) return type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answer_set, known_answers) + self._add_pointer_answers(question.name, answer_set, known_answers, now) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answer_set, known_answers, type_) + self._add_address_answers(question.name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore @@ -218,11 +231,11 @@ def _answer_question( if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.2. - dns_service = service.dns_service() + dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = set(service.dns_addresses()) + answer_set[dns_service] = set(service.dns_addresses(created=now)) if type_ in (_TYPE_TXT, _TYPE_ANY): - dns_text = service.dns_text() + dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text): answer_set[dns_text] = set() @@ -233,10 +246,11 @@ def response( # pylint: disable=unused-argument ucast_source = port != _MDNS_PORT known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) + now = current_time_millis() for question in itertools.chain(*[msg.questions for msg in msgs]): answer_set: _AnswerWithAdditionalsType = {} - self._answer_question(question, answer_set, known_answers) + self._answer_question(question, answer_set, known_answers, now) if not ucast_source and question.unicast: query_res.add_qu_question_response(answer_set) else: diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 64c65b96b..80ca7b886 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -24,10 +24,12 @@ import struct from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast + from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ._exceptions import IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log from ._utils.struct import int2byte +from ._utils.time import current_time_millis from .const import ( _CLASS_UNIQUE, _DNS_PACKET_HEADER_LEN, @@ -90,6 +92,7 @@ def __init__(self, data: bytes) -> None: self.num_authorities = 0 self.num_additionals = 0 self.valid = False + self.now = current_time_millis() try: self.read_header() @@ -166,11 +169,11 @@ def read_others(self) -> None: type_, class_, ttl, length = self.unpack(b'!HHiH') rec: Optional[DNSRecord] = None if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), self.now) elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) + rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) + rec = DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) elif type_ == _TYPE_SRV: rec = DNSService( domain, @@ -181,6 +184,7 @@ def read_others(self) -> None: self.read_unsigned_short(), self.read_unsigned_short(), self.read_name(), + self.now, ) elif type_ == _TYPE_HINFO: rec = DNSHinfo( @@ -190,9 +194,10 @@ def read_others(self) -> None: ttl, self.read_character_string().decode('utf-8'), self.read_character_string().decode('utf-8'), + self.now, ) elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16), self.now) else: # Try to ignore types we don't know about # Skip the payload for the resource record so the next diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index b6481eb31..c6efcc8e3 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -369,7 +369,7 @@ def update_records_complete(self) -> None: At this point the cache will have the new records. """ - # Cannot use .update here since PyPy can fail with + # Cannot use .update here since can fail with # RuntimeError: dictionary changed size during iteration # for threaded ServiceBrowsers while self._pending_handlers: @@ -722,7 +722,10 @@ def _process_record(self, record: DNSRecord, now: float) -> None: self._set_text(record.text) def dns_addresses( - self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, ) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return [ @@ -732,11 +735,12 @@ def dns_addresses( _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, address, + created, ) for address in self.addresses_by_version(version) ] - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" return DNSPointer( self.type, @@ -744,9 +748,10 @@ def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: _CLASS_IN, override_ttl if override_ttl is not None else self.other_ttl, self.name, + created, ) - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" return DNSService( self.name, @@ -757,9 +762,10 @@ def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: self.weight, cast(int, self.port), self.server, + created, ) - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: """Return DNSText from ServiceInfo.""" return DNSText( self.name, @@ -767,6 +773,7 @@ def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, + created, ) def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: