Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ cdef class DNSRecord(DNSEntry):
)
cpdef suppressed_by(self, object msg)

cpdef get_remaining_ttl(self, cython.float now)

cpdef get_expiration_time(self, cython.uint percent)

cpdef is_expired(self, cython.float now)
Expand All @@ -54,6 +56,8 @@ cdef class DNSRecord(DNSEntry):

cpdef reset_ttl(self, DNSRecord other)

cpdef set_created_ttl(self, cython.float now, cython.float ttl)

cdef class DNSAddress(DNSRecord):

cdef public cython.int _hash
Expand Down
7 changes: 4 additions & 3 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ._exceptions import AbstractMethodException
from ._utils.net import _is_v6_address
from ._utils.time import current_time_millis, millis_to_seconds
from ._utils.time import current_time_millis
from .const import _CLASS_MASK, _CLASS_UNIQUE, _CLASSES, _TYPE_ANY, _TYPES

_LEN_BYTE = 1
Expand Down Expand Up @@ -193,7 +193,8 @@ def get_expiration_time(self, percent: _int) -> float:
# TODO: Switch to just int here
def get_remaining_ttl(self, now: _float) -> Union[int, float]:
"""Returns the remaining TTL in seconds."""
return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now))
remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0
return 0 if remain < 0 else remain

def is_expired(self, now: _float) -> bool:
"""Returns true if this record has expired."""
Expand All @@ -212,7 +213,7 @@ def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def]
another record."""
self.set_created_ttl(other.created, other.ttl)

def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None:
def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
"""Set the created and ttl of a record."""
self.created = created
self.ttl = ttl
Expand Down
15 changes: 9 additions & 6 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes: Set[DNSRecord] = set()
now = msg.now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache

for record in msg.answers:
# Protect zeroconf from records that can cause denial of service.
Expand All @@ -416,7 +417,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# ServiceBrowsers generating excessive queries refresh queries.
# Apple uses a 15s minimum TTL, however we do not have the same
# level of rate limit and safe guards so we use 1/4 of the recommended value.
if record.ttl and record.type == _TYPE_PTR and record.ttl < _DNS_PTR_MIN_TTL:
record_type = record.type
record_ttl = record.ttl
if record_ttl and record_type == _TYPE_PTR and record_ttl < _DNS_PTR_MIN_TTL:
log.debug(
"Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.",
record,
Expand All @@ -425,12 +428,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)

if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
unique_types.add((record.name, record.type, record.class_))
unique_types.add((record.name, record_type, record.class_))

if TYPE_CHECKING:
record = cast(_UniqueRecordsType, record)

maybe_entry = self.cache.async_get_unique(record)
maybe_entry = cache.async_get_unique(record)
if not record.is_expired(now):
if maybe_entry is not None:
maybe_entry.reset_ttl(record)
Expand All @@ -447,7 +450,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
self.cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)

if updates:
self.async_updates(now, updates)
Expand All @@ -468,12 +471,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# processsed.
new = False
if other_adds or address_adds:
new = self.cache.async_add_records(itertools.chain(address_adds, other_adds))
new = cache.async_add_records(itertools.chain(address_adds, other_adds))
# Removes are processed last since
# ServiceInfo could generate an un-needed query
# because the data was not yet populated.
if removes:
self.cache.async_remove_records(removes)
cache.async_remove_records(removes)
if updates:
self.async_updates_complete(new)

Expand Down