Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def build(setup_kwargs: Any) -> None:
"src/zeroconf/_handlers/answers.py",
"src/zeroconf/_handlers/record_manager.py",
"src/zeroconf/_handlers/query_handler.py",
"src/zeroconf/_services/info.py",
"src/zeroconf/_services/registry.py",
"src/zeroconf/_updates.py",
"src/zeroconf/_utils/time.py",
Expand Down
87 changes: 87 additions & 0 deletions src/zeroconf/_services/info.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

import cython

from .._cache cimport DNSCache
from .._dns cimport DNSPointer, DNSRecord, DNSService, DNSText
from .._protocol.outgoing cimport DNSOutgoing
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis


cdef object _resolve_all_futures_to_none

cdef object _TYPE_SRV
cdef object _TYPE_TXT
cdef object _TYPE_A
cdef object _TYPE_AAAA
cdef object _TYPE_PTR
cdef object _TYPE_NSEC
cdef object _CLASS_IN
cdef object _FLAGS_QR_QUERY

cdef object service_type_name

cdef object DNS_QUESTION_TYPE_QU
cdef object DNS_QUESTION_TYPE_QM

cdef object _IPVersion_All_value
cdef object _IPVersion_V4Only_value

cdef object TYPE_CHECKING

cdef class ServiceInfo(RecordUpdateListener):

cdef public cython.bytes text
cdef public str type
cdef str _name
cdef public str key
cdef public cython.list _ipv4_addresses
cdef public cython.list _ipv6_addresses
cdef public object port
cdef public object weight
cdef public object priority
cdef public str server
cdef public str server_key
cdef public cython.dict _properties
cdef public object host_ttl
cdef public object other_ttl
cdef public object interface_index
cdef public cython.set _new_records_futures
cdef public DNSPointer _dns_pointer_cache
cdef public DNSService _dns_service_cache
cdef public DNSText _dns_text_cache
cdef public cython.list _dns_address_cache
cdef public cython.set _get_address_and_nsec_records_cache

@cython.locals(
cache=DNSCache
)
cpdef async_update_records(self, object zc, object now, cython.list records)

@cython.locals(
cache=DNSCache
)
cpdef _load_from_cache(self, object zc, object now)

cdef _unpack_text_into_properties(self)

cdef _set_properties(self, cython.dict properties)

cdef _set_text(self, cython.bytes text)

cdef _get_ip_addresses_from_cache_lifo(self, object zc, object now, object type)

cdef _process_record_threadsafe(self, object zc, DNSRecord record, object now)

@cython.locals(
cache=DNSCache
)
cdef cython.list _get_address_records_from_cache_by_type(self, object zc, object _type)

cdef _set_ipv4_addresses_from_cache(self, object zc, object now)

cdef _set_ipv6_addresses_from_cache(self, object zc, object now)

cdef cython.list _ip_addresses_by_version_value(self, object version_value)

cdef addresses_by_version(self, object version)
67 changes: 40 additions & 27 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@
# the A/AAAA/SRV records for a host.
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

float_ = float
int_ = int

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

if TYPE_CHECKING:
from .._core import Zeroconf

