Skip to content

Commit 0890f62

Browse files
authored
feat: cache construction of records used to answer queries from the service registry (#1243)
1 parent af192d3 commit 0890f62

5 files changed

Lines changed: 95 additions & 46 deletions

File tree

src/zeroconf/_core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d
400400
#
401401
# _CLASS_UNIQUE is the "QU" bit
402402
out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE))
403-
out.add_authorative_answer(info.dns_pointer(created=current_time_millis()))
403+
out.add_authorative_answer(info.dns_pointer())
404404
return out
405405

406406
def _add_broadcast_answer( # pylint: disable=no-self-use
@@ -411,14 +411,14 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
411411
broadcast_addresses: bool = True,
412412
) -> None:
413413
"""Add answers to broadcast a service."""
414-
now = current_time_millis()
415-
other_ttl = info.other_ttl if override_ttl is None else override_ttl
416-
host_ttl = info.host_ttl if override_ttl is None else override_ttl
417-
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0)
418-
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
419-
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
414+
current_time_millis()
415+
other_ttl = None if override_ttl is None else override_ttl
416+
host_ttl = None if override_ttl is None else override_ttl
417+
out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0)
418+
out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0)
419+
out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0)
420420
if broadcast_addresses:
421-
for record in info.get_address_and_nsec_records(override_ttl=host_ttl, created=now):
421+
for record in info.get_address_and_nsec_records(override_ttl=host_ttl):
422422
out.add_answer_at_time(record, 0)
423423

424424
def unregister_service(self, info: ServiceInfo) -> None:

src/zeroconf/_handlers/query_handler.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,48 +163,47 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history:
163163
self.question_history = question_history
164164

165165
def _add_service_type_enumeration_query_answers(
166-
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
166+
self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
167167
) -> None:
168168
"""Provide an answer to a service type enumeration query.
169169
170170
https://datatracker.ietf.org/doc/html/rfc6763#section-9
171171
"""
172172
for stype in self.registry.async_get_types():
173173
dns_pointer = DNSPointer(
174-
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now
174+
_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, 0.0
175175
)
176176
if not known_answers.suppresses(dns_pointer):
177177
answer_set[dns_pointer] = set()
178178

179179
def _add_pointer_answers(
180-
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
180+
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
181181
) -> None:
182182
"""Answer PTR/ANY question."""
183183
for service in self.registry.async_get_infos_type(lower_name):
184184
# Add recommended additional answers according to
185185
# https://tools.ietf.org/html/rfc6763#section-12.1.
186-
dns_pointer = service.dns_pointer(created=now)
186+
dns_pointer = service.dns_pointer()
187187
if known_answers.suppresses(dns_pointer):
188188
continue
189189
answer_set[dns_pointer] = {
190-
service.dns_service(created=now),
191-
service.dns_text(created=now),
192-
} | service.get_address_and_nsec_records(created=now)
190+
service.dns_service(),
191+
service.dns_text(),
192+
} | service.get_address_and_nsec_records()
193193

