Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ async def async_register_service(
info.host_ttl = ttl
info.other_ttl = ttl

info.set_server_if_missing()
await self.async_wait_for_start()
await self.async_check_service(info, allow_name_change, cooperating_responders)
self.registry.async_add(info)
Expand Down Expand Up @@ -738,10 +739,12 @@ def unregister_service(self, info: ServiceInfo) -> None:

async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
"""Unregister a service."""
info.set_server_if_missing()
self.registry.async_remove(info)
# If another server uses the same addresses, we do not want to send
# goodbye packets for the address records

assert info.server is not None
entries = self.registry.async_get_infos_server(info.server)
broadcast_addresses = not bool(entries)
return asyncio.ensure_future(
Expand Down
3 changes: 3 additions & 0 deletions src/zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRe
records.add(dns_address)
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
records.add(construct_nsec_record(service.server, list(missing_types), now))
return records

Expand Down Expand Up @@ -310,10 +311,12 @@ def _add_address_answers(
missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
if answers:
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
additionals.add(construct_nsec_record(service.server, list(missing_types), now))
for answer in answers:
answer_set[answer] = additionals
elif type_ in missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()

def _answer_question(
Expand Down
173 changes: 116 additions & 57 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import ipaddress
import random
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast

from .._dns import (
DNSAddress,
Expand Down Expand Up @@ -156,8 +156,8 @@ def __init__(
self.port = port
self.weight = weight
self.priority = priority
self.server = server if server else name
self.server_key = self.server.lower()
self.server = server if server else None
self.server_key = server.lower() if server else None
self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {}
if isinstance(properties, bytes):
self._set_text(properties)
Expand Down Expand Up @@ -205,7 +205,7 @@ def addresses(self, value: List[bytes]) -> None:
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
f" got {address!r}. Hint: convert string addresses with socket.inet_pton"
)
if isinstance(addr, ipaddress.IPv4Address):
if addr.version == 4:
self._ipv4_addresses.append(addr)
else:
self._ipv6_addresses.append(addr)
Expand Down Expand Up @@ -339,6 +339,35 @@ def get_name(self) -> str:
"""Name accessor"""
return self.name[: len(self.name) - len(self.type) - 1]

def _get_ip_addresses_from_cache_lifo(
self, zc: 'Zeroconf', now: float, type: int
) -> List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]:
"""Set IPv6 addresses from the cache."""
address_list: List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] = []
for record in self._get_address_records_from_cache_by_type(zc, type):
if record.is_expired(now):
continue
try:
ip_address = _cached_ip_addresses(record.address)
except ValueError:
continue
else:
address_list.append(ip_address)
address_list.reverse() # Reverse to get LIFO order
return address_list

def _set_ipv6_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv6 addresses from the cache."""
self._ipv6_addresses = cast(
"List[ipaddress.IPv6Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA)
)

def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
"""Set IPv4 addresses from the cache."""
self._ipv4_addresses = cast(
"List[ipaddress.IPv4Address]", self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)
)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
"""Updates service information from a DNS record.

Expand All @@ -348,7 +377,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord])
This method will be run in the event loop.
"""
if record is not None:
self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)])
self._process_record_threadsafe(zc, record, now)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Updates service information from a DNS record.
Expand All @@ -357,55 +386,77 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
"""
self._process_records_threadsafe(zc, now, records)

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Thread safe record updating."""
seen_addresses: Set[bytes] = set()
def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.

Returns True if new records were added.
"""
updated: bool = False
for record_update in records:
record = record_update.new
if isinstance(record, DNSAddress):
seen_addresses.add(record.address)
self._process_record_threadsafe(record, now)
for record in self._get_address_records_from_cache(zc):
if record.address not in seen_addresses:
self._process_record_threadsafe(record, now)

def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None:
"""Thread safe record updating."""
updated |= self._process_record_threadsafe(zc, record_update.new, now)
return updated

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
"""Thread safe record updating.

Returns True if a new record was added.
"""
if record.is_expired(now):
return
return False

if isinstance(record, DNSAddress):
if record.key != self.server_key:
return
if record.key == self.server_key and isinstance(record, DNSAddress):
try:
ip_addr = _cached_ip_addresses(record.address)
except ValueError as ex:
log.warning("Encountered invalid address while processing %s: %s", record, ex)
return
if isinstance(ip_addr, ipaddress.IPv4Address):
return False

if ip_addr.version == 4:
if not self._ipv4_addresses:
self._set_ipv4_addresses_from_cache(zc, now)

if ip_addr not in self._ipv4_addresses:
self._ipv4_addresses.insert(0, ip_addr)
return
return True
elif ip_addr != self._ipv4_addresses[0]:
self._ipv4_addresses.remove(ip_addr)
self._ipv4_addresses.insert(0, ip_addr)

return False

if not self._ipv6_addresses:
self._set_ipv6_addresses_from_cache(zc, now)

if ip_addr not in self._ipv6_addresses:
self._ipv6_addresses.insert(0, ip_addr)
if ip_addr.is_link_local:
self.interface_index = record.scope_id
return
return True
elif ip_addr != self._ipv6_addresses[0]:
self._ipv6_addresses.remove(ip_addr)
self._ipv6_addresses.insert(0, ip_addr)

if isinstance(record, DNSText):
if record.key == self.key:
self._set_text(record.text)
return
return False

if record.key != self.key:
return False

if record.type == _TYPE_TXT and isinstance(record, DNSText):
self._set_text(record.text)
return True

if isinstance(record, DNSService):
if record.key != self.key:
return
if record.type == _TYPE_SRV and isinstance(record, DNSService):
old_server_key = self.server_key
self.name = record.name
self.server = record.server
self.server_key = record.server.lower()
self.port = record.port
self.weight = record.weight
self.priority = record.priority
if old_server_key != self.server_key:
self._set_ipv4_addresses_from_cache(zc, now)
self._set_ipv6_addresses_from_cache(zc, now)
return True

return False

def dns_addresses(
self,
Expand All @@ -416,7 +467,7 @@ def dns_addresses(
"""Return matching DNSAddress from ServiceInfo."""
return [
DNSAddress(
self.server,
self.server or self.name,
_TYPE_AAAA if address.version == 6 else _TYPE_A,
_CLASS_IN | _CLASS_UNIQUE,
override_ttl if override_ttl is not None else self.host_ttl,
Expand Down Expand Up @@ -447,7 +498,7 @@ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[floa
self.priority,
self.weight,
cast(int, self.port),
self.server,
self.server or self.name,
created,
)

Expand All @@ -462,35 +513,43 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]
created,
)

def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSAddress]:
"""Get the address records from the cache."""
return cast(
"List[DNSAddress]",
[
*zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN),
*zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN),
],
)
def _get_address_records_from_cache_by_type(self, zc: 'Zeroconf', _type: int) -> List[DNSAddress]:
"""Get the addresses from the cache."""
if self.server_key is None:
return []
return cast("List[DNSAddress]", zc.cache.get_all_by_details(self.server_key, _type, _CLASS_IN))

