From 060bdc6fb81d1888151a697262aa8be7cda1f842 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 7 Sep 2024 19:55:51 -0500 Subject: [PATCH 1/2] feat: improve performance when IP addresses change frequently --- src/zeroconf/_services/info.py | 39 ++++++++++++++++++++------------ src/zeroconf/_utils/ipaddress.py | 28 ++++++++--------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 2fc9dfc8e..fef43fa02 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -23,7 +23,6 @@ import asyncio import random import sys -from ipaddress import IPv4Address, IPv6Address, _BaseAddress from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast from .._cache import DNSCache @@ -50,6 +49,8 @@ wait_for_future_set_or_timeout, ) from .._utils.ipaddress import ( + ZeroconfIPv4Address, + ZeroconfIPv6Address, cached_ip_addresses, get_ip_address_object_from_record, ip_bytes_and_scope_to_address, @@ -187,8 +188,8 @@ def __init__( self.type = type_ self._name = name self.key = name.lower() - self._ipv4_addresses: List[IPv4Address] = [] - self._ipv6_addresses: List[IPv6Address] = [] + self._ipv4_addresses: List[ZeroconfIPv4Address] = [] + self._ipv6_addresses: List[ZeroconfIPv6Address] = [] if addresses is not None: self.addresses = addresses elif parsed_addresses is not None: @@ -260,11 +261,11 @@ def addresses(self, value: List[bytes]) -> None: ) if addr.version == 4: if TYPE_CHECKING: - assert isinstance(addr, IPv4Address) + assert isinstance(addr, ZeroconfIPv4Address) self._ipv4_addresses.append(addr) else: if TYPE_CHECKING: - assert isinstance(addr, IPv6Address) + assert isinstance(addr, ZeroconfIPv6Address) self._ipv6_addresses.append(addr) @property @@ -321,7 +322,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: def ip_addresses_by_version( self, version: IPVersion - ) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]: + ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: """List ip_address objects matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -334,7 +335,7 @@ def ip_addresses_by_version( def _ip_addresses_by_version_value( self, version_value: int_ - ) -> Union[List[IPv4Address], List[IPv6Address]]: + ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] @@ -440,9 +441,9 @@ def get_name(self) -> str: def _get_ip_addresses_from_cache_lifo( self, zc: "Zeroconf", now: float_, type: int_ - ) -> List[Union[IPv4Address, IPv6Address]]: + ) -> List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: """Set IPv6 addresses from the cache.""" - address_list: List[Union[IPv4Address, IPv6Address]] = [] + address_list: List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]] = [] for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue @@ -456,7 +457,7 @@ def _set_ipv6_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: """Set IPv6 addresses from the cache.""" if TYPE_CHECKING: self._ipv6_addresses = cast( - "List[IPv6Address]", + "List[ZeroconfIPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA), ) else: @@ -466,7 +467,7 @@ def _set_ipv4_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: """Set IPv4 addresses from the cache.""" if TYPE_CHECKING: self._ipv4_addresses = cast( - "List[IPv4Address]", + "List[ZeroconfIPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A), ) else: @@ -509,24 +510,32 @@ def _process_record_threadsafe(self, zc: "Zeroconf", record: DNSRecord, now: flo if ip_addr.version == 4: if TYPE_CHECKING: - assert isinstance(ip_addr, IPv4Address) + assert isinstance(ip_addr, ZeroconfIPv4Address) ipv4_addresses = self._ipv4_addresses if ip_addr not in ipv4_addresses: ipv4_addresses.insert(0, ip_addr) return True - elif ip_addr != ipv4_addresses[0]: + # Use int() to compare the addresses as integers + # since by default IPv4Address.__eq__ compares the + # the addresses on version and int which more than + # we need here since we know the version is 4. + elif ip_addr.zc_integer != ipv4_addresses[0].zc_integer: ipv4_addresses.remove(ip_addr) ipv4_addresses.insert(0, ip_addr) return False if TYPE_CHECKING: - assert isinstance(ip_addr, IPv6Address) + assert isinstance(ip_addr, ZeroconfIPv6Address) ipv6_addresses = self._ipv6_addresses if ip_addr not in self._ipv6_addresses: ipv6_addresses.insert(0, ip_addr) return True - elif ip_addr != self._ipv6_addresses[0]: + # Use int() to compare the addresses as integers + # since by default IPv6Address.__eq__ compares the + # the addresses on version and int which more than + # we need here since we know the version is 6. + elif ip_addr.zc_integer != self._ipv6_addresses[0].zc_integer: ipv6_addresses.remove(ip_addr) ipv6_addresses.insert(0, ip_addr) diff --git a/src/zeroconf/_utils/ipaddress.py b/src/zeroconf/_utils/ipaddress.py index 3346e6d7b..72bb9ce83 100644 --- a/src/zeroconf/_utils/ipaddress.py +++ b/src/zeroconf/_utils/ipaddress.py @@ -39,13 +39,7 @@ class ZeroconfIPv4Address(IPv4Address): - __slots__ = ( - "_str", - "_is_link_local", - "_is_unspecified", - "_is_loopback", - "__hash__", - ) + __slots__ = ("_str", "_is_link_local", "_is_unspecified", "_is_loopback", "__hash__", "zc_integer") def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize a new IPv4 address.""" @@ -55,6 +49,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._is_unspecified = super().is_unspecified self._is_loopback = super().is_loopback self.__hash__ = cache(lambda: IPv4Address.__hash__(self)) # type: ignore[method-assign] + self.zc_integer = int(self) def __str__(self) -> str: """Return the string representation of the IPv4 address.""" @@ -77,13 +72,7 @@ def is_loopback(self) -> bool: class ZeroconfIPv6Address(IPv6Address): - __slots__ = ( - "_str", - "_is_link_local", - "_is_unspecified", - "_is_loopback", - "__hash__", - ) + __slots__ = ("_str", "_is_link_local", "_is_unspecified", "_is_loopback", "__hash__", "zc_integer") def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize a new IPv6 address.""" @@ -93,6 +82,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._is_unspecified = super().is_unspecified self._is_loopback = super().is_loopback self.__hash__ = cache(lambda: IPv6Address.__hash__(self)) # type: ignore[method-assign] + self.zc_integer = int(self) def __str__(self) -> str: """Return the string representation of the IPv6 address.""" @@ -117,7 +107,7 @@ def is_loopback(self) -> bool: @lru_cache(maxsize=512) def _cached_ip_addresses( address: Union[str, bytes, int], -) -> Optional[Union[IPv4Address, IPv6Address]]: +) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: """Cache IP addresses.""" try: return ZeroconfIPv4Address(address) @@ -136,14 +126,16 @@ def _cached_ip_addresses( def get_ip_address_object_from_record( record: DNSAddress, -) -> Optional[Union[IPv4Address, IPv6Address]]: +) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: """Get the IP address object from the record.""" if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id: return ip_bytes_and_scope_to_address(record.address, record.scope_id) return cached_ip_addresses_wrapper(record.address) -def ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]: +def ip_bytes_and_scope_to_address( + address: bytes_, scope: int_ +) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: """Convert the bytes and scope to an IP address object.""" base_address = cached_ip_addresses_wrapper(address) if base_address is not None and base_address.is_link_local: @@ -152,7 +144,7 @@ def ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Unio return base_address -def str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str: +def str_without_scope_id(addr: Union[ZeroconfIPv4Address, ZeroconfIPv6Address]) -> str: """Return the string representation of the address without the scope id.""" if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6: address_str = str(addr) From 2329b39c4b981309fbebd8c7225e7287ba327ff8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 7 Sep 2024 20:57:05 -0500 Subject: [PATCH 2/2] add missing cover --- tests/services/test_info.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 4a9b1ee2f..9d4a4958f 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -1469,6 +1469,10 @@ async def test_ipv6_changes_are_seen(): assert info.addresses_by_version(IPVersion.V6Only) == [ b"\xde\xad\xbe\xef\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ] + info.load_from_cache(aiozc.zeroconf) + assert info.addresses_by_version(IPVersion.V6Only) == [ + b"\xde\xad\xbe\xef\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ] generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( @@ -1494,6 +1498,7 @@ async def test_ipv6_changes_are_seen(): b"\x00\xad\xbe\xef\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", b"\xde\xad\xbe\xef\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", ] + await aiozc.async_close()