Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

import cython

from ._protocol.incoming cimport DNSIncoming


cdef object _LEN_BYTE
cdef object _LEN_SHORT
Expand All @@ -9,9 +11,9 @@ cdef object _LEN_INT
cdef object _NAME_COMPRESSION_MIN_SIZE
cdef object _BASE_MAX_SIZE

cdef object _EXPIRE_FULL_TIME_MS
cdef object _EXPIRE_STALE_TIME_MS
cdef object _RECENT_TIME_MS
cdef cython.uint _EXPIRE_FULL_TIME_MS
cdef cython.uint _EXPIRE_STALE_TIME_MS
cdef cython.uint _RECENT_TIME_MS

cdef object _CLASS_UNIQUE
cdef object _CLASS_MASK
Expand All @@ -34,11 +36,25 @@ cdef class DNSQuestion(DNSEntry):

cdef class DNSRecord(DNSEntry):

cdef public object ttl
cdef public object created
cdef public cython.float ttl
cdef public cython.float created

cdef _suppressed_by_answer(self, DNSRecord answer)

@cython.locals(
answers=cython.list,
)
cpdef suppressed_by(self, DNSIncoming msg)

cpdef get_expiration_time(self, cython.uint percent)

cpdef is_expired(self, cython.float now)

cpdef is_stale(self, cython.float now)

cpdef is_recent(self, cython.float now)

cpdef reset_ttl(self, DNSRecord other)

cdef class DNSAddress(DNSRecord):

Expand Down
20 changes: 13 additions & 7 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
_EXPIRE_STALE_TIME_MS = 500
_RECENT_TIME_MS = 250

_float = float
_int = int

if TYPE_CHECKING:
from ._protocol.incoming import DNSIncoming
Expand Down Expand Up @@ -172,32 +174,36 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
"""Returns true if any answer in a message can suffice for the
information held in this record."""
return any(self._suppressed_by_answer(record) for record in msg.answers)
answers = msg.answers
for record in answers:
if self._suppressed_by_answer(record):
return True
return False

def _suppressed_by_answer(self, other) -> bool: # type: ignore[no-untyped-def]
def _suppressed_by_answer(self, other: 'DNSRecord') -> bool:
"""Returns true if another record has same name, type and class,
and if its TTL is at least half of this record's."""
return self == other and other.ttl > (self.ttl / 2)

def get_expiration_time(self, percent: int) -> float:
def get_expiration_time(self, percent: _int) -> float:
"""Returns the time at which this record will have expired
by a certain percentage."""
return self.created + (percent * self.ttl * 10)

# TODO: Switch to just int here
def get_remaining_ttl(self, now: float) -> Union[int, float]:
def get_remaining_ttl(self, now: _float) -> Union[int, float]:
"""Returns the remaining TTL in seconds."""
return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now))

def is_expired(self, now: float) -> bool:
def is_expired(self, now: _float) -> bool:
"""Returns true if this record has expired."""
return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now

def is_stale(self, now: float) -> bool:
def is_stale(self, now: _float) -> bool:
"""Returns true if this record is at least half way expired."""
return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now

def is_recent(self, now: float) -> bool:
def is_recent(self, now: _float) -> bool:
"""Returns true if the record more than one quarter of its TTL remaining."""
return self.created + (_RECENT_TIME_MS * self.ttl) > now

Expand Down