Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ def on_service_state_change(
parser.add_argument('--debug', action='store_true')
parser.add_argument('--find', action='store_true', help='Browse all available services')
version_group = parser.add_mutually_exclusive_group()
version_group.add_argument('--v6', action='store_true')
version_group.add_argument('--v6-only', action='store_true')
version_group.add_argument('--v4-only', action='store_true')
args = parser.parse_args()

if args.debug:
logging.getLogger('zeroconf').setLevel(logging.DEBUG)
if args.v6:
ip_version = IPVersion.All
elif args.v6_only:
if args.v6_only:
ip_version = IPVersion.V6Only
else:
elif args.v4_only:
ip_version = IPVersion.V4Only
else:
ip_version = IPVersion.All

zeroconf = Zeroconf(ip_version=ip_version)

Expand Down
8 changes: 8 additions & 0 deletions src/zeroconf/_services/info.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ cdef object _IPVersion_V4Only_value
cdef cython.set _ADDRESS_RECORD_TYPES

cdef bint TYPE_CHECKING
cdef bint IPADDRESS_SUPPORTS_SCOPE_ID

cdef _get_ip_address_object_from_record(DNSAddress record)

@cython.locals(address_str=str)
cdef _str_without_scope_id(object addr)

cdef _ip_bytes_and_scope_to_address(object addr, object scope_id)

cdef class ServiceInfo(RecordUpdateListener):

Expand Down
53 changes: 39 additions & 14 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import asyncio
import random
import sys
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, _BaseAddress, ip_address
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
Expand Down Expand Up @@ -78,12 +79,15 @@
# the A/AAAA/SRV records for a host.
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

bytes_ = bytes
float_ = float
int_ = int

DNS_QUESTION_TYPE_QU = DNSQuestionType.QU
DNS_QUESTION_TYPE_QM = DNSQuestionType.QM

IPADDRESS_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)

if TYPE_CHECKING:
from .._core import Zeroconf

Expand All @@ -110,6 +114,29 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4
_cached_ip_addresses_wrapper = _cached_ip_addresses


def _get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Get the IP address object from the record."""
if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None:
return _ip_bytes_and_scope_to_address(record.address, record.scope_id)
return _cached_ip_addresses_wrapper(record.address)


def _ip_bytes_and_scope_to_address(address: bytes_, scope: int_) -> Optional[Union[IPv4Address, IPv6Address]]:
"""Convert the bytes and scope to an IP address object."""
base_address = _cached_ip_addresses_wrapper(address)
if base_address is not None and base_address.is_link_local:
return _cached_ip_addresses_wrapper(f"{base_address}%{scope}")
return base_address


def _str_without_scope_id(addr: Union[IPv4Address, IPv6Address]) -> str:
"""Return the string representation of the address without the scope id."""
if IPADDRESS_SUPPORTS_SCOPE_ID and addr.version == 6:
address_str = str(addr)
return address_str.partition('%')[0]
return str(addr)


class ServiceInfo(RecordUpdateListener):
"""Service information.

