Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/zeroconf/_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:


class Signal:

__slots__ = ('_handlers',)

def __init__(self) -> None:
Expand All @@ -62,7 +61,6 @@ def registration_interface(self) -> 'SignalRegistrationInterface':


class SignalRegistrationInterface:

__slots__ = ('_handlers',)

def __init__(self, handlers: List[Callable[..., None]]) -> None:
Expand Down
62 changes: 33 additions & 29 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
"""

import asyncio
import ipaddress
import random
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast

from .._dns import (
Expand Down Expand Up @@ -90,7 +90,7 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str:
return info.name[: -len(service_name) - 1]


_cached_ip_addresses = lru_cache(maxsize=256)(ipaddress.ip_address)
_cached_ip_addresses = lru_cache(maxsize=256)(ip_address)


class ServiceInfo(RecordUpdateListener):
Expand Down Expand Up @@ -158,8 +158,8 @@ def __init__(
self.type = type_
self._name = name
self.key = name.lower()
self._ipv4_addresses: List[ipaddress.IPv4Address] = []
self._ipv6_addresses: List[ipaddress.IPv6Address] = []
self._ipv4_addresses: List[IPv4Address] = []
self._ipv6_addresses: List[IPv6Address] = []
if addresses is not None:
self.addresses = addresses
elif parsed_addresses is not None:
Expand Down Expand Up @@ -260,7 +260,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]:

def ip_addresses_by_version(
self, version: IPVersion
) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""List ip_address objects matching IP version.

Addresses are guaranteed to be returned in LIFO (last in, first out)
Expand All @@ -273,7 +273,7 @@ def ip_addresses_by_version(

def _ip_addresses_by_version_value(
self, version_value: int
) -> Union[List[ipaddress.IPv4Address], List[ipaddress.IPv6Address], List[ipaddress._BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
return [*self._ipv4_addresses, *self._ipv6_addresses]
Expand Down Expand Up @@ -366,31 +366,31 @@ def get_name(self) -> str:

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
) -> List[Union[IPv4Address, IPv6Address]]:
"""Set IPv6 addresses from the cache."""
address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = []
address_list: List[Union[IPv4Address, IPv6Address]] = []
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
try:
ip_address = _cached_ip_addresses(record.address)
ip_addr = _cached_ip_addresses(record.address)
except ValueError:
continue
else:
address_list.append(ip_address)
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
return address_list

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv6 addresses from the cache."""
self._ipv6_addresses = cast(
"List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
"List[IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv4 addresses from the cache."""
self._ipv4_addresses = cast(
"List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
"List[IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
Expand Down Expand Up @@ -431,46 +431,49 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
if record.is_expired(now):
return False

if record.key == self.server_key and isinstance(record, DNSAddress):
record_key = record.key
if record_key == self.server_key and type(record) is DNSAddress:
Comment thread
bdraco marked this conversation as resolved.
try:
ip_addr = _cached_ip_addresses(record.address)
except ValueError as ex:
log.warning("Encountered invalid address while processing %s: %s", record, ex)
return False

if ip_addr.version == 4:
if not self._ipv4_addresses:
if type(ip_addr) is IPv4Address:
if self._ipv4_addresses:
self._set_ipv4_addresses_from_cache(zc, now)

if ip_addr not in self._ipv4_addresses:
self._ipv4_addresses.insert(0, ip_addr)
ipv4_addresses = self._ipv4_addresses
if ip_addr not in ipv4_addresses:
ipv4_addresses.insert(0, ip_addr)
return True
elif ip_addr != self._ipv4_addresses[0]:
self._ipv4_addresses.remove(ip_addr)
self._ipv4_addresses.insert(0, ip_addr)
elif ip_addr != ipv4_addresses[0]:
ipv4_addresses.remove(ip_addr)
ipv4_addresses.insert(0, ip_addr)

return False

if not self._ipv6_addresses:
self._set_ipv6_addresses_from_cache(zc, now)

ipv6_addresses = self._ipv6_addresses
if ip_addr not in self._ipv6_addresses:
self._ipv6_addresses.insert(0, ip_addr)
ipv6_addresses.insert(0, ip_addr)
return True
elif ip_addr != self._ipv6_addresses[0]:
self._ipv6_addresses.remove(ip_addr)
self._ipv6_addresses.insert(0, ip_addr)
ipv6_addresses.remove(ip_addr)
ipv6_addresses.insert(0, ip_addr)

return False

if record.key != self.key:
if record_key != self.key:
return False

if record.type == _TYPE_TXT and isinstance(record, DNSText):
if record.type == _TYPE_TXT and type(record) is DNSText:
self._set_text(record.text)
return True

if record.type == _TYPE_SRV and isinstance(record, DNSService):
if record.type == _TYPE_SRV and type(record) is DNSService:
old_server_key = self.server_key
self.name = record.name
self.server = record.server
Expand All @@ -495,16 +498,17 @@ def dns_addresses(
name = self.server or self.name
ttl = override_ttl if override_ttl is not None else self.host_ttl
class_ = _CLASS_IN | _CLASS_UNIQUE
version_value = version.value
return [
DNSAddress(
name,
_TYPE_AAAA if address.version == 6 else _TYPE_A,
_TYPE_AAAA if type(ip_addr) is IPv6Address else _TYPE_A,
class_,
ttl,
address.packed,
ip_addr.packed,
created=created,
)
for address in self._ip_addresses_by_version_value(version.value)
for ip_addr in self._ip_addresses_by_version_value(version_value)
]

def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
Expand Down