Skip to content

Commit 0d60b61

Browse files
authored
feat: speed up incoming packet reader (#1314)
1 parent bfe4c24 commit 0d60b61

3 files changed

Lines changed: 59 additions & 54 deletions

File tree

src/zeroconf/_dns.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ def __init__(
244244
class_: int,
245245
ttl: int,
246246
address: bytes,
247-
*,
248247
scope_id: Optional[int] = None,
249248
created: Optional[float] = None,
250249
) -> None:

src/zeroconf/_protocol/incoming.pxd

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ cdef cython.uint _FLAGS_TC
2121
cdef cython.uint _FLAGS_QR_QUERY
2222
cdef cython.uint _FLAGS_QR_RESPONSE
2323

24-
cdef object UNPACK_3H
25-
cdef object UNPACK_6H
26-
cdef object UNPACK_HH
27-
cdef object UNPACK_HHiH
28-
2924
cdef object DECODE_EXCEPTIONS
3025

3126
cdef object IncomingDecodeError
@@ -62,7 +57,6 @@ cdef class DNSIncoming:
6257
cdef cython.uint _num_additionals
6358
cdef public bint valid
6459
cdef public object now
65-
cdef cython.float _now_float
6660
cdef public object scope_id
6761
cdef public object source
6862
cdef bint _has_qu_question
@@ -81,49 +75,53 @@ cdef class DNSIncoming:
8175
cpdef bint is_response(self)
8276

8377
@cython.locals(
84-
off=cython.uint,
85-
label_idx=cython.uint,
86-
length=cython.uint,
87-
link=cython.uint,
88-
link_data=cython.uint,
78+
off="unsigned int",
79+
label_idx="unsigned int",
80+
length="unsigned int",
81+
link="unsigned int",
82+
link_data="unsigned int",
8983
link_py_int=object,
9084
linked_labels=cython.list
9185
)
92-
cdef cython.uint _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
86+
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
9387

88+
@cython.locals(offset="unsigned int")
9489
cdef _read_header(self)
9590

9691
cdef _initial_parse(self)
9792

9893
@cython.locals(
99-
end=cython.uint,
100-
length=cython.uint
94+
end="unsigned int",
95+
length="unsigned int",
96+
offset="unsigned int"
10197
)
10298
cdef _read_others(self)
10399

100+
@cython.locals(offset="unsigned int")
104101
cdef _read_questions(self)
105102

106103
@cython.locals(
107-
length=cython.uint,
104+
length="unsigned int",
108105
)
109106
cdef str _read_character_string(self)
110107

111108
cdef bytes _read_string(self, unsigned int length)
112109

113110
@cython.locals(
114-
name_start=cython.uint
111+
name_start="unsigned int",
112+
offset="unsigned int"
115113
)
116-
cdef _read_record(self, object domain, unsigned int type_, object class_, object ttl, unsigned int length)
114+
cdef _read_record(self, object domain, unsigned int type_, unsigned int class_, unsigned int ttl, unsigned int length)
117115

118116
@cython.locals(
119-
offset=cython.uint,
120-
offset_plus_one=cython.uint,
121-
offset_plus_two=cython.uint,
122-
window=cython.uint,
123-
bit=cython.uint,
124-
byte=cython.uint,
125-
i=cython.uint,
126-
bitmap_length=cython.uint,
117+
offset="unsigned int",
118+
offset_plus_one="unsigned int",
119+
offset_plus_two="unsigned int",
120+
window="unsigned int",
121+
bit="unsigned int",
122+
byte="unsigned int",
123+
i="unsigned int",
124+
bitmap_length="unsigned int",
127125
)
128126
cdef _read_bitmap(self, unsigned int end)
129127

src/zeroconf/_protocol/incoming.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@
6060

6161
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)
6262

63-
UNPACK_3H = struct.Struct(b'!3H').unpack_from
64-
UNPACK_6H = struct.Struct(b'!6H').unpack_from
65-
UNPACK_HH = struct.Struct(b'!HH').unpack_from
66-
UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from
6763