Expand Down Expand Up @@ -177,6 +204,7 @@ def __init__(
raise TypeError("addresses and parsed_addresses cannot be provided together")
if not type_.endswith(service_type_name(name, strict=False)):
raise BadTypeInNameException
self.interface_index = interface_index
self.text = b''
self.type = type_
self._name = name
Expand All @@ -199,7 +227,6 @@ def __init__(
self._set_properties(properties)
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: Optional[Set[asyncio.Future]] = None
self._dns_address_cache: Optional[List[DNSAddress]] = None
self._dns_pointer_cache: Optional[DNSPointer] = None
Expand Down Expand Up @@ -243,7 +270,10 @@ def addresses(self, value: List[bytes]) -> None:
self._get_address_and_nsec_records_cache = None

for address in value:
addr = _cached_ip_addresses_wrapper(address)
if IPADDRESS_SUPPORTS_SCOPE_ID and len(address) == 16 and self.interface_index is not None:
addr = _ip_bytes_and_scope_to_address(address, self.interface_index)
else:
addr = _cached_ip_addresses_wrapper(address)
if addr is None:
raise TypeError(
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
Expand Down Expand Up @@ -322,10 +352,10 @@ def ip_addresses_by_version(

def _ip_addresses_by_version_value(
self, version_value: int_
) -> Union[List[IPv4Address], List[IPv6Address], List[_BaseAddress]]:
) -> Union[List[IPv4Address], List[IPv6Address]]:
"""Backend for addresses_by_version that uses the raw value."""
if version_value == _IPVersion_All_value:
return [*self._ipv4_addresses, *self._ipv6_addresses]
return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value]
if version_value == _IPVersion_V4Only_value:
return self._ipv4_addresses
return self._ipv6_addresses
Expand All @@ -339,7 +369,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
This means the first address will always be the most recently added
address of the given IP version.
"""
return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]
return [_str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)]

def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
"""Equivalent to parsed_addresses, with the exception that IPv6 Link-Local
Expand All @@ -351,12 +381,7 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st
This means the first address will always be the most recently added
address of the given IP version.
"""
if self.interface_index is None:
return self.parsed_addresses(version)
return [
f"{addr}%{self.interface_index}" if addr.version == 6 and addr.is_link_local else str(addr)
for addr in self._ip_addresses_by_version_value(version.value)
]
return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)]

def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None:
"""Sets properties and text of this info from a dictionary"""
Expand Down Expand Up @@ -421,8 +446,8 @@ def _get_ip_addresses_from_cache_lifo(
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
ip_addr = _cached_ip_addresses_wrapper(record.address)
if ip_addr is not None:
ip_addr = _get_ip_address_object_from_record(record)
if ip_addr is not None and ip_addr not in address_list:
address_list.append(ip_addr)
address_list.reverse() # Reverse to get LIFO order
return address_list
Expand Down Expand Up @@ -471,7 +496,7 @@ def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: flo
dns_address_record = record
if TYPE_CHECKING:
assert isinstance(dns_address_record, DNSAddress)
ip_addr = _cached_ip_addresses_wrapper(dns_address_record.address)
ip_addr = _get_ip_address_object_from_record(dns_address_record)
if ip_addr is None:
log.warning(
"Encountered invalid address while processing %s: %s",
Expand Down
59 changes: 55 additions & 4 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import socket
import sys
import threading
import unittest
from ipaddress import ip_address
Expand Down Expand Up @@ -538,6 +539,7 @@ def test_multiple_addresses():
assert info.addresses == [address, address]
assert info.parsed_addresses() == [address_parsed, address_parsed]
assert info.parsed_scoped_addresses() == [address_parsed, address_parsed]
ipaddress_supports_scope_id = sys.version_info >= (3, 9, 0)

if has_working_ipv6() and not os.environ.get('SKIP_IPV6'):
address_v6_parsed = "2001:db8::1"
Expand Down Expand Up @@ -576,30 +578,79 @@ def test_multiple_addresses():
assert info.ip_addresses_by_version(r.IPVersion.All) == [
ip_address(address),
ip_address(address_v6),
ip_address(address_v6_ll),
ip_address(address_v6_ll_scoped_parsed)
if ipaddress_supports_scope_id
else ip_address(address_v6_ll),
]
assert info.addresses_by_version(r.IPVersion.V4Only) == [address]
assert info.ip_addresses_by_version(r.IPVersion.V4Only) == [ip_address(address)]
assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll]
assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [
ip_address(address_v6),
ip_address(address_v6_ll),
ip_address(address_v6_ll_scoped_parsed)
if ipaddress_supports_scope_id
else ip_address(address_v6_ll),
]
assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed]
assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed]
assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed]
assert info.parsed_scoped_addresses() == [
address_parsed,
address_v6_parsed,
address_v6_ll_scoped_parsed,
address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed,
]
assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed]
assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [
address_v6_parsed,
address_v6_ll_scoped_parsed,
address_v6_ll_scoped_parsed if ipaddress_supports_scope_id else address_v6_ll_parsed,
]


@unittest.skipIf(sys.version_info < (3, 9, 0), 'Requires newer python')
def test_scoped_addresses_from_cache():
type_ = "_http._tcp.local."
registration_name = f"scoped.{type_}"
zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
host = "scoped.local."

zeroconf.cache.async_add_records(
[
r.DNSPointer(
type_,
const._TYPE_PTR,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
registration_name,
),
r.DNSService(
registration_name,
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
0,
0,
80,
host,
),
r.DNSAddress(
host,
const._TYPE_AAAA,
const._CLASS_IN | const._CLASS_UNIQUE,
120,
socket.inet_pton(socket.AF_INET6, "fe80::52e:c2f2:bc5f:e9c6"),
scope_id=12,
),
]
)

# New kwarg way
info = ServiceInfo(type_, registration_name)
info.load_from_cache(zeroconf)
assert info.parsed_scoped_addresses() == ["fe80::52e:c2f2:bc5f:e9c6%12"]
assert info.ip_addresses_by_version(r.IPVersion.V6Only) == [ip_address("fe80::52e:c2f2:bc5f:e9c6%12")]
zeroconf.close()


# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
Expand Down