Skip to content

Commit 3c6b18c

Browse files
authored
feat: speed up responding to queries (#1275)
1 parent aa8fd1a commit 3c6b18c

11 files changed

Lines changed: 93 additions & 81 deletions

File tree

src/zeroconf/_dns.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ cdef class DNSNsec(DNSRecord):
125125

126126
cdef class DNSRRSet:
127127

128-
cdef cython.list _record_sets
128+
cdef cython.list _records
129129
cdef cython.dict _lookup
130130

131131
@cython.locals(other=DNSRecord)

src/zeroconf/_dns.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
174174
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
175175
"""Returns true if any answer in a message can suffice for the
176176
information held in this record."""
177-
answers = msg.answers
177+
answers = msg.answers()
178178
for record in answers:
179179
if self._suppressed_by_answer(record):
180180
return True
@@ -521,37 +521,34 @@ def __repr__(self) -> str:
521521
class DNSRRSet:
522522
"""A set of dns records with a lookup to get the ttl."""
523523

524-
__slots__ = ('_record_sets', '_lookup')
524+
__slots__ = ('_records', '_lookup')
525525

526-
def __init__(self, record_sets: List[List[DNSRecord]]) -> None:
526+
def __init__(self, records: List[DNSRecord]) -> None:
527527
"""Create an RRset from records sets."""
528-
self._record_sets = record_sets
529-
self._lookup: Optional[Dict[DNSRecord, float]] = None
528+
self._records = records
529+
self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None
530530

531531
@property
532-
def lookup(self) -> Dict[DNSRecord, float]:
532+
def lookup(self) -> Dict[DNSRecord, DNSRecord]:
533533
"""Return the lookup table."""
534534
return self._get_lookup()
535535

536536
def lookup_set(self) -> Set[DNSRecord]:
537537
"""Return the lookup table as aset."""
538538
return set(self._get_lookup())
539539

540-
def _get_lookup(self) -> Dict[DNSRecord, float]:
540+
def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]:
541541
"""Return the lookup table, building it if needed."""
542542
if self._lookup is None:
543543
# Build the hash table so we can lookup the record ttl
544-
self._lookup = {}
545-
for record_sets in self._record_sets:
546-
for record in record_sets:
547-
self._lookup[record] = record.ttl
544+
self._lookup = {record: record for record in self._records}
548545
return self._lookup
549546

550547
def suppresses(self, record: _DNSRecord) -> bool:
551548
"""Returns true if any answer in the rrset can suffice for the
552549
information held in this record."""
553550
lookup = self._get_lookup()
554-
other_ttl = lookup.get(record)
555-
if other_ttl is None:
551+
other = lookup.get(record)
552+
if other is None:
556553
return False
557-
return other_ttl > (record.ttl / 2)
554+
return other.ttl > (record.ttl / 2)

src/zeroconf/_handlers/query_handler.pxd

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL
2121
cdef class _QueryResponse:
2222

2323
cdef bint _is_probe
24-
cdef DNSIncoming _msg
24+
cdef cython.list _questions
2525
cdef float _now
2626
cdef DNSCache _cache
2727
cdef cython.dict _additionals
@@ -31,20 +31,20 @@ cdef class _QueryResponse:
3131
cdef cython.set _mcast_aggregate_last_second
3232

3333
@cython.locals(record=DNSRecord)
34-
cpdef add_qu_question_response(self, cython.dict answers)
34+
cdef add_qu_question_response(self, cython.dict answers)
3535

36-
cpdef add_ucast_question_response(self, cython.dict answers)
36+
cdef add_ucast_question_response(self, cython.dict answers)
3737

38-
@cython.locals(answer=DNSRecord)
39-
cpdef add_mcast_question_response(self, cython.dict answers)
38+
@cython.locals(answer=DNSRecord, question=DNSQuestion)
39+
cdef add_mcast_question_response(self, cython.dict answers)
4040

4141
@cython.locals(maybe_entry=DNSRecord)
4242
cdef bint _has_mcast_within_one_quarter_ttl(self, DNSRecord record)
4343

4444
@cython.locals(maybe_entry=DNSRecord)
4545
cdef bint _has_mcast_record_in_last_second(self, DNSRecord record)
4646

47-
cpdef answers(self)
47+
cdef QuestionAnswers answers(self)
4848

4949
cdef class QueryHandler:
5050

@@ -70,5 +70,7 @@ cdef class QueryHandler:
7070
answer_set=cython.dict,
7171
known_answers=DNSRRSet,
7272
known_answers_set=cython.set,
73+
is_probe=object,
74+
now=object
7375
)
7476
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)

src/zeroconf/_handlers/query_handler.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class _QueryResponse:
5555

5656
__slots__ = (
5757
"_is_probe",
58-
"_msg",
58+
"_questions",
5959
"_now",
6060
"_cache",
6161
"_additionals",
@@ -65,15 +65,11 @@ class _QueryResponse:
6565
"_mcast_aggregate_last_second",
6666
)
6767

68-
def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
68+
def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None:
6969
"""Build a query response."""
70-
self._is_probe = False
71-
for msg in msgs:
72-
if msg.is_probe:
73-
self._is_probe = True
74-
break
75-
self._msg = msgs[0]
76-
self._now = self._msg.now
70+
self._is_probe = is_probe
71+
self._questions = questions
72+
self._now = now
7773
self._cache = cache
7874
self._additionals: _AnswerWithAdditionalsType = {}
7975
self._ucast: Set[DNSRecord] = set()
@@ -107,10 +103,15 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No
107103

108104
if self._has_mcast_record_in_last_second(answer):
109105
self._mcast_aggregate_last_second.add(answer)
110-
elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES:
111-
self._mcast_now.add(answer)
112-
else:
113-
self._mcast_aggregate.add(answer)
106+
continue
107+
108+
if len(self._questions) == 1:
109+
question = self._questions[0]
110+
if question.type in _RESPOND_IMMEDIATE_TYPES:
111+
self._mcast_now.add(answer)
112+
continue
113+
114+
self._mcast_aggregate.add(answer)
114115