6864
_seen_logs: Dict[str, Union[int, tuple]] = {}
6965
_str = str
@@ -90,7 +86,6 @@ class DNSIncoming:
9086
'_num_additionals',
9187
'valid',
9288
'now',
93-
'_now_float',
9489
'scope_id',
9590
'source',
9691
'_has_qu_question',
@@ -120,7 +115,6 @@ def __init__(
120115
self.valid = False
121116
self._did_read_others = False
122117
self.now = now or current_time_millis()
123-
self._now_float = self.now
124118
self.source = source
125119
self.scope_id = scope_id
126120
self._has_qu_question = False
@@ -230,23 +224,28 @@ def __repr__(self) -> str:
230224

231225
def _read_header(self) -> None:
232226
"""Reads header portion of packet"""
233-
(
234-
self.id,
235-
self.flags,
236-
self._num_questions,
237-
self._num_answers,
238-
self._num_authorities,
239-
self._num_additionals,
240-
) = UNPACK_6H(self.data)
227+
view = self.view
228+
offset = self.offset
241229
self.offset += 12
230+
# The header has 6 unsigned shorts in network order
231+
self.id = view[offset] << 8 | view[offset + 1]
232+
self.flags = view[offset + 2] << 8 | view[offset + 3]
233+
self._num_questions = view[offset + 4] << 8 | view[offset + 5]
234+
self._num_answers = view[offset + 6] << 8 | view[offset + 7]
235+
self._num_authorities = view[offset + 8] << 8 | view[offset + 9]
236+
self._num_additionals = view[offset + 10] << 8 | view[offset + 11]
242237

243238
def _read_questions(self) -> None:
244239
"""Reads questions section of packet"""
240+
view = self.view
245241
questions = self._questions
246242
for _ in range(self._num_questions):
247243
name = self._read_name()
248-
type_, class_ = UNPACK_HH(self.data, self.offset)
244+
offset = self.offset
249245
self.offset += 4
246+
# The question has 2 unsigned shorts in network order
247+
type_ = view[offset] << 8 | view[offset + 1]
248+
class_ = view[offset + 2] << 8 | view[offset + 3]
250249
question = DNSQuestion(name, type_, class_)
251250
if question.unique: # QU questions use the same bit as unique
252251
self._has_qu_question = True
@@ -270,11 +269,18 @@ def _read_others(self) -> None:
270269
"""Reads the answers, authorities and additionals section of the
271270
packet"""
272271
self._did_read_others = True
272+
view = self.view
273273
n = self._num_answers + self._num_authorities + self._num_additionals
274274
for _ in range(n):
275275
domain = self._read_name()
276-
type_, class_, ttl, length = UNPACK_HHiH(self.data, self.offset)
276+
offset = self.offset
277277
self.offset += 10
278+
# type_, class_ and length are unsigned shorts in network order
279+
# ttl is an unsigned long in network order https://www.rfc-editor.org/errata/eid2130
280+
type_ = view[offset] << 8 | view[offset + 1]
281+
class_ = view[offset + 2] << 8 | view[offset + 3]
282+
ttl = view[offset + 4] << 24 | view[offset + 5] << 16 | view[offset + 6] << 8 | view[offset + 7]
283+
length = view[offset + 8] << 8 | view[offset + 9]
278284
end = self.offset + length
279285
rec = None
280286
try:
@@ -300,16 +306,19 @@ def _read_record(
300306
) -> Optional[DNSRecord]:
301307
"""Read known records types and skip unknown ones."""
302308
if type_ == _TYPE_A:
303-
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
304-
dns_address.created = self._now_float
305-
return dns_address
309+
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now)
306310
if type_ in (_TYPE_CNAME, _TYPE_PTR):
307311
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
308312
if type_ == _TYPE_TXT:
309313
return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now)
310314
if type_ == _TYPE_SRV:
311-
priority, weight, port = UNPACK_3H(self.data, self.offset)
315+
view = self.view
316+
offset = self.offset
312317
self.offset += 6
318+
# The SRV record has 3 unsigned shorts in network order
319+
priority = view[offset] << 8 | view[offset + 1]
320+
weight = view[offset + 2] << 8 | view[offset + 3]
321+
port = view[offset + 4] << 8 | view[offset + 5]
313322
return DNSService(
314323
domain,
315324
type_,
@@ -332,10 +341,7 @@ def _read_record(
332341
self.now,
333342
)
334343
if type_ == _TYPE_AAAA:
335-
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
336-
dns_address.created = self._now_float
337-
dns_address.scope_id = self.scope_id
338-
return dns_address
344+
return DNSAddress(domain, type_, class_, ttl, self._read_string(16), self.scope_id, self.now)
339345
if type_ == _TYPE_NSEC:
340346
name_start = self.offset
341347
return DNSNsec(
@@ -356,12 +362,13 @@ def _read_record(
356362
def _read_bitmap(self, end: _int) -> List[int]:
357363
"""Reads an NSEC bitmap from the packet."""
358364
rdtypes = []
365+
view = self.view
359366
while self.offset < end:
360367
offset = self.offset
361368
offset_plus_one = offset + 1
362369
offset_plus_two = offset + 2
363-
window = self.view[offset]
364-
bitmap_length = self.view[offset_plus_one]
370+
window = view[offset]
371+
bitmap_length = view[offset_plus_one]
365372
bitmap_end = offset_plus_two + bitmap_length
366373
for i, byte in enumerate(self.data[offset_plus_two:bitmap_end]):
367374
for bit in range(0, 8):
@@ -386,8 +393,9 @@ def _read_name(self) -> str:
386393

387394
def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int:
388395
# This is a tight loop that is called frequently, small optimizations can make a difference.
396+
view = self.view
389397
while off < self._data_len:
390-
length = self.view[off]
398+
length = view[off]
391399
if length == 0:
392400
return off + DNS_COMPRESSION_HEADER_LEN
393401

@@ -403,7 +411,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
403411
)
404412

405413
# We have a DNS compression pointer
406-
link_data = self.view[off + 1]
414+
link_data = view[off + 1]
407415
link = (length & 0x3F) * 256 + link_data
408416
link_py_int = link
409417
if link > self._data_len:

0 commit comments

Comments
 (0)