def set_server_if_missing(self) -> None:
"""Set the server if it is missing.

This function is for backwards compatibility.
"""
if self.server is None:
self.server = self.name
self.server_key = self.server.lower()

def load_from_cache(self, zc: 'Zeroconf') -> bool:
"""Populate the service info from the cache.

This method is designed to be threadsafe.
"""
now = current_time_millis()
record_updates: List[RecordUpdate] = []
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
# If there is a srv record, A and AAAA will already
# be called and we do not want to do it twice
record_updates.append(RecordUpdate(cached_srv_record, None))
else:
for record in self._get_address_records_from_cache(zc):
record_updates.append(RecordUpdate(record, None))
self._process_record_threadsafe(zc, cached_srv_record, now)
cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)
if cached_txt_record:
record_updates.append(RecordUpdate(cached_txt_record, None))
self._process_records_threadsafe(zc, now, record_updates)
self._process_record_threadsafe(zc, cached_txt_record, now)
if original_server_key == self.server_key:
# If there is a srv which changes the server_key,
# A and AAAA will already be loaded from the cache
# and we do not want to do it twice
for record in [
*self._get_address_records_from_cache_by_type(zc, _TYPE_A),
*self._get_address_records_from_cache_by_type(zc, _TYPE_AAAA),
]:
self._process_record_threadsafe(zc, record, now)
return self._is_complete

@property
Expand Down Expand Up @@ -560,8 +619,8 @@ def generate_request_query(
out = DNSOutgoing(_FLAGS_QR_QUERY)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_A, _CLASS_IN)
out.add_question_or_all_cache(zc.cache, now, self.server or self.name, _TYPE_AAAA, _CLASS_IN)
if question_type == DNSQuestionType.QU:
for question in out.questions:
question.unicast = True
Expand Down
2 changes: 2 additions & 0 deletions src/zeroconf/_services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[Servic

def _add(self, info: ServiceInfo) -> None:
"""Add a new service under the lock."""
assert info.server_key is not None, "ServiceInfo must have a server"
if info.key in self._services:
raise ServiceNameAlreadyRegistered

Expand All @@ -93,6 +94,7 @@ def _remove(self, infos: List[ServiceInfo]) -> None:
if info.key not in self._services:
continue
old_service_info = self._services[info.key]
assert old_service_info.server_key is not None
self.types[old_service_info.type.lower()].remove(info.key)
self.servers[old_service_info.server_key].remove(info.key)
del self._services[info.key]
Loading