Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def build(setup_kwargs: Any) -> None:
dict(
ext_modules=cythonize(
[
"src/zeroconf/_cache.py",
"src/zeroconf/_dns.py",
"src/zeroconf/_protocol/incoming.py",
"src/zeroconf/_protocol/outgoing.py",
Expand Down
28 changes: 28 additions & 0 deletions src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import cython
from ._dns cimport (
DNSAddress,
DNSEntry,
DNSHinfo,
DNSPointer,
DNSRecord,
DNSService,
DNSText,
)


cdef object _TYPE_PTR

cdef _remove_key(cython.dict cache, object key, DNSRecord record)


cdef class DNSCache:

cdef public cython.dict cache
cdef public cython.dict service_cache

cdef _async_add(self, DNSRecord record)

cdef _async_remove(self, DNSRecord record)


cdef _dns_record_matches(DNSRecord record, object key, object type_, object class_)
42 changes: 25 additions & 17 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,23 @@
DNSRecord,
DNSService,
DNSText,
dns_entry_matches,
)
from ._utils.time import current_time_millis
from .const import _TYPE_PTR

_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
_DNSRecord = DNSRecord
_str = str


def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
"""Remove a key from a DNSRecord cache

This function must be run in from event loop.
"""
del cache[key][entry]
del cache[key][record]
if not cache[key]:
del cache[key]

Expand All @@ -62,7 +63,7 @@ def __init__(self) -> None:
# Functions prefixed with async_ are NOT threadsafe and must
# be run in the event loop.

def _async_add(self, entry: DNSRecord) -> bool:
def _async_add(self, record: _DNSRecord) -> bool:
"""Adds an entry.

Returns true if the entry was not already in the cache.
Expand All @@ -75,11 +76,11 @@ def _async_add(self, entry: DNSRecord) -> bool:
# replaces any existing records that are __eq__ to each other which
# removes the risk that accessing the cache from the wrong
# direction would return the old incorrect entry.
store = self.cache.setdefault(entry.key, {})
new = entry not in store and not isinstance(entry, DNSNsec)
store[entry] = entry
if isinstance(entry, DNSService):
self.service_cache.setdefault(entry.server_key, {})[entry] = entry
store = self.cache.setdefault(record.key, {})
new = record not in store and not isinstance(record, DNSNsec)
store[record] = record
if isinstance(record, DNSService):
self.service_cache.setdefault(record.server_key, {})[record] = record
return new

def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
Expand All @@ -95,14 +96,14 @@ def async_add_records(self, entries: Iterable[DNSRecord]) -> bool:
new = True
return new

def _async_remove(self, entry: DNSRecord) -> None:
def _async_remove(self, record: _DNSRecord) -> None:
"""Removes an entry.

This function must be run in from event loop.
"""
if isinstance(entry, DNSService):
_remove_key(self.service_cache, entry.server_key, entry)
_remove_key(self.cache, entry.key, entry)
if isinstance(record, DNSService):
_remove_key(self.service_cache, record.server_key, record)
_remove_key(self.cache, record.key, record)

def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
"""Remove multiple records.
Expand All @@ -128,7 +129,10 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
This function is not threadsafe and must be called from
the event loop.
"""
return self.cache.get(entry.key, {}).get(entry)
store = self.cache.get(entry.key)
if store is None:
return None
return store.get(entry)

def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:
"""Gets all matching entries by details.
Expand All @@ -138,7 +142,7 @@ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[D
"""
key = name.lower()
for entry in self.cache.get(key, []):
if dns_entry_matches(entry, key, type_, class_):
if _dns_record_matches(entry, key, type_, class_):
yield entry

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
Expand Down Expand Up @@ -185,15 +189,15 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco
"""
key = name.lower()
for cached_entry in reversed(list(self.cache.get(key, []))):
if dns_entry_matches(cached_entry, key, type_, class_):
if _dns_record_matches(cached_entry, key, type_, class_):
return cached_entry
return None

def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
"""Gets all matching entries by details."""
key = name.lower()
return [
entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)
entry for entry in list(self.cache.get(key, [])) if _dns_record_matches(entry, key, type_, class_)
]

def entries_with_server(self, server: str) -> List[DNSRecord]:
Expand All @@ -218,3 +222,7 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
def names(self) -> List[str]:
"""Return a copy of the list of current cache names."""
return list(self.cache)


def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool:
return key == record.key and type_ == record.type and class_ == record.class_
2 changes: 2 additions & 0 deletions src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,5 @@ cdef class DNSRRSet:

cdef _records
cdef _lookup

cdef _dns_entry_matches(DNSEntry entry, object key, object type_, object class_)
28 changes: 16 additions & 12 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class DNSQuestionType(enum.Enum):
QM = 2


def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool:
return key == record.key and type_ == record.type and class_ == record.class_


class DNSEntry:

"""A DNS entry"""
Expand All @@ -78,7 +74,7 @@ def __init__(self, name: str, type_: int, class_: int) -> None:

def __eq__(self, other: Any) -> bool:
"""Equality test on key (lowercase name), type, and class"""
return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
return _dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)

@staticmethod
def get_class_(class_: int) -> str:
Expand Down Expand Up @@ -121,7 +117,7 @@ def __hash__(self) -> int:

def __eq__(self, other: Any) -> bool:
"""Tests equality on dns question."""
return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_)
return isinstance(other, DNSQuestion) and _dns_entry_matches(other, self.key, self.type, self.class_)

@property
def max_size(self) -> int:
Expand Down Expand Up @@ -254,7 +250,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSAddress)
and self.address == other.address
and self.scope_id == other.scope_id
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -298,7 +294,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSHinfo)
and self.cpu == other.cpu
and self.os == other.os
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -342,7 +338,7 @@ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, DNSPointer)
and self.alias == other.alias
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -381,7 +377,7 @@ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, DNSText)
and self.text == other.text
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -432,7 +428,7 @@ def __eq__(self, other: Any) -> bool:
and self.weight == other.weight
and self.port == other.port
and self.server == other.server
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -487,7 +483,7 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSNsec)
and self.next_name == other.next_name

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move all the checks after the isinstance into a cdef with the other typed to the same type which will avoid all the python accessor overhead since we can access the struct values in C instead

and self.rdtypes == other.rdtypes
and dns_entry_matches(other, self.key, self.type, self.class_)
and _dns_entry_matches(other, self.key, self.type, self.class_)
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -527,3 +523,11 @@ def suppresses(self, record: DNSRecord) -> bool:
def __contains__(self, record: DNSRecord) -> bool:
"""Returns true if the rrset contains the record."""
return record in self.lookup


_DNSEntry = DNSEntry
_str = str


def _dns_entry_matches(entry: _DNSEntry, key: _str, type_: int, class_: int) -> bool:
return key == entry.key and type_ == entry.type and class_ == entry.class_
5 changes: 4 additions & 1 deletion src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
UNPACK_HHiH = struct.Struct(b'!HHiH').unpack_from

_seen_logs: Dict[str, Union[int, tuple]] = {}
_str = str


class DNSIncoming:
Expand Down Expand Up @@ -250,7 +251,9 @@ def _read_others(self) -> None:
if rec is not None:
self._answers.append(rec)

def _read_record(self, domain, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: # type: ignore[no-untyped-def]
def _read_record(
self, domain: _str, type_: int, class_: int, ttl: int, length: int
) -> Optional[DNSRecord]:
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), created=self.now)
Expand Down