diff --git a/src/zeroconf/_dns.pxd b/src/zeroconf/_dns.pxd index d4116a66a..720805177 100644 --- a/src/zeroconf/_dns.pxd +++ b/src/zeroconf/_dns.pxd @@ -54,6 +54,8 @@ cdef class DNSRecord(DNSEntry): cpdef get_remaining_ttl(self, double now) + cpdef unsigned int get_percentage_remaining_ttl(self, double now) + cpdef double get_expiration_time(self, cython.uint percent) cpdef bint is_expired(self, double now) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index 66fb5b86d..262dbb5f4 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -193,6 +193,11 @@ def get_expiration_time(self, percent: _int) -> float: by a certain percentage.""" return self.created + (percent * self.ttl * 10) + def get_percentage_remaining_ttl(self, now: _float) -> _int: + """Returns the percentage remaining of the ttl between 0-100.""" + remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / self.ttl / 10 + return 0 if remain <= 0 else round(remain) + # TODO: Switch to just int here def get_remaining_ttl(self, now: _float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" diff --git a/tests/test_dns.py b/tests/test_dns.py index 0eac568dd..b7e5a8790 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -6,7 +6,6 @@ import logging import os import socket -import time import unittest import unittest.mock @@ -86,19 +85,32 @@ def test_dns_record_abc(self): record.write(None) # type: ignore[arg-type] def test_dns_record_reset_ttl(self): - record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) - time.sleep(1) - record2 = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + start = r.current_time_millis() + record = r.DNSRecord( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, created=start + ) + later = start + 1000 + record2 = r.DNSRecord( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, created=later + ) now = r.current_time_millis() assert record.created != record2.created assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now) + assert record.get_percentage_remaining_ttl(now) != record2.get_percentage_remaining_ttl(now) + assert record2.get_percentage_remaining_ttl(later) == 100 + assert record2.get_percentage_remaining_ttl(later + (const._DNS_HOST_TTL * 1000 / 2)) == 50 record.reset_ttl(record2) assert record.ttl == record2.ttl assert record.created == record2.created assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now) + assert record.get_percentage_remaining_ttl(now) == record2.get_percentage_remaining_ttl(now) + assert record.get_percentage_remaining_ttl(later) == 100 + assert record2.get_percentage_remaining_ttl(later) == 100 + assert record.get_percentage_remaining_ttl(later + (const._DNS_HOST_TTL * 1000 / 2)) == 50 + assert record2.get_percentage_remaining_ttl(later + (const._DNS_HOST_TTL * 1000 / 2)) == 50 def test_service_info_dunder(self): type_ = "_test-srvc-type._tcp.local."