Skip to content

Commit 3a25ff7

Browse files
authored
feat: optimize equality checks for DNS records (#1120)
1 parent 255a884 commit 3a25ff7

2 files changed

Lines changed: 104 additions & 72 deletions

File tree

src/zeroconf/_dns.pxd

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
import cython
23

34

45
cdef object _LEN_BYTE
@@ -12,67 +13,90 @@ cdef object _EXPIRE_FULL_TIME_MS
1213
cdef object _EXPIRE_STALE_TIME_MS
1314
cdef object _RECENT_TIME_MS
1415

16+
cdef object _CLASS_UNIQUE
17+
cdef object _CLASS_MASK
1518

1619
cdef class DNSEntry:
1720

18-
cdef public key
19-
cdef public name
20-
cdef public type
21-
cdef public class_
22-
cdef public unique
21+
cdef public object key
22+
cdef public object name
23+
cdef public object type
24+
cdef public object class_
25+
cdef public object unique
26+
27+
cdef _dns_entry_matches(self, DNSEntry other)
2328

2429
cdef class DNSQuestion(DNSEntry):
2530

26-
cdef public _hash
31+
cdef public cython.int _hash
2732

2833
cdef class DNSRecord(DNSEntry):
2934

30-
cdef public ttl
31-
cdef public created
35+
cdef public object ttl
36+
cdef public object created
37+
38+
cdef _suppressed_by_answer(self, DNSRecord answer)
39+
3240

3341
cdef class DNSAddress(DNSRecord):
3442

35-
cdef public _hash
36-
cdef public address
37-
cdef public scope_id
43+
cdef public cython.int _hash
44+
cdef public object address
45+
cdef public object scope_id
46+
47+
cdef _eq(self, DNSAddress other)
3848

3949

4050
cdef class DNSHinfo(DNSRecord):
4151

42-
cdef public _hash
43-
cdef public cpu
44-
cdef public os
52+
cdef public cython.int _hash
53+
cdef public object cpu
54+
cdef public object os
55+
56+
cdef _eq(self, DNSHinfo other)
4557

4658

4759
cdef class DNSPointer(DNSRecord):
4860

49-
cdef public _hash
50-
cdef public alias
61+
cdef public cython.int _hash
62+
cdef public object alias
63+
64+
cdef _eq(self, DNSPointer other)
65+
5166

5267
cdef class DNSText(DNSRecord):
5368

54-
cdef public _hash
55-
cdef public text
69+
cdef public cython.int _hash
70+
cdef public object text
71+
72+
cdef _eq(self, DNSText other)
73+
5674

5775
cdef class DNSService(DNSRecord):
5876

59-
cdef public _hash
60-
cdef public priority
61-
cdef public weight
62-
cdef public port
63-
cdef public server
64-
cdef public server_key
77+
cdef public cython.int _hash
78+
cdef public object priority
79+
cdef public object weight
80+
cdef public object port
81+
cdef public object server
82+
cdef public object server_key
83+
84+
cdef _eq(self, DNSService other)
85+
6586

6687
cdef class DNSNsec(DNSRecord):
6788

68-
cdef public _hash
69-
cdef public next_name
70-
cdef public rdtypes
89+
cdef public cython.int _hash
90+
cdef public object next_name
91+
cdef public cython.list rdtypes
92+
93+
cdef _eq(self, DNSNsec other)
7194

7295

7396
cdef class DNSRRSet:
7497

7598
cdef _records
76-
cdef _lookup
99+
cdef cython.dict _lookup
77100

78-
cdef _dns_entry_matches(DNSEntry entry, object key, object type_, object class_)
101+
@cython.locals(other=DNSRecord)
102+
cpdef suppresses(self, DNSRecord record)

src/zeroconf/_dns.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ def __init__(self, name: str, type_: int, class_: int) -> None:
7272
self.class_ = class_ & _CLASS_MASK
7373
self.unique = (class_ & _CLASS_UNIQUE) != 0
7474

75+
def _dns_entry_matches(self, other) -> bool: # type: ignore[no-untyped-def]
76+
return self.key == other.key and self.type == other.type and self.class_ == other.class_
77+
7578
def __eq__(self, other: Any) -> bool:
7679
"""Equality test on key (lowercase name), type, and class"""
77-
return _dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
80+
return isinstance(other, DNSEntry) and self._dns_entry_matches(other)
7881

7982
@staticmethod
8083
def get_class_(class_: int) -> str:
@@ -117,7 +120,7 @@ def __hash__(self) -> int:
117120

118121
def __eq__(self, other: Any) -> bool:
119122
"""Tests equality on dns question."""
120-
return isinstance(other, DNSQuestion) and _dns_entry_matches(other, self.key, self.type, self.class_)
123+
return isinstance(other, DNSQuestion) and self._dns_entry_matches(other)
121124

122125
@property
123126
def max_size(self) -> int:
@@ -169,9 +172,9 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
169172
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
170173
"""Returns true if any answer in a message can suffice for the
171174
information held in this record."""
172-
return any(self.suppressed_by_answer(record) for record in msg.answers)
175+
return any(self._suppressed_by_answer(record) for record in msg.answers)
173176

174-
def suppressed_by_answer(self, other: 'DNSRecord') -> bool:
177+
def _suppressed_by_answer(self, other) -> bool: # type: ignore[no-untyped-def]
175178
"""Returns true if another record has same name, type and class,
176179
and if its TTL is at least half of this record's."""
177180
return self == other and other.ttl > (self.ttl / 2)
@@ -246,11 +249,13 @@ def write(self, out: 'DNSOutgoing') -> None:
246249

247250
def __eq__(self, other: Any) -> bool:
248251
"""Tests equality on address"""
252+
return isinstance(other, DNSAddress) and self._eq(other)
253+
254+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
249255
return (
250-
isinstance(other, DNSAddress)
251-
and self.address == other.address
256+
self.address == other.address
252257
and self.scope_id == other.scope_id
253-
and _dns_entry_matches(other, self.key, self.type, self.class_)
258+
and self._dns_entry_matches(other)
254259
)
255260

256261
def __hash__(self) -> int:
@@ -289,13 +294,12 @@ def write(self, out: 'DNSOutgoing') -> None:
289294
out.write_character_string(self.os.encode('utf-8'))
290295

291296
def __eq__(self, other: Any) -> bool:
292-
"""Tests equality on cpu and os"""
293-
return (
294-
isinstance(other, DNSHinfo)
295-
and self.cpu == other.cpu
296-
and self.os == other.os
297-
and _dns_entry_matches(other, self.key, self.type, self.class_)
298-
)
297+
"""Tests equality on cpu and os."""
298+
return isinstance(other, DNSHinfo) and self._eq(other)
299+
300+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
301+
"""Tests equality on cpu and os."""
302+
return self.cpu == other.cpu and self.os == other.os and self._dns_entry_matches(other)
299303

300304
def __hash__(self) -> int:
301305
"""Hash to compare like DNSHinfo."""
@@ -334,12 +338,12 @@ def write(self, out: 'DNSOutgoing') -> None:
334338
out.write_name(self.alias)
335339

336340
def __eq__(self, other: Any) -> bool:
337-
"""Tests equality on alias"""
338-
return (
339-
isinstance(other, DNSPointer)
340-
and self.alias == other.alias
341-
and _dns_entry_matches(other, self.key, self.type, self.class_)
342-
)
341+
"""Tests equality on alias."""
342+
return isinstance(other, DNSPointer) and self._eq(other)
343+
344+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
345+
"""Tests equality on alias."""
346+
return self.alias == other.alias and self._dns_entry_matches(other)
343347

344348
def __hash__(self) -> int:
345349
"""Hash to compare like DNSPointer."""
@@ -373,12 +377,12 @@ def __hash__(self) -> int:
373377
return self._hash
374378

375379
def __eq__(self, other: Any) -> bool:
376-
"""Tests equality on text"""
377-
return (
378-
isinstance(other, DNSText)
379-
and self.text == other.text
380-
and _dns_entry_matches(other, self.key, self.type, self.class_)
381-
)
380+
"""Tests equality on text."""
381+
return isinstance(other, DNSText) and self._eq(other)
382+
383+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
384+
"""Tests equality on text."""
385+
return self.text == other.text and self._dns_entry_matches(other)
382386

383387
def __repr__(self) -> str:
384388
"""String representation"""
@@ -422,13 +426,16 @@ def write(self, out: 'DNSOutgoing') -> None:
422426

423427
def __eq__(self, other: Any) -> bool:
424428
"""Tests equality on priority, weight, port and server"""
429+
return isinstance(other, DNSService) and self._eq(other)
430+
431+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
432+
"""Tests equality on priority, weight, port and server."""
425433
return (
426-
isinstance(other, DNSService)
427-
and self.priority == other.priority
434+
self.priority == other.priority
428435
and self.weight == other.weight
429436
and self.port == other.port
430437
and self.server == other.server
431-
and _dns_entry_matches(other, self.key, self.type, self.class_)
438+
and self._dns_entry_matches(other)
432439
)
433440

434441
def __hash__(self) -> int:
@@ -478,12 +485,15 @@ def write(self, out: 'DNSOutgoing') -> None:
478485
out.write_string(out_bytes)
479486

480487
def __eq__(self, other: Any) -> bool:
481-
"""Tests equality on cpu and os"""
488+
"""Tests equality on next_name and rdtypes."""
489+
return isinstance(other, DNSNsec) and self._eq(other)
490+
491+
def _eq(self, other) -> bool: # type: ignore[no-untyped-def]
492+
"""Tests equality on next_name and rdtypes."""
482493
return (
483-
isinstance(other, DNSNsec)
484-
and self.next_name == other.next_name
494+
self.next_name == other.next_name
485495
and self.rdtypes == other.rdtypes
486-
and _dns_entry_matches(other, self.key, self.type, self.class_)
496+
and self._dns_entry_matches(other)
487497
)
488498

489499
def __hash__(self) -> int:
@@ -497,6 +507,9 @@ def __repr__(self) -> str:
497507
)
498508

