diff --git a/tests/test_aio.py b/tests/test_aio.py index 47c1e2d9d..e41442500 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -16,7 +16,8 @@ from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered -from zeroconf._services import ServiceInfo, ServiceListener +from zeroconf._services import ServiceListener +from zeroconf._services.info import ServiceInfo from zeroconf._utils.time import current_time_millis from . import _clear_cache diff --git a/tests/test_services.py b/tests/test_services.py index f972f9d24..2a077329c 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -19,11 +19,8 @@ from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis import zeroconf._services as s from zeroconf import Zeroconf -from zeroconf._services import ( - ServiceBrowser, - ServiceInfo, - ServiceStateChange, -) +from zeroconf._services import ServiceBrowser, ServiceStateChange +from zeroconf._services.info import ServiceInfo from zeroconf.aio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ab2b0993e..e61a71193 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -46,15 +46,17 @@ ) from ._protocol import DNSIncoming, DNSOutgoing # noqa # import needed for backwards compat from ._services import ( # noqa # import needed for backwards compat - instance_name_from_service_info, Signal, SignalRegistrationInterface, RecordUpdateListener, ServiceBrowser, - ServiceInfo, ServiceListener, ServiceStateChange, ) +from ._services.info import ( # noqa # import needed for backwards compat + instance_name_from_service_info, + ServiceInfo, +) from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat from ._services.types import ZeroconfServiceTypes from ._utils.name import service_type_name # noqa # import needed for backwards compat diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a7910591a..675d7169d 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -37,13 +37,8 @@ from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing -from ._services import ( - RecordUpdateListener, - ServiceBrowser, - ServiceInfo, - ServiceListener, - instance_name_from_service_info, -) +from ._services import RecordUpdateListener, ServiceBrowser, ServiceListener +from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._utils.aio import get_running_loop from ._utils.name import service_type_name diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 04288a69c..111ea4487 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -21,44 +21,28 @@ """ import enum -import socket import threading import warnings from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from .._cache import _UniqueRecordsType -from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText -from .._exceptions import BadTypeInNameException +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord from .._protocol import DNSOutgoing from .._utils.name import service_type_name -from .._utils.net import ( - IPVersion, - _encode_address, - _is_v6_address, -) -from .._utils.struct import int2byte from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, _CLASS_IN, - _CLASS_UNIQUE, - _DNS_HOST_TTL, - _DNS_OTHER_TTL, _DNS_PACKET_HEADER_LEN, _EXPIRE_REFRESH_TIME_PERCENT, _FLAGS_QR_QUERY, - _LISTENER_TIME, _MAX_MSG_TYPICAL, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT, - _TYPE_A, - _TYPE_AAAA, _TYPE_PTR, - _TYPE_SRV, - _TYPE_TXT, ) @@ -77,16 +61,6 @@ class ServiceStateChange(enum.Enum): Updated = 3 -def instance_name_from_service_info(info: "ServiceInfo") -> str: - """Calculate the instance name from the ServiceInfo.""" - # This is kind of funky because of the subtype based tests - # need to make subtypes a first class citizen - service_name = service_type_name(info.name) - if not info.type.endswith(service_name): - raise BadTypeInNameException - return info.name[: -len(service_name) - 1] - - class ServiceListener: def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() @@ -505,396 +479,3 @@ def run(self) -> None: name=name_type[0], state_change=state_change, ) - - -class ServiceInfo(RecordUpdateListener): - """Service information. - - Constructor parameters are as follows: - - * `type_`: fully qualified service type name - * `name`: fully qualified service name - * `port`: port that the service runs on - * `weight`: weight of the service - * `priority`: priority of the service - * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). - converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to - value-less attributes. - * `server`: fully qualified name for service host (defaults to name) - * `host_ttl`: ttl used for A/SRV records - * `other_ttl`: ttl used for PTR/TXT records - * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, - or in parsed form as text; at most one of those parameters can be provided) - - """ - - text = b'' - - def __init__( - self, - type_: str, - name: str, - port: Optional[int] = None, - weight: int = 0, - priority: int = 0, - properties: Union[bytes, Dict] = b'', - server: Optional[str] = None, - host_ttl: int = _DNS_HOST_TTL, - other_ttl: int = _DNS_OTHER_TTL, - *, - addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None - ) -> None: - # Accept both none, or one, but not both. - if addresses is not None and parsed_addresses is not None: - raise TypeError("addresses and parsed_addresses cannot be provided together") - if not type_.endswith(service_type_name(name, strict=False)): - raise BadTypeInNameException - self.type = type_ - self._name = name - self.key = name.lower() - if addresses is not None: - self._addresses = addresses - elif parsed_addresses is not None: - self._addresses = [_encode_address(a) for a in parsed_addresses] - else: - self._addresses = [] - # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] - if invalid: - raise TypeError( - 'Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid - ) - self.port = port - self.weight = weight - self.priority = priority - self.server = server if server else name - self.server_key = self.server.lower() - self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} - if isinstance(properties, bytes): - self._set_text(properties) - else: - self._set_properties(properties) - self.host_ttl = host_ttl - self.other_ttl = other_ttl - - @property - def name(self) -> str: - """The name of the service.""" - return self._name - - @name.setter - def name(self, name: str) -> None: - """Replace the the name and reset the key.""" - self._name = name - self.key = name.lower() - - @property - def addresses(self) -> List[bytes]: - """IPv4 addresses of this service. - - Only IPv4 addresses are returned for backward compatibility. - Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to - include IPv6 addresses as well. - """ - return self.addresses_by_version(IPVersion.V4Only) - - @addresses.setter - def addresses(self, value: List[bytes]) -> None: - """Replace the addresses list. - - This replaces all currently stored addresses, both IPv4 and IPv6. - """ - self._addresses = value - - @property - def properties(self) -> Dict: - """If properties were set in the constructor this property returns the original dictionary - of type `Dict[Union[bytes, str], Any]`. - - If properties are coming from the network, after decoding a TXT record, the keys are always - bytes and the values are either bytes, if there was a value, even empty, or `None`, if there - was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. - """ - return self._properties - - def addresses_by_version(self, version: IPVersion) -> List[bytes]: - """List addresses matching IP version.""" - if version == IPVersion.V4Only: - return [addr for addr in self._addresses if not _is_v6_address(addr)] - if version == IPVersion.V6Only: - return list(filter(_is_v6_address, self._addresses)) - return self._addresses - - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: - """List addresses in their parsed string form.""" - result = self.addresses_by_version(version) - return [ - socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) - for addr in result - ] - - def _set_properties(self, properties: Dict) -> None: - """Sets properties and text of this info from a dictionary""" - self._properties = properties - list_ = [] - result = b'' - for key, value in properties.items(): - if isinstance(key, str): - key = key.encode('utf-8') - - record = key - if value is not None: - if not isinstance(value, bytes): - value = str(value).encode('utf-8') - record += b'=' + value - list_.append(record) - for item in list_: - result = b''.join((result, int2byte(len(item)), item)) - self.text = result - - def _set_text(self, text: bytes) -> None: - """Sets properties and text given a text field""" - self.text = text - end = len(text) - if end == 0: - self._properties = {} - return - result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} - index = 0 - strs = [] - while index < end: - length = text[index] - index += 1 - strs.append(text[index : index + length]) - index += length - - key: bytes - value: Optional[bytes] - for s in strs: - try: - key, value = s.split(b'=', 1) - except ValueError: - # No equals sign at all - key = s - value = None - - # Only update non-existent properties - if key and result.get(key) is None: - result[key] = value - - self._properties = result - - def get_name(self) -> str: - """Name accessor""" - return self.name[: len(self.name) - len(self.type) - 1] - - def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record. - - This method is deprecated and will be removed in a future version. - update_records should be implemented instead. - - This method will be run in the event loop. - """ - if record is not None: - self._process_records_threadsafe(zc, now, [record]) - - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Updates service information from a DNS record. - - This method will be run in the event loop. - """ - self._process_records_threadsafe(zc, now, records) - - def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Thread safe record updating.""" - update_addresses = False - for record in records: - if isinstance(record, DNSService): - update_addresses = True - self._process_record_threadsafe(record, now) - - # Only update addresses if the DNSService (.server) has changed - if not update_addresses: - return - - for record in self._get_address_records_from_cache(zc): - self._process_record_threadsafe(record, now) - - def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: - if record.is_expired(now): - return - - if isinstance(record, DNSAddress): - if record.key == self.server_key and record.address not in self._addresses: - self._addresses.append(record.address) - return - - if isinstance(record, DNSService): - if record.key != self.key: - return - self.name = record.name - self.server = record.server - self.server_key = record.server.lower() - self.port = record.port - self.weight = record.weight - self.priority = record.priority - return - - if isinstance(record, DNSText): - if record.key == self.key: - self._set_text(record.text) - - def dns_addresses( - self, - override_ttl: Optional[int] = None, - version: IPVersion = IPVersion.All, - created: Optional[float] = None, - ) -> List[DNSAddress]: - """Return matching DNSAddress from ServiceInfo.""" - return [ - DNSAddress( - self.server, - _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - address, - created, - ) - for address in self.addresses_by_version(version) - ] - - def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: - """Return DNSPointer from ServiceInfo.""" - return DNSPointer( - self.type, - _TYPE_PTR, - _CLASS_IN, - override_ttl if override_ttl is not None else self.other_ttl, - self.name, - created, - ) - - def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: - """Return DNSService from ServiceInfo.""" - return DNSService( - self.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - self.priority, - self.weight, - cast(int, self.port), - self.server, - created, - ) - - def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: - """Return DNSText from ServiceInfo.""" - return DNSText( - self.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.other_ttl, - self.text, - created, - ) - - def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: - """Get the address records from the cache.""" - return [ - *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), - *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), - ] - - def load_from_cache(self, zc: 'Zeroconf') -> bool: - """Populate the service info from the cache. - - This method is designed to be threadsafe. - """ - now = current_time_millis() - record_updates = [] - cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) - if cached_srv_record: - # If there is a srv record, A and AAAA will already - # be called and we do not want to do it twice - record_updates.append(cached_srv_record) - else: - record_updates.extend(self._get_address_records_from_cache(zc)) - cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) - if cached_txt_record: - record_updates.append(cached_txt_record) - self._process_records_threadsafe(zc, now, record_updates) - return self._is_complete - - @property - def _is_complete(self) -> bool: - """The ServiceInfo has all expected properties.""" - return not (self.text is None or not self._addresses) - - def request(self, zc: 'Zeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - if self.load_from_cache(zc): - return True - - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - try: - # Do not set a question on the listener to preload from cache - # since we just checked it above in load_from_cache - zc.add_listener(self, None) - while not self._is_complete: - if last <= now: - return False - if next_ <= now: - out = self.generate_request_query(zc, now) - if not out.questions: - return True - zc.send(out) - next_ = now + delay - delay *= 2 - - zc.wait(min(next_, last) - now) - now = current_time_millis() - finally: - zc.remove_listener(self) - - return True - - def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: - """Generate the request query.""" - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) - return out - - def __eq__(self, other: object) -> bool: - """Tests equality of service name""" - return isinstance(other, ServiceInfo) and other.name == self.name - - def __repr__(self) -> str: - """String representation""" - return '%s(%s)' % ( - type(self).__name__, - ', '.join( - '%s=%r' % (name, getattr(self, name)) - for name in ( - 'type', - 'name', - 'addresses', - 'port', - 'weight', - 'priority', - 'server', - 'properties', - ) - ), - ) diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py new file mode 100644 index 000000000..a3536ed1d --- /dev/null +++ b/zeroconf/_services/info.py @@ -0,0 +1,458 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import socket +from typing import Dict, List, Optional, TYPE_CHECKING, Union, cast + +from .._dns import DNSAddress, DNSPointer, DNSRecord, DNSService, DNSText +from .._exceptions import BadTypeInNameException +from .._protocol import DNSOutgoing +from .._services import RecordUpdateListener +from .._utils.name import service_type_name +from .._utils.net import ( + IPVersion, + _encode_address, + _is_v6_address, +) +from .._utils.struct import int2byte +from .._utils.time import current_time_millis +from ..const import ( + _CLASS_IN, + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _FLAGS_QR_QUERY, + _LISTENER_TIME, + _TYPE_A, + _TYPE_AAAA, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from .._core import Zeroconf # pylint: disable=cyclic-import + + +def instance_name_from_service_info(info: "ServiceInfo") -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + +class ServiceInfo(RecordUpdateListener): + """Service information. + + Constructor parameters are as follows: + + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) + + """ + + text = b'' + + def __init__( + self, + type_: str, + name: str, + port: Optional[int] = None, + weight: int = 0, + priority: int = 0, + properties: Union[bytes, Dict] = b'', + server: Optional[str] = None, + host_ttl: int = _DNS_HOST_TTL, + other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None + ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") + if not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + self.type = type_ + self._name = name + self.key = name.lower() + if addresses is not None: + self._addresses = addresses + elif parsed_addresses is not None: + self._addresses = [_encode_address(a) for a in parsed_addresses] + else: + self._addresses = [] + # This results in an ugly error when registering, better check now + invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] + if invalid: + raise TypeError( + 'Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid + ) + self.port = port + self.weight = weight + self.priority = priority + self.server = server if server else name + self.server_key = self.server.lower() + self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + if isinstance(properties, bytes): + self._set_text(properties) + else: + self._set_properties(properties) + self.host_ttl = host_ttl + self.other_ttl = other_ttl + + @property + def name(self) -> str: + """The name of the service.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Replace the the name and reset the key.""" + self._name = name + self.key = name.lower() + + @property + def addresses(self) -> List[bytes]: + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value: List[bytes]) -> None: + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._addresses = value + + @property + def properties(self) -> Dict: + """If properties were set in the constructor this property returns the original dictionary + of type `Dict[Union[bytes, str], Any]`. + + If properties are coming from the network, after decoding a TXT record, the keys are always + bytes and the values are either bytes, if there was a value, even empty, or `None`, if there + was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. + """ + return self._properties + + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version.""" + if version == IPVersion.V4Only: + return [addr for addr in self._addresses if not _is_v6_address(addr)] + if version == IPVersion.V6Only: + return list(filter(_is_v6_address, self._addresses)) + return self._addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form.""" + result = self.addresses_by_version(version) + return [ + socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) + for addr in result + ] + + def _set_properties(self, properties: Dict) -> None: + """Sets properties and text of this info from a dictionary""" + self._properties = properties + list_ = [] + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, int2byte(len(item)), item)) + self.text = result + + def _set_text(self, text: bytes) -> None: + """Sets properties and text given a text field""" + self.text = text + end = len(text) + if end == 0: + self._properties = {} + return + result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + index = 0 + strs = [] + while index < end: + length = text[index] + index += 1 + strs.append(text[index : index + length]) + index += length + + key: bytes + value: Optional[bytes] + for s in strs: + try: + key, value = s.split(b'=', 1) + except ValueError: + # No equals sign at all + key = s + value = None + + # Only update non-existent properties + if key and result.get(key) is None: + result[key] = value + + self._properties = result + + def get_name(self) -> str: + """Name accessor""" + return self.name[: len(self.name) - len(self.type) - 1] + + def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + + This method will be run in the event loop. + """ + if record is not None: + self._process_records_threadsafe(zc, now, [record]) + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method will be run in the event loop. + """ + self._process_records_threadsafe(zc, now, records) + + def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Thread safe record updating.""" + update_addresses = False + for record in records: + if isinstance(record, DNSService): + update_addresses = True + self._process_record_threadsafe(record, now) + + # Only update addresses if the DNSService (.server) has changed + if not update_addresses: + return + + for record in self._get_address_records_from_cache(zc): + self._process_record_threadsafe(record, now) + + def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: + if record.is_expired(now): + return + + if isinstance(record, DNSAddress): + if record.key == self.server_key and record.address not in self._addresses: + self._addresses.append(record.address) + return + + if isinstance(record, DNSService): + if record.key != self.key: + return + self.name = record.name + self.server = record.server + self.server_key = record.server.lower() + self.port = record.port + self.weight = record.weight + self.priority = record.priority + return + + if isinstance(record, DNSText): + if record.key == self.key: + self._set_text(record.text) + + def dns_addresses( + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return [ + DNSAddress( + self.server, + _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + address, + created, + ) + for address in self.addresses_by_version(version) + ] + + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self.name, + created, + ) + + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return DNSService( + self.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + cast(int, self.port), + self.server, + created, + ) + + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return DNSText( + self.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + created, + ) + + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: + """Get the address records from the cache.""" + return [ + *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), + *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), + ] + + def load_from_cache(self, zc: 'Zeroconf') -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + now = current_time_millis() + record_updates = [] + cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + # If there is a srv record, A and AAAA will already + # be called and we do not want to do it twice + record_updates.append(cached_srv_record) + else: + record_updates.extend(self._get_address_records_from_cache(zc)) + cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + record_updates.append(cached_txt_record) + self._process_records_threadsafe(zc, now, record_updates) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return not (self.text is None or not self._addresses) + + def request(self, zc: 'Zeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(zc): + return True + + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + # Do not set a question on the listener to preload from cache + # since we just checked it above in load_from_cache + zc.add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query(zc, now) + if not out.questions: + return True + zc.send(out) + next_ = now + delay + delay *= 2 + + zc.wait(min(next_, last) - now) + now = current_time_millis() + finally: + zc.remove_listener(self) + + return True + + def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) + return out + + def __eq__(self, other: object) -> bool: + """Tests equality of service name""" + return isinstance(other, ServiceInfo) and other.name == self.name + + def __repr__(self) -> str: + """String representation""" + return '%s(%s)' % ( + type(self).__name__, + ', '.join( + '%s=%r' % (name, getattr(self, name)) + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + ) + ), + ) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 20584b3a6..ebf5abbb6 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Union +from .info import ServiceInfo from .._exceptions import ServiceNameAlreadyRegistered -from .._services import ServiceInfo class ServiceRegistry: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index d5414d138..e64c87c38 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,7 +26,8 @@ from ._core import NotifyListener, Zeroconf from ._exceptions import NonUniqueNameException -from ._services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info +from ._services import _ServiceBrowserBase +from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.types import ZeroconfServiceTypes from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType