diff --git a/src/zeroconf/_services/__init__.py b/src/zeroconf/_services/__init__.py index 968b5dafe..cf54d7f07 100644 --- a/src/zeroconf/_services/__init__.py +++ b/src/zeroconf/_services/__init__.py @@ -46,7 +46,6 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: class Signal: - __slots__ = ('_handlers',) def __init__(self) -> None: @@ -62,7 +61,6 @@ def registration_interface(self) -> 'SignalRegistrationInterface': class SignalRegistrationInterface: - __slots__ = ('_handlers',) def __init__(self, handlers: List[Callable[..., None]]) -> None: diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 8ff1f6656..d3e6f082b 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -21,9 +21,9 @@ """ import asyncio -import ipaddress import random from functools import lru_cache +from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast from .._dns import ( @@ -90,7 +90,7 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str: return info.name[: -len(service_name) - 1] -_cached_ip_addresses = lru_cache(maxsize=256)(ipaddress.ip_address) +_cached_ip_addresses = lru_cache(maxsize=256)(ip_address) class ServiceInfo(RecordUpdateListener): @@ -158,8 +158,8 @@ def __init__( self.type = type_ self._name = name self.key = name.lower() - self._ipv4_addresses: List[ipaddress.IPv4Address] = [] - self._ipv6_addresses: List[ipaddress.IPv6Address] = [] + self._ipv4_addresses: List[IPv4Address] = [] + self._ipv6_addresses: List[IPv6Address] = [] if addresses is not None: self.addresses = addresses elif parsed_addresses is not None: @@ -260,7 +260,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: def ip_addresses_by_version( self, version: IPVersion - ) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]: + ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: """List ip_address objects matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -273,7 +273,7 @@ def ip_addresses_by_version( def _ip_addresses_by_version_value( self, version_value: int - ) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]: + ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] @@ -366,31 +366,31 @@ def get_name(self) -> str: def _get_ip_addresses_from_cache_lifo( self, zc: 'Zeroconf', now: float, type: int - ) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]: + ) -> List[Union[IPv4Address, IPv6Address]]: """Set IPv6 addresses from the cache.""" - address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = [] + address_list: List[Union[IPv4Address, IPv6Address]] = [] for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue try: - ip_address = _cached_ip_addresses(record.address) + ip_addr = _cached_ip_addresses(record.address) except ValueError: continue else: - address_list.append(ip_address) + address_list.append(ip_addr) address_list.reverse() # Reverse to get LIFO order return address_list def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv6 addresses from the cache.""" self._ipv6_addresses = cast( - "List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) + "List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) ) def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: """Set IPv4 addresses from the cache.""" self._ipv4_addresses = cast( - "List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) + "List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) ) def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: @@ -431,46 +431,49 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo if record.is_expired(now): return False - if record.key == self.server_key and isinstance(record, DNSAddress): + record_key = record.key + if record_key == self.server_key and type(record) is DNSAddress: try: ip_addr = _cached_ip_addresses(record.address) except ValueError as ex: log.warning("Encountered invalid address while processing %s: %s", record, ex) return False - if ip_addr.version == 4: - if not self._ipv4_addresses: + if type(ip_addr) is IPv4Address: + if self._ipv4_addresses: self._set_ipv4_addresses_from_cache(zc, now) - if ip_addr not in self._ipv4_addresses: - self._ipv4_addresses.insert(0, ip_addr) + ipv4_addresses = self._ipv4_addresses + if ip_addr not in ipv4_addresses: + ipv4_addresses.insert(0, ip_addr) return True - elif ip_addr != self._ipv4_addresses[0]: - self._ipv4_addresses.remove(ip_addr) - self._ipv4_addresses.insert(0, ip_addr) + elif ip_addr != ipv4_addresses[0]: + ipv4_addresses.remove(ip_addr) + ipv4_addresses.insert(0, ip_addr) return False if not self._ipv6_addresses: self._set_ipv6_addresses_from_cache(zc, now) + ipv6_addresses = self._ipv6_addresses if ip_addr not in self._ipv6_addresses: - self._ipv6_addresses.insert(0, ip_addr) + ipv6_addresses.insert(0, ip_addr) return True elif ip_addr != self._ipv6_addresses[0]: - self._ipv6_addresses.remove(ip_addr) - self._ipv6_addresses.insert(0, ip_addr) + ipv6_addresses.remove(ip_addr) + ipv6_addresses.insert(0, ip_addr) return False - if record.key != self.key: + if record_key != self.key: return False - if record.type == _TYPE_TXT and isinstance(record, DNSText): + if record.type == _TYPE_TXT and type(record) is DNSText: self._set_text(record.text) return True - if record.type == _TYPE_SRV and isinstance(record, DNSService): + if record.type == _TYPE_SRV and type(record) is DNSService: old_server_key = self.server_key self.name = record.name self.server = record.server @@ -495,16 +498,17 @@ def dns_addresses( name = self.server or self.name ttl = override_ttl if override_ttl is not None else self.host_ttl class_ = _CLASS_IN | _CLASS_UNIQUE + version_value = version.value return [ DNSAddress( name, - _TYPE_AAAA if address.version == 6 else _TYPE_A, + _TYPE_AAAA if type(ip_addr) is IPv6Address else _TYPE_A, class_, ttl, - address.packed, + ip_addr.packed, created=created, ) - for address in self._ip_addresses_by_version_value(version.value) + for ip_addr in self._ip_addresses_by_version_value(version_value) ] def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: