Skip to content

Commit 00c439a

Browse files
authored
feat: improve performance of constructing outgoing queries (#1267)
1 parent aed6391 commit 00c439a

4 files changed

Lines changed: 52 additions & 20 deletions

File tree

src/zeroconf/_handlers/answers.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import cython
33

4+
from .._dns cimport DNSRecord
45
from .._protocol.outgoing cimport DNSOutgoing
56

67

@@ -10,7 +11,8 @@ cdef object NAME_GETTER
1011
cpdef construct_outgoing_multicast_answers(cython.dict answers)
1112

1213
cpdef construct_outgoing_unicast_answers(
13-
cython.dict answers, object ucast_source, cython.list questions, object id_
14+
cython.dict answers, bint ucast_source, cython.list questions, object id_
1415
)
1516

17+
@cython.locals(answer=DNSRecord, additionals=cython.set, additional=DNSRecord)
1618
cdef _add_answers_additionals(DNSOutgoing out, cython.dict answers)

src/zeroconf/_handlers/answers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsTy
8282
# overall size of the outgoing response via name compression
8383
for answer in sorted(answers, key=NAME_GETTER):
8484
out.add_answer_at_time(answer, 0)
85-
for additional in answers[answer]:
85+
additionals = answers[answer]
86+
for additional in additionals:
8687
if additional not in sending:
8788
out.add_additional_answer(additional)
8889
sending.add(additional)

src/zeroconf/_protocol/outgoing.pxd

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,25 @@ cdef object PACK_BYTE
2121
cdef object PACK_SHORT
2222
cdef object PACK_LONG
2323

24+
cdef object STATE_INIT
25+
cdef object STATE_FINISHED
26+
27+
cdef object LOGGING_IS_ENABLED_FOR
28+
cdef object LOGGING_DEBUG
29+
30+
cdef cython.tuple BYTE_TABLE
31+
2432
cdef class DNSOutgoing:
2533

2634
cdef public unsigned int flags
27-
cdef public object finished
35+
cdef public bint finished
2836
cdef public object id
2937
cdef public bint multicast
3038
cdef public cython.list packets_data
3139
cdef public cython.dict names
3240
cdef public cython.list data
3341
cdef public unsigned int size
34-
cdef public object allow_long
42+
cdef public bint allow_long
3543
cdef public object state
3644
cdef public cython.list questions
3745
cdef public cython.list answers
@@ -48,18 +56,21 @@ cdef class DNSOutgoing:
4856

4957
cdef _write_int(self, object value)
5058

51-
cdef _write_question(self, DNSQuestion question)
59+
cdef cython.bint _write_question(self, DNSQuestion question)
5260

5361
@cython.locals(
5462
d=cython.bytes,
5563
data_view=cython.list,
5664
length=cython.uint
5765
)
58-
cdef _write_record(self, DNSRecord record, object now)
66+
cdef cython.bint _write_record(self, DNSRecord record, object now)
5967

6068
cdef _write_record_class(self, DNSEntry record)
6169

62-
cdef _check_data_limit_or_rollback(self, object start_data_length, object start_size)
70+
@cython.locals(
71+
start_size_int=object
72+
)
73+
cdef cython.bint _check_data_limit_or_rollback(self, cython.uint start_data_length, cython.uint start_size)
6374

6475
cdef _write_questions_from_offset(self, object questions_offset)
6576

@@ -74,6 +85,9 @@ cdef class DNSOutgoing:
7485
@cython.locals(
7586
labels=cython.list,
7687
label=cython.str,
88+
index=cython.uint,
89+
start_size=cython.uint,
90+
name_length=cython.uint,
7791
)
7892
cpdef write_name(self, cython.str name)
7993

@@ -103,6 +117,7 @@ cdef class DNSOutgoing:
103117

104118
cpdef add_answer(self, DNSIncoming inp, DNSRecord record)
105119

120+
@cython.locals(now_float=cython.float)
106121
cpdef add_answer_at_time(self, DNSRecord record, object now)
107122

108123
cpdef add_authorative_answer(self, DNSPointer record)

src/zeroconf/_protocol/outgoing.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,21 @@
5353
PACK_SHORT = Struct('>H').pack
5454
PACK_LONG = Struct('>L').pack
5555

56+
BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))
57+
5658

5759
class State(enum.Enum):
5860
init = 0
5961
finished = 1
6062

6163

64+
STATE_INIT = State.init
65+
STATE_FINISHED = State.finished
66+
67+
LOGGING_IS_ENABLED_FOR = log.isEnabledFor
68+
LOGGING_DEBUG = logging.DEBUG
69+
70+
6271
class DNSOutgoing:
6372

6473
"""Object representation of an outgoing packet"""
@@ -93,7 +102,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
93102
self.size: int = _DNS_PACKET_HEADER_LEN
94103
self.allow_long: bool = True
95104

96-
self.state = State.init
105+
self.state = STATE_INIT
97106

98107
self.questions: List[DNSQuestion] = []
99108
self.answers: List[Tuple[DNSRecord, float]] = []
@@ -137,7 +146,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
137146

138147
def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
139148
"""Adds an answer if it does not expire by a certain time"""
140-
if record is not None and (now == 0 or not record.is_expired(now)):
149+
now_float = now
150+
if record is not None and (now_float == 0 or not record.is_expired(now_float)):
141151
self.answers.append((record, now))
142152

143153
def add_authorative_answer(self, record: DNSPointer) -> None:
@@ -207,7 +217,7 @@ def add_question_or_all_cache(
207217

208218
def _write_byte(self, value: int_) -> None:
209219
"""Writes a single byte to the packet"""
210-
self.data.append(PACK_BYTE(value))
220+
self.data.append(BYTE_TABLE[value])
211221
self.size += 1
212222

213223
def _insert_short_at_start(self, value: int_) -> None:
@@ -267,7 +277,7 @@ def write_name(self, name: str_) -> None:
267277
"""
268278

269279
# split name into each label
270-
name_length = None
280+
name_length = 0
271281
if name.endswith('.'):
272282
name = name[: len(name) - 1]
273283
labels = name.split('.')
@@ -276,14 +286,14 @@ def write_name(self, name: str_) -> None:
276286
start_size = self.size
277287
for count in range(len(labels)):
278288
label = name if count == 0 else '.'.join(labels[count:])
279-
index = self.names.get(label)
289+
index = self.names.get(label, 0)
280290
if index:
281291
# If part of the name already exists in the packet,
282292
# create a pointer to it
283293
self._write_byte((index >> 8) | 0xC0)
284294
self._write_byte(index & 0xFF)
285295
return
286-
if name_length is None:
296+
if name_length == 0:
287297
name_length = len(name.encode('utf-8'))
288298
self.names[label] = start_size + name_length - len(label.encode('utf-8'))
289299
self._write_utf(labels[count])
@@ -293,7 +303,8 @@ def write_name(self, name: str_) -> None:
293303

294304
def _write_question(self, question: DNSQuestion_) -> bool:
295305
"""Writes a question to the packet"""
296-
start_data_length, start_size = len(self.data), self.size
306+
start_data_length = len(self.data)
307+
start_size = self.size
297308
self.write_name(question.name)
298309
self.write_short(question.type)
299310
self._write_record_class(question)
@@ -314,7 +325,8 @@ def _write_record(self, record: DNSRecord_, now: float_) -> bool:
314325
"""Writes a record (answer, authoritative answer, additional) to
315326
the packet. Returns True on success, or False if we did not
316327
because the packet because the record does not fit."""
317-
start_data_length, start_size = len(self.data), self.size
328+
start_data_length = len(self.data)
329+
start_size = self.size
318330
self.write_name(record.name)
319331
self.write_short(record.type)
320332
self._write_record_class(record)
@@ -339,11 +351,13 @@ def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int
339351
if self.size <= len_limit:
340352
return True
341353

342-
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
354+
if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch
355+
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
343356
del self.data[start_data_length:]
344357
self.size = start_size
345358

346-
rollback_names = [name for name, idx in self.names.items() if idx >= start_size]
359+
start_size_int = start_size
360+
rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int]
347361
for name in rollback_names:
348362
del self.names[name]
349363
return False
@@ -395,7 +409,7 @@ def packets(self) -> List[bytes]:
395409
return self._packets()
396410

397411
def _packets(self) -> List[bytes]:
398-
if self.state == State.finished:
412+
if self.state == STATE_FINISHED:
399413
return self.packets_data
400414

401415
questions_offset = 0
@@ -404,7 +418,7 @@ def _packets(self) -> List[bytes]:
404418
additional_offset = 0
405419
# we have to at least write out the question
406420
first_time = True
407-
debug_enable = log.isEnabledFor(logging.DEBUG)
421+
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)
408422

409423
while first_time or self._has_more_to_add(
410424
questions_offset, answer_offset, authority_offset, additional_offset
@@ -476,5 +490,5 @@ def _packets(self) -> List[bytes]:
476490
):
477491
log.warning("packets() made no progress adding records; returning")
478492
break
479-
self.state = State.finished
493+
self.state = STATE_FINISHED
480494
return self.packets_data

0 commit comments

Comments
 (0)