115116
def answers(
116117
self,
@@ -262,16 +263,26 @@ def async_response( # pylint: disable=unused-argument
262263
This function must be run in the event loop as it is not
263264
threadsafe.
264265
"""
265-
known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe])
266-
query_res = _QueryResponse(self.cache, msgs)
266+
answers: List[DNSRecord] = []
267+
is_probe = False
268+
msg = msgs[0]
269+
questions = msg.questions
270+
now = msg.now
271+
for msg in msgs:
272+
if not msg.is_probe():
273+
answers.extend(msg.answers())
274+
else:
275+
is_probe = True
276+
known_answers = DNSRRSet(answers)
277+
query_res = _QueryResponse(self.cache, questions, is_probe, now)
267278
known_answers_set: Optional[Set[DNSRecord]] = None
268279

269280
for msg in msgs:
270281
for question in msg.questions:
271282
if not question.unique: # unique and unicast are the same flag
272283
if not known_answers_set: # pragma: no branch
273284
known_answers_set = known_answers.lookup_set()
274-
self.question_history.add_question_at_time(question, msg.now, known_answers_set)
285+
self.question_history.add_question_at_time(question, now, known_answers_set)
275286
answer_set = self._answer_question(question, known_answers)
276287
if not ucast_source and question.unique: # unique and unicast are the same flag
277288
query_res.add_qu_question_response(answer_set)

src/zeroconf/_handlers/record_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
8787
now_float = now
8888
unique_types: Set[Tuple[str, int, int]] = set()
8989
cache = self.cache
90-
answers = msg.answers
90+
answers = msg.answers()
9191

9292
for record in answers:
9393
# Protect zeroconf from records that can cause denial of service.

src/zeroconf/_protocol/incoming.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ cdef class DNSIncoming:
7272

7373
cpdef is_query(self)
7474

75+
cpdef is_probe(self)
76+
77+
cpdef answers(self)
78+
7579
cpdef is_response(self)
7680

7781
@cython.locals(

src/zeroconf/_protocol/incoming.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def _log_exception_debug(cls, *logger_data: Any) -> None:
172172
log_exc_info = True
173173
log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info)
174174

175-
@property
176175
def answers(self) -> List[DNSRecord]:
177176
"""Answers in the packet."""
178177
if not self._did_read_others:
@@ -187,7 +186,6 @@ def answers(self) -> List[DNSRecord]:
187186
)
188187
return self._answers
189188

190-
@property
191189
def is_probe(self) -> bool:
192190
"""Returns true if this is a probe."""
193191
return self.num_authorities > 0
@@ -203,7 +201,7 @@ def __repr__(self) -> str:
203201
'n_auth=%s' % self.num_authorities,
204202
'n_add=%s' % self.num_additionals,
205203
'questions=%s' % self.questions,
206-
'answers=%s' % self.answers,
204+
'answers=%s' % self.answers(),
207205
]
208206
)
209207

tests/test_asyncio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
997997
"""Sends an outgoing packet."""
998998
pout = DNSIncoming(out.packets()[0])
999999
nonlocal nbr_answers
1000-
for answer in pout.answers:
1000+
for answer in pout.answers():
10011001
nbr_answers += 1
10021002
if not answer.ttl > expected_ttl / 2:
10031003
unexpected_ttl.set()

tests/test_dns.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def test_rrset_does_not_consider_ttl():
392392
longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same')
393393
shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same')
394394

395-
rrset = DNSRRSet([[longarec, shortaaaarec]])
395+
rrset = DNSRRSet([longarec, shortaaaarec])
396396

397397
assert rrset.suppresses(longarec)
398398
assert rrset.suppresses(shortarec)
@@ -404,7 +404,7 @@ def test_rrset_does_not_consider_ttl():
404404
mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same')
405405
shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')
406406

407-
rrset2 = DNSRRSet([[mediumarec]])
407+
rrset2 = DNSRRSet([mediumarec])
408408
assert not rrset2.suppresses(verylongarec)
409409
assert rrset2.suppresses(longarec)
410410
assert rrset2.suppresses(mediumarec)

tests/test_handlers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,8 +1425,8 @@ async def test_response_aggregation_timings(run_isolated):
14251425
outgoing = send_mock.call_args[0][0]
14261426
incoming = r.DNSIncoming(outgoing.packets()[0])
14271427
zc.record_manager.async_updates_from_response(incoming)
1428-
assert info.dns_pointer() in incoming.answers
1429-
assert info2.dns_pointer() in incoming.answers
1428+
assert info.dns_pointer() in incoming.answers()
1429+
assert info2.dns_pointer() in incoming.answers()
14301430
send_mock.reset_mock()
14311431

14321432
protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT))
@@ -1439,7 +1439,7 @@ async def test_response_aggregation_timings(run_isolated):
14391439
outgoing = send_mock.call_args[0][0]
14401440
incoming = r.DNSIncoming(outgoing.packets()[0])
14411441
zc.record_manager.async_updates_from_response(incoming)
1442-
assert info3.dns_pointer() in incoming.answers
1442+
assert info3.dns_pointer() in incoming.answers()
14431443
send_mock.reset_mock()
14441444

14451445
# Because the response was sent in the last second we need to make
@@ -1461,7 +1461,7 @@ async def test_response_aggregation_timings(run_isolated):
14611461
assert len(calls) == 1
14621462
outgoing = send_mock.call_args[0][0]
14631463
incoming = r.DNSIncoming(outgoing.packets()[0])
1464-
assert info.dns_pointer() in incoming.answers
1464+
assert info.dns_pointer() in incoming.answers()
14651465

14661466
await aiozc.async_close()
14671467

@@ -1501,7 +1501,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
15011501
outgoing = send_mock.call_args[0][0]
15021502
incoming = r.DNSIncoming(outgoing.packets()[0])
15031503
zc.record_manager.async_updates_from_response(incoming)
1504-
assert info2.dns_pointer() in incoming.answers
1504+
assert info2.dns_pointer() in incoming.answers()
15051505

15061506
send_mock.reset_mock()
15071507
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
@@ -1511,7 +1511,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
15111511
outgoing = send_mock.call_args[0][0]
15121512
incoming = r.DNSIncoming(outgoing.packets()[0])
15131513
zc.record_manager.async_updates_from_response(incoming)
1514-
assert info2.dns_pointer() in incoming.answers
1514+
assert info2.dns_pointer() in incoming.answers()
15151515

15161516
send_mock.reset_mock()
15171517
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
@@ -1534,7 +1534,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
15341534
outgoing = send_mock.call_args[0][0]
15351535
incoming = r.DNSIncoming(outgoing.packets()[0])
15361536
zc.record_manager.async_updates_from_response(incoming)
1537-
assert info2.dns_pointer() in incoming.answers
1537+
assert info2.dns_pointer() in incoming.answers()
15381538

15391539

15401540
@pytest.mark.asyncio

0 commit comments

Comments
 (0)