Skip to content

Commit ac081cf

Browse files
authored
feat: speed up decoding incoming packets (#1256)
1 parent aebabd9 commit ac081cf

2 files changed

Lines changed: 25 additions & 16 deletions

File tree

src/zeroconf/_protocol/incoming.pxd

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ cdef cython.uint MAX_DNS_LABELS
77
cdef cython.uint DNS_COMPRESSION_POINTER_LEN
88
cdef cython.uint MAX_NAME_LENGTH
99

10-
cdef object current_time_millis
11-
1210
cdef cython.uint _TYPE_A
1311
cdef cython.uint _TYPE_CNAME
1412
cdef cython.uint _TYPE_PTR
@@ -43,6 +41,7 @@ from .._dns cimport (
4341
DNSService,
4442
DNSText,
4543
)
44+
from .._utils.time cimport current_time_millis
4645

4746

4847
cdef class DNSIncoming:
@@ -62,6 +61,7 @@ cdef class DNSIncoming:
6261
cdef public cython.uint num_additionals
6362
cdef public object valid
6463
cdef public object now
64+
cdef cython.float _now_float
6565
cdef public object scope_id
6666
cdef public object source
6767

@@ -79,7 +79,9 @@ cdef class DNSIncoming:
7979
label_idx=cython.uint,
8080
length=cython.uint,
8181
link=cython.uint,
82-
link_data=cython.uint
82+
link_data=cython.uint,
83+
link_py_int=object,
84+
linked_labels=cython.list
8385
)
8486
cdef _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
8587

@@ -95,9 +97,12 @@ cdef class DNSIncoming:
9597

9698
cdef _read_questions(self)
9799

98-
cdef bytes _read_character_string(self)
100+
@cython.locals(
101+
length=cython.uint,
102+
)
103+
cdef str _read_character_string(self)
99104

100-
cdef _read_string(self, unsigned int length)
105+
cdef bytes _read_string(self, unsigned int length)
101106

102107
@cython.locals(
103108
name_start=cython.uint

src/zeroconf/_protocol/incoming.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class DNSIncoming:
8989
'num_additionals',
9090
'valid',
9191
'now',
92+
'_now_float',
9293
'scope_id',
9394
'source',
9495
)
@@ -116,6 +117,7 @@ def __init__(
116117
self.valid = False
117118
self._did_read_others = False
118119
self.now = now or current_time_millis()
120+
self._now_float = self.now
119121
self.source = source
120122
self.scope_id = scope_id
121123
try:
@@ -226,11 +228,13 @@ def _read_questions(self) -> None:
226228
question = DNSQuestion(name, type_, class_)
227229
self.questions.append(question)
228230

229-
def _read_character_string(self) -> bytes:
231+
def _read_character_string(self) -> str:
230232
"""Reads a character string from the packet"""
231233
length = self.data[self.offset]
232234
self.offset += 1
233-
return self._read_string(length)
235+
info = self.data[self.offset : self.offset + length].decode('utf-8', 'replace')
236+
self.offset += length
237+
return info
234238

235239
def _read_string(self, length: _int) -> bytes:
236240
"""Reads a string of a given length from the packet"""
@@ -273,7 +277,7 @@ def _read_record(
273277
"""Read known records types and skip unknown ones."""
274278
if type_ == _TYPE_A:
275279
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(4))
276-
dns_address.created = self.now
280+
dns_address.created = self._now_float
277281
return dns_address
278282
if type_ in (_TYPE_CNAME, _TYPE_PTR):
279283
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
@@ -299,13 +303,13 @@ def _read_record(
299303
type_,
300304
class_,
301305
ttl,
302-
self._read_character_string().decode('utf-8', 'replace'),
303-
self._read_character_string().decode('utf-8', 'replace'),
306+
self._read_character_string(),
307+
self._read_character_string(),
304308
self.now,
305309
)
306310
if type_ == _TYPE_AAAA:
307311
dns_address = DNSAddress(domain, type_, class_, ttl, self._read_string(16))
308-
dns_address.created = self.now
312+
dns_address.created = self._now_float
309313
dns_address.scope_id = self.scope_id
310314
return dns_address
311315
if type_ == _TYPE_NSEC:
@@ -377,7 +381,7 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
377381
# We have a DNS compression pointer
378382
link_data = self.data[off + 1]
379383
link = (length & 0x3F) * 256 + link_data
380-
lint_int = int(link)
384+
link_py_int = link
381385
if link > self._data_len:
382386
raise IncomingDecodeError(
383387
f"DNS compression pointer at {off} points to {link} beyond packet from {self.source}"
@@ -386,16 +390,16 @@ def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers:
386390
raise IncomingDecodeError(
387391
f"DNS compression pointer at {off} points to itself from {self.source}"
388392
)
389-
if lint_int in seen_pointers:
393+
if link_py_int in seen_pointers:
390394
raise IncomingDecodeError(
391395
f"DNS compression pointer at {off} was seen again from {self.source}"
392396
)
393-
linked_labels = self.name_cache.get(lint_int)
397+
linked_labels = self.name_cache.get(link_py_int)
394398
if not linked_labels:
395399
linked_labels = []
396-
seen_pointers.add(lint_int)
400+
seen_pointers.add(link_py_int)
397401
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
398-
self.name_cache[lint_int] = linked_labels
402+
self.name_cache[link_py_int] = linked_labels
399403
labels.extend(linked_labels)
400404
if len(labels) > MAX_DNS_LABELS:
401405
raise IncomingDecodeError(

0 commit comments

Comments
 (0)