From 00c736552454997bc9c746fee3a79dde7895eebe Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 1 Aug 2023 17:18:28 -1000 Subject: [PATCH] feat: speed up processing incoming records add cython defs for dns record ttl compares --- src/zeroconf/_dns.pxd | 26 +++++++++++++++++++++----- src/zeroconf/_dns.py | 20 +++++++++++++------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index cd4f1f9e3..5908ff1bf 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -1,6 +1,8 @@ import cython +from ._protocol.incoming cimport DNSIncoming + cdef object _LEN_BYTE cdef object _LEN_SHORT @@ -9,9 +11,9 @@ cdef object _LEN_INT cdef object _NAME_COMPRESSION_MIN_SIZE cdef object _BASE_MAX_SIZE -cdef object _EXPIRE_FULL_TIME_MS -cdef object _EXPIRE_STALE_TIME_MS -cdef object _RECENT_TIME_MS +cdef cython.uint _EXPIRE_FULL_TIME_MS +cdef cython.uint _EXPIRE_STALE_TIME_MS +cdef cython.uint _RECENT_TIME_MS cdef object _CLASS_UNIQUE cdef object _CLASS_MASK @@ -34,11 +36,25 @@ cdef class DNSQuestion(DNSEntry): cdef class DNSRecord(DNSEntry): - cdef public object ttl - cdef public object created + cdef public cython.float ttl + cdef public cython.float created cdef _suppressed_by_answer(self, DNSRecord answer) + @cython.locals( + answers=cython.list, + ) + cpdef suppressed_by(self, DNSIncoming msg) + + cpdef get_expiration_time(self, cython.uint percent) + + cpdef is_expired(self, cython.float now) + + cpdef is_stale(self, cython.float now) + + cpdef is_recent(self, cython.float now) + + cpdef reset_ttl(self, DNSRecord other) cdef class DNSAddress(DNSRecord): diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 34d7fdb24..561b16ffc 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -40,6 +40,8 @@ _EXPIRE_STALE_TIME_MS = 500 _RECENT_TIME_MS = 250 +_float = float +_int = int if TYPE_CHECKING: from ._protocol.incoming import DNSIncoming @@ -172,32 +174,36 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use def suppressed_by(self, msg: 'DNSIncoming') -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" - return any(self._suppressed_by_answer(record) for record in msg.answers) + answers = msg.answers + for record in answers: + if self._suppressed_by_answer(record): + return True + return False - def _suppressed_by_answer(self, other) -> bool: # type: ignore[no-untyped-def] + def _suppressed_by_answer(self, other: 'DNSRecord') -> bool: """Returns true if another record has same name, type and class, and if its TTL is at least half of this record's.""" return self == other and other.ttl > (self.ttl / 2) - def get_expiration_time(self, percent: int) -> float: + def get_expiration_time(self, percent: _int) -> float: """Returns the time at which this record will have expired by a certain percentage.""" return self.created + (percent * self.ttl * 10) # TODO: Switch to just int here - def get_remaining_ttl(self, now: float) -> Union[int, float]: + def get_remaining_ttl(self, now: _float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now)) - def is_expired(self, now: float) -> bool: + def is_expired(self, now: _float) -> bool: """Returns true if this record has expired.""" return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now - def is_stale(self, now: float) -> bool: + def is_stale(self, now: _float) -> bool: """Returns true if this record is at least half way expired.""" return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now - def is_recent(self, now: float) -> bool: + def is_recent(self, now: _float) -> bool: """Returns true if the record more than one quarter of its TTL remaining.""" return self.created + (_RECENT_TIME_MS * self.ttl) > now