194194
def _add_address_answers(
195195
self,
196196
lower_name: str,
197197
answer_set: _AnswerWithAdditionalsType,
198198
known_answers: DNSRRSet,
199-
now: float,
200199
type_: int,
201200
) -> None:
202201
"""Answer A/AAAA/ANY question."""
203202
for service in self.registry.async_get_infos_server(lower_name):
204203
answers: List[DNSAddress] = []
205204
additionals: Set[DNSRecord] = set()
206205
seen_types: Set[int] = set()
207-
for dns_address in service.dns_addresses(created=now):
206+
for dns_address in service.dns_addresses():
208207
seen_types.add(dns_address.type)
209208
if dns_address.type != type_:
210209
additionals.add(dns_address)
@@ -214,12 +213,12 @@ def _add_address_answers(
214213
if answers:
215214
if missing_types:
216215
assert service.server is not None, "Service server must be set for NSEC record."
217-
additionals.add(service.dns_nsec(list(missing_types), created=now))
216+
additionals.add(service.dns_nsec(list(missing_types)))
218217
for answer in answers:
219218
answer_set[answer] = additionals
220219
elif type_ in missing_types:
221220
assert service.server is not None, "Service server must be set for NSEC record."
222-
answer_set[service.dns_nsec(list(missing_types), created=now)] = set()
221+
answer_set[service.dns_nsec(list(missing_types))] = set()
223222

224223
def _answer_question(
225224
self,
@@ -231,28 +230,28 @@ def _answer_question(
231230
question_lower_name = question.name.lower()
232231

233232
if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
234-
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
233+
self._add_service_type_enumeration_query_answers(answer_set, known_answers)
235234
return answer_set
236235

237236
type_ = question.type
238237

239238
if type_ in (_TYPE_PTR, _TYPE_ANY):
240-
self._add_pointer_answers(question_lower_name, answer_set, known_answers, now)
239+
self._add_pointer_answers(question_lower_name, answer_set, known_answers)
241240

242241
if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
243-
self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_)
242+
self._add_address_answers(question_lower_name, answer_set, known_answers, type_)
244243

245244
if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
246245
service = self.registry.async_get_info_name(question_lower_name)
247246
if service is not None:
248247
if type_ in (_TYPE_SRV, _TYPE_ANY):
249248
# Add recommended additional answers according to
250249
# https://tools.ietf.org/html/rfc6763#section-12.2.
251-
dns_service = service.dns_service(created=now)
250+
dns_service = service.dns_service()
252251
if not known_answers.suppresses(dns_service):
253-
answer_set[dns_service] = service.get_address_and_nsec_records(created=now)
252+
answer_set[dns_service] = service.get_address_and_nsec_records()
254253
if type_ in (_TYPE_TXT, _TYPE_ANY):
255-
dns_text = service.dns_text(created=now)
254+
dns_text = service.dns_text()
256255
if not known_answers.suppresses(dns_text):
257256
answer_set[dns_text] = set()
258257

src/zeroconf/_services/info.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ class ServiceInfo(RecordUpdateListener):
133133
"other_ttl",
134134
"interface_index",
135135
"_new_records_futures",
136+
"_dns_pointer_cache",
137+
"_dns_service_cache",
138+
"_dns_text_cache",
139+
"_dns_address_cache",
140+
"_get_address_and_nsec_records_cache",
136141
)
137142

138143
def __init__(
@@ -180,6 +185,11 @@ def __init__(
180185
self.other_ttl = other_ttl
181186
self.interface_index = interface_index
182187
self._new_records_futures: Set[asyncio.Future] = set()
188+
self._dns_address_cache: Optional[List[DNSAddress]] = None
189+
self._dns_pointer_cache: Optional[DNSPointer] = None
190+
self._dns_service_cache: Optional[DNSService] = None
191+
self._dns_text_cache: Optional[DNSText] = None
192+
self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None
183193

184194
@property
185195
def name(self) -> str:
@@ -191,6 +201,9 @@ def name(self, name: str) -> None:
191201
"""Replace the the name and reset the key."""
192202
self._name = name
193203
self.key = name.lower()
204+
self._dns_service_cache = None
205+
self._dns_pointer_cache = None
206+
self._dns_text_cache = None
194207

195208
@property
196209
def addresses(self) -> List[bytes]:
@@ -210,6 +223,8 @@ def addresses(self, value: List[bytes]) -> None:
210223
"""
211224
self._ipv4_addresses.clear()
212225
self._ipv6_addresses.clear()
226+
self._dns_address_cache = None
227+
self._get_address_and_nsec_records_cache = None
213228

214229
for address in value:
215230
try:
@@ -489,42 +504,56 @@ def dns_addresses(
489504
self,
490505
override_ttl: Optional[int] = None,
491506
version: IPVersion = IPVersion.All,
492-
created: Optional[float] = None,
493507
) -> List[DNSAddress]:
494508
"""Return matching DNSAddress from ServiceInfo."""
509+
cacheable = version is IPVersion.All and override_ttl is None
510+
if self._dns_address_cache is not None and cacheable:
511+
return self._dns_address_cache
495512
name = self.server or self._name
496513
ttl = override_ttl if override_ttl is not None else self.host_ttl
497514
class_ = _CLASS_IN_UNIQUE
498515
version_value = version.value
499-
return [
516+
records = [
500517
DNSAddress(
501518
name,
502519
_TYPE_AAAA if type(ip_addr) is IPv6Address else _TYPE_A,
503520
class_,
504521
ttl,
505522
ip_addr.packed,
506-
created=created,
523+
created=0.0,
507524
)
508525
for ip_addr in self._ip_addresses_by_version_value(version_value)
509526
]
527+
if cacheable:
528+
self._dns_address_cache = records
529+
return records
510530

511-
def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
531+
def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer:
512532
"""Return DNSPointer from ServiceInfo."""
513-
return DNSPointer(
533+
cacheable = override_ttl is None
534+
if self._dns_pointer_cache is not None and cacheable:
535+
return self._dns_pointer_cache
536+
record = DNSPointer(
514537
self.type,
515538
_TYPE_PTR,
516539
_CLASS_IN,
517540
override_ttl if override_ttl is not None else self.other_ttl,
518541
self._name,
519-
created,
542+
0.0,
520543
)
544+
if cacheable:
545+
self._dns_pointer_cache = record
546+
return record
521547

522-
def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
548+
def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
523549
"""Return DNSService from ServiceInfo."""
550+
cacheable = override_ttl is None
551+
if self._dns_service_cache is not None and cacheable:
552+
return self._dns_service_cache
524553
port = self.port
525554
if TYPE_CHECKING:
526555
assert isinstance(port, int)
527-
return DNSService(
556+
record = DNSService(
528557
self._name,
529558
_TYPE_SRV,
530559
_CLASS_IN_UNIQUE,
@@ -533,23 +562,30 @@ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[floa
533562
self.weight,
534563
port,
535564
self.server or self._name,
536-
created,
565+
0.0,
537566
)
567+
if cacheable:
568+
self._dns_service_cache = record
569+
return record
538570

539-
def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
571+
def dns_text(self, override_ttl: Optional[int] = None) -> DNSText:
540572
"""Return DNSText from ServiceInfo."""
541-
return DNSText(
573+
cacheable = override_ttl is None
574+
if self._dns_text_cache is not None and cacheable:
575+
return self._dns_text_cache
576+
record = DNSText(
542577
self._name,
543578
_TYPE_TXT,
544579
_CLASS_IN_UNIQUE,
545580
override_ttl if override_ttl is not None else self.other_ttl,
546581
self.text,
547-
created,
582+
0.0,
548583
)
584+
if cacheable:
585+
self._dns_text_cache = record
586+
return record
549587

550-
def dns_nsec(
551-
self, missing_types: List[int], override_ttl: Optional[int] = None, created: Optional[float] = None
552-
) -> DNSNsec:
588+
def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec:
553589
"""Return DNSNsec from ServiceInfo."""
554590
return DNSNsec(
555591
self._name,
@@ -558,21 +594,24 @@ def dns_nsec(
558594
override_ttl if override_ttl is not None else self.host_ttl,
559595
self._name,
560596
missing_types,
561-
created,
597+
0.0,
562598
)
563599

564-
def get_address_and_nsec_records(
565-
self, override_ttl: Optional[int] = None, created: Optional[float] = None
566-
) -> Set[DNSRecord]:
600+
def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]:
567601
"""Build a set of address records and NSEC records for non-present record types."""
602+
cacheable = override_ttl is None
603+
if self._get_address_and_nsec_records_cache is not None and cacheable:
604+
return self._get_address_and_nsec_records_cache
568605
missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy()
569606
records: Set[DNSRecord] = set()
570-
for dns_address in self.dns_addresses(override_ttl, IPVersion.All, created):
607+
for dns_address in self.dns_addresses(override_ttl, IPVersion.All):
571608
missing_types.discard(dns_address.type)
572609
records.add(dns_address)
573610
if missing_types:
574611
assert self.server is not None, "Service server must be set for NSEC record."
575-
records.add(self.dns_nsec(list(missing_types), override_ttl, created))
612+
records.add(self.dns_nsec(list(missing_types), override_ttl))
613+
if cacheable:
614+
self._get_address_and_nsec_records_cache = records
576615
return records
577616

578617
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:

tests/services/test_browser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
792792
assert service_info.port == 80
793793

794794
info.port = 400
795+
info._dns_service_cache = None # we are mutating the record so clear the cache
796+
795797
_inject_response(
796798
zc,
797799
mock_incoming_msg([info.dns_service()]),
@@ -856,6 +858,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
856858
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
857859
)
858860
time.sleep(0.2)
861+
info._dns_service_cache = None # we are mutating the record so clear the cache
862+
859863
info.port = 400
860864
_inject_response(
861865
zc,
@@ -914,6 +918,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
914918
)
915919
time.sleep(0.2)
916920
info.port = 400
921+
info._dns_service_cache = None # we are mutating the record so clear the cache
922+
917923
_inject_response(
918924
zc,
919925
mock_incoming_msg([info.dns_service()]),
@@ -1131,6 +1137,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
11311137
)
11321138
time.sleep(0.2)
11331139
info.port = 400
1140+
info._dns_service_cache = None # we are mutating the record so clear the cache
1141+
11341142
_inject_response(
11351143
zc,
11361144
mock_incoming_msg([info.dns_service()]),
@@ -1210,6 +1218,8 @@ def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
12101218
)
12111219
time.sleep(0.3)
12121220
info.port = 400
1221+
info._dns_service_cache = None # we are mutating the record so clear the cache
1222+
12131223
_inject_response(
12141224
zc,
12151225
mock_incoming_msg([info.dns_service()]),

tests/test_handlers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer():
986986
a_record = info.dns_addresses()[0]
987987
a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
988988
assert not a_record.is_recent(current_time_millis())
989+
info._dns_address_cache = None # we are mutating the record so clear the cache
989990
zc.cache.async_add_records([a_record])
990991

991992
# With QU should respond to only unicast when the answer has been recently multicast

0 commit comments

Comments
 (0)