499509

510+
_DNSRecord = DNSRecord
511+
512+
500513
class DNSRRSet:
501514
"""A set of dns records independent of the ttl."""
502515

@@ -514,20 +527,15 @@ def lookup(self) -> Dict[DNSRecord, DNSRecord]:
514527
self._lookup = {record: record for record in self._records}
515528
return self._lookup
516529

517-
def suppresses(self, record: DNSRecord) -> bool:
530+
def suppresses(self, record: _DNSRecord) -> bool:
518531
"""Returns true if any answer in the rrset can suffice for the
519532
information held in this record."""
520-
other = self.lookup.get(record)
533+
if self._lookup is None:
534+
other = self.lookup.get(record)
535+
else:
536+
other = self._lookup.get(record)
521537
return bool(other and other.ttl > (record.ttl / 2))
522538

523539
def __contains__(self, record: DNSRecord) -> bool:
524540
"""Returns true if the rrset contains the record."""
525541
return record in self.lookup
526-
527-
528-
_DNSEntry = DNSEntry
529-
_str = str
530-
531-
532-
def _dns_entry_matches(entry: _DNSEntry, key: _str, type_: int, class_: int) -> bool:
533-
return key == entry.key and type_ == entry.type and class_ == entry.class_

0 commit comments

Comments
 (0)