Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from ._dns cimport (
DNSAddress,
DNSEntry,
DNSHinfo,
DNSNsec,
DNSPointer,
DNSRecord,
DNSService,
Expand All @@ -13,7 +14,7 @@ from ._dns cimport (

cdef object _UNIQUE_RECORD_TYPES
cdef object _TYPE_PTR
cdef object _ONE_SECOND
cdef cython.uint _ONE_SECOND

cdef _remove_key(cython.dict cache, object key, DNSRecord record)

Expand All @@ -27,23 +28,37 @@ cdef class DNSCache:

cpdef async_remove_records(self, object entries)

@cython.locals(
store=cython.dict,
)
cpdef async_get_unique(self, DNSRecord entry)

@cython.locals(
record=DNSRecord,
)
cpdef async_expire(self, float now)

@cython.locals(
records=cython.dict,
record=DNSRecord,
)
cdef _async_all_by_details(self, object name, object type_, object class_)
cpdef async_all_by_details(self, str name, object type_, object class_)

cpdef async_entries_with_name(self, str name)

cpdef async_entries_with_server(self, str name)

@cython.locals(
store=cython.dict,
)
cdef _async_add(self, DNSRecord record)

cdef _async_remove(self, DNSRecord record)

cpdef async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now)

@cython.locals(
record=DNSRecord,
created_float=cython.float,
)
cdef _async_mark_unique_records_older_than_1s_to_expire(self, object unique_types, object answers, object now)
cpdef async_mark_unique_records_older_than_1s_to_expire(self, cython.set unique_types, object answers, float now)

cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_)
25 changes: 6 additions & 19 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
USA
"""

import itertools
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

from ._dns import (
Expand Down Expand Up @@ -115,12 +114,12 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
for entry in entries:
self._async_remove(entry)

def async_expire(self, now: float) -> List[DNSRecord]:
def async_expire(self, now: _float) -> List[DNSRecord]:
"""Purge expired entries from the cache.

This function must be run in from event loop.
"""
expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)]
expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]
self.async_remove_records(expired)
return expired

Expand All @@ -136,15 +135,7 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
return None
return store.get(entry)

def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]:
"""Gets all matching entries by details.

This function is not thread-safe and must be called from
the event loop.
"""
return self._async_all_by_details(name, type_, class_)

def _async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]:
def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]:
"""Gets all matching entries by details.

This function is not thread-safe and must be called from
Expand Down Expand Up @@ -240,20 +231,16 @@ def names(self) -> List[str]:

def async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

def _async_mark_unique_records_older_than_1s_to_expire(
self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float
) -> None:
# rfc6762#section-10.2 para 2
# Since unique is set, all old records with that name, rrtype,
# and rrclass that were received more than one second ago are declared
# invalid, and marked to expire from the cache in one second.
answers_rrset = set(answers)
for name, type_, class_ in unique_types:
for record in self._async_all_by_details(name, type_, class_):
if (now - record.created > _ONE_SECOND) and record not in answers_rrset:
for record in self.async_all_by_details(name, type_, class_):
created_float = record.created
if (now - created_float > _ONE_SECOND) and record not in answers_rrset:
# Expire in 1s
record.set_created_ttl(now, 1)

Expand Down
5 changes: 3 additions & 2 deletions src/zeroconf/_handlers/record_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
now_float = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers

for record in msg.answers:
for record in answers:
# Protect zeroconf from records that can cause denial of service.
#
# We enforce a minimum TTL for PTR records to avoid
Expand Down Expand Up @@ -127,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
removes.add(record)

if unique_types:
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, msg.answers, now)
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)

if updates:
self.async_updates(now, updates)
Expand Down