diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 2f5ef507..a8986211 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -29,7 +29,7 @@ import sys import threading from types import TracebackType # noqa # used in type hints -from typing import Dict, List, Optional, Tuple, Type, Union, cast +from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType @@ -43,7 +43,7 @@ from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener -from ._utils.asyncio import get_running_loop, shutdown_loop, wait_event_or_timeout +from ._utils.asyncio import await_awaitable, get_running_loop, shutdown_loop, wait_event_or_timeout from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -74,6 +74,7 @@ _TC_DELAY_RANDOM_INTERVAL = (400, 500) _CLOSE_TIMEOUT = 3 +_REGISTER_BROADCASTS = 3 class AsyncEngine: @@ -478,6 +479,27 @@ def register_service( allow_name_change: bool = False, cooperating_responders: bool = False, ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe( + await_awaitable( + self.async_register_service(info, ttl, allow_name_change, cooperating_responders) + ), + self.loop, + ).result(millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service. The name of the service may be changed if needed to make @@ -489,36 +511,34 @@ def register_service( # Setting TTLs via ServiceInfo is preferred info.host_ttl = ttl info.other_ttl = ttl - self.check_service(info, allow_name_change, cooperating_responders) + + await self.async_wait_for_start() + await self.async_check_service(info, allow_name_change, cooperating_responders) self.registry.add(info) - self._broadcast_service(info, _REGISTER_TIME, None) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) def update_service(self, info: ServiceInfo) -> None: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service.""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe(await_awaitable(self.async_update_service(info)), self.loop).result( + millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT + ) + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" self.registry.update(info) - self._broadcast_service(info, _REGISTER_TIME, None) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - - self.send_service_broadcast(info, ttl) - i += 1 - next_time += interval - - def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: - """Send a broadcast to announce a service.""" - self.send(self.generate_service_broadcast(info, ttl)) + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(interval)) + self.async_send(self.generate_service_broadcast(info, ttl)) def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: """Generate a broadcast to announce a service.""" @@ -526,10 +546,6 @@ def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> D self._add_broadcast_answer(out, info, ttl) return out - def send_service_query(self, info: ServiceInfo) -> None: - """Send a query to lookup a service.""" - self.send(self.generate_service_query(info)) - def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) @@ -559,9 +575,16 @@ def _add_broadcast_answer( # pylint: disable=no-self-use out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service.""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe( + await_awaitable(self.async_unregister_service(info)), self.loop + ).result(millis_to_seconds(_UNREGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service.""" self.registry.remove(info) - self._broadcast_service(info, _UNREGISTER_TIME, 0) + return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" @@ -574,6 +597,22 @@ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: self.registry.remove(service_infos) return out + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + out = self.generate_unregister_all_services() + if not out: + return + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) + self.async_send(out) + def unregister_all_services(self) -> None: """Unregister all registered services.""" # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 @@ -592,7 +631,7 @@ def unregister_all_services(self) -> None: i += 1 next_time += _UNREGISTER_TIME - def check_service( + async def async_check_service( self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False ) -> None: """Checks the network for a unique service name, modifying the @@ -603,7 +642,7 @@ def check_service( next_instance_number = 2 next_time = now = current_time_millis() i = 0 - while i < 3: + while i < _REGISTER_BROADCASTS: # check for a name conflict while self.cache.current_entry_with_name_and_alias(info.type, info.name): if not allow_name_change: @@ -617,11 +656,11 @@ def check_service( i = 0 if now < next_time: - self.wait(next_time - now) + await self.async_wait(next_time - now) now = current_time_millis() continue - self.send_service_query(info) + self.async_send(self.generate_service_query(info)) i += 1 next_time += _CHECK_TIME diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py index c68c0f00..395c331b 100644 --- a/zeroconf/_utils/asyncio.py +++ b/zeroconf/_utils/asyncio.py @@ -23,8 +23,9 @@ import asyncio import contextlib import queue -from typing import Any, List, Optional, Set, cast +from typing import Any, Awaitable, List, Optional, Set, cast +# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT _TASK_AWAIT_TIMEOUT = 1 _GET_ALL_TASKS_TIMEOUT = 3 _WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT @@ -80,6 +81,12 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) +async def await_awaitable(aw: Awaitable) -> None: + """Wait on an awaitable and the task it returns.""" + task = await aw + await task + + def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" pending_tasks = set( diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 67ff1c12..08478044 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -26,20 +26,15 @@ from ._core import Zeroconf from ._dns import DNSQuestionType -from ._exceptions import NonUniqueNameException from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase -from ._services.info import ServiceInfo, instance_name_from_service_info +from ._services.info import ServiceInfo from ._services.types import ZeroconfServiceTypes from ._utils.net import IPVersion, InterfaceChoice, InterfacesType -from ._utils.time import millis_to_seconds from .const import ( _BROWSER_TIME, - _CHECK_TIME, _MDNS_PORT, - _REGISTER_TIME, _SERVICE_TYPE_ENUMERATION_NAME, - _UNREGISTER_TIME, ) @@ -172,16 +167,11 @@ def __init__( ) self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} - async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: - """Send a broadcasts to announce a service at intervals.""" - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(interval)) - self.zeroconf.async_send(self.zeroconf.generate_service_broadcast(info, ttl)) - async def async_register_service( self, info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, cooperating_responders: bool = False, ) -> Awaitable: """Registers service information to the network with a default TTL. @@ -194,10 +184,9 @@ async def async_register_service( The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - await self.zeroconf.async_wait_for_start() - await self.async_check_service(info, cooperating_responders) - self.zeroconf.registry.add(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return await self.zeroconf.async_register_service( + info, ttl, allow_name_change, cooperating_responders + ) async def async_unregister_all_services(self) -> None: """Unregister all registered services. @@ -206,30 +195,7 @@ async def async_unregister_all_services(self) -> None: method does not return a future and is always expected to be awaited since its only called at shutdown. """ - out = self.zeroconf.generate_unregister_all_services() - if not out: - return - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) - self.zeroconf.async_send(out) - - async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: - """Checks the network for a unique service name.""" - instance_name_from_service_info(info) - if cooperating_responders: - return - self._raise_on_name_conflict(info) - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) - self.zeroconf.async_send(self.zeroconf.generate_service_query(info)) - self._raise_on_name_conflict(info) - - def _raise_on_name_conflict(self, info: ServiceInfo) -> None: - """Raise NonUniqueNameException if the ServiceInfo has a conflict.""" - if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): - raise NonUniqueNameException + await self.zeroconf.async_unregister_all_services() async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service. @@ -237,8 +203,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - self.zeroconf.registry.remove(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + return await self.zeroconf.async_unregister_service(info) async def async_update_service(self, info: ServiceInfo) -> Awaitable: """Registers service information to the network with a default TTL. @@ -248,8 +213,7 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - self.zeroconf.registry.update(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return await self.zeroconf.async_update_service(info) async def async_close(self) -> None: """Ends the background threads, and prevent this instance from