Expand Down Expand Up @@ -281,10 +287,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""
version_value = version.value
if version_value == _IPVersion_All_value:
return [
*(addr.packed for addr in self._ipv4_addresses),
*(addr.packed for addr in self._ipv6_addresses),
]
ip_v4_packed = [addr.packed for addr in self._ipv4_addresses]
ip_v6_packed = [addr.packed for addr in self._ipv6_addresses]
return [*ip_v4_packed, *ip_v6_packed]
if version_value == _IPVersion_V4Only_value:
return [addr.packed for addr in self._ipv4_addresses]
return [addr.packed for addr in self._ipv6_addresses]
Expand All @@ -303,7 +308,7 @@ def ip_addresses_by_version(
return self._ip_addresses_by_version_value(version.value)

def _ip_addresses_by_version_value(
self, version_value: int
self, version_value: int_
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
Expand Down Expand Up @@ -397,7 +402,7 @@ def get_name(self) -> str:
return self._name[: len(self._name) - len(self.type) - 1]

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
self, zc: 'Zeroconf', now: float_, type: int_
) -> List[Union[IPv4Address, IPv6Address]]:
"""Set IPv6 addresses from the cache."""
address_list: List[Union[IPv4Address, IPv6Address]] = []
Expand All @@ -410,7 +415,7 @@ def _get_ip_addresses_from_cache_lifo(
address_list.reverse() # Reverse to get LIFO order
return address_list

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:
"""Set IPv6 addresses from the cache."""
if TYPE_CHECKING:
self._ipv6_addresses = cast(
Expand All @@ -419,7 +424,7 @@ def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float_) -> None:
"""Set IPv4 addresses from the cache."""
if TYPE_CHECKING:
self._ipv4_addresses = cast(
Expand All @@ -428,7 +433,7 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
def async_update_records(self, zc: 'Zeroconf', now: float_, records: List[RecordUpdate]) -> None:
"""Updates service information from a DNS record.

This method will be run in the event loop.
Expand All @@ -440,7 +445,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
if updated and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float_) -> bool:
"""Thread safe record updating.

Returns True if a new record was added.
Expand Down Expand Up @@ -624,14 +629,15 @@ def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Se
self._get_address_and_nsec_records_cache = records
return records

def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int_) -> List[DNSAddress]:
"""Get the addresses from the cache."""
if self.server_key is None:
return []
cache = zc.cache
if TYPE_CHECKING:
records = cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))
records = cast("List[DNSAddress]", cache.get_all_by_details(self.server_key, _type, _CLASS_IN))
else:
records = zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN)
records = cache.get_all_by_details(self.server_key, _type, _CLASS_IN)
return records

def set_server_if_missing(self) -> None:
Expand All @@ -643,28 +649,33 @@ def set_server_if_missing(self) -> None:
self.server = self._name
self.server_key = self.key

def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
def load_from_cache(self, zc: 'Zeroconf', now: Optional[float_] = None) -> bool:
"""Populate the service info from the cache.

This method is designed to be threadsafe.
"""
return self._load_from_cache(zc, now or current_time_millis())

def _load_from_cache(self, zc: 'Zeroconf', now: float_) -> bool:
"""Populate the service info from the cache.

This method is designed to be threadsafe.
"""
if not now:
now = current_time_millis()
cache = zc.cache
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
cached_srv_record = cache.get_by_details(self._name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
self._process_record_threadsafe(zc, cached_srv_record, now)
cached_txt_record = zc.cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
cached_txt_record = cache.get_by_details(self._name, _TYPE_TXT, _CLASS_IN)
if cached_txt_record:
self._process_record_threadsafe(zc, cached_txt_record, now)
if original_server_key == self.server_key:
# If there is a srv which changes the server_key,
# A and AAAA will already be loaded from the cache
# and we do not want to do it twice
for record in [
*self._get_address_records_from_cache_by_type(zc, _TYPE_A),
*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),
]:
for record in self._get_address_records_from_cache_by_type(zc, _TYPE_A):
self._process_record_threadsafe(zc, record, now)
for record in self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA):
self._process_record_threadsafe(zc, record, now)
return self._is_complete

Expand Down Expand Up @@ -720,7 +731,7 @@ async def async_request(

now = current_time_millis()

if self.load_from_cache(zc, now):
if self._load_from_cache(zc, now):
return True

if TYPE_CHECKING:
Expand All @@ -737,11 +748,13 @@ async def async_request(
return False
if next_ <= now:
out = self.generate_request_query(
zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM
zc,
now,
question_type or DNS_QUESTION_TYPE_QU if first_request else DNS_QUESTION_TYPE_QM,
)
first_request = False
if not out.questions:
return self.load_from_cache(zc, now)
return self._load_from_cache(zc, now)
zc.async_send(out, addr, port)
next_ = now + delay
delay *= 2
Expand All @@ -755,7 +768,7 @@ async def async_request(
return True

def generate_request_query(
self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None
self, zc: 'Zeroconf', now: float_, question_type: Optional[DNSQuestionType] = None
) -> DNSOutgoing:
"""Generate the request query."""
out = DNSOutgoing(_FLAGS_QR_QUERY)
Expand All @@ -766,7 +779,7 @@ def generate_request_query(
out.add_question_or_one_cache(cache, now, name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(cache, now, server_or_name, _TYPE_AAAA, _CLASS_IN)
if question_type == DNSQuestionType.QU:
if question_type == DNS_QUESTION_TYPE_QU:
for question in out.questions:
question.unicast = True
return out
Expand Down