From 7fa40fb5efb10bd4fc3bd14749823867f833339f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 12:55:31 -0600 Subject: [PATCH 01/13] feat: eliminate async_timeout dep on python less than 3.11 --- poetry.lock | 15 +-------------- pyproject.toml | 1 - src/zeroconf/_core.py | 10 +++++----- src/zeroconf/_engine.py | 14 +++++++------- src/zeroconf/_utils/asyncio.py | 19 ++++++++----------- tests/utils/test_asyncio.py | 13 +++++++------ 6 files changed, 28 insertions(+), 44 deletions(-) diff --git a/poetry.lock b/poetry.lock index 14c79f618..962899b2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,19 +12,6 @@ files = [ {file = "alabaster-0.7.16.tar.gz", hash = "sha256:75a8b99c28a5dad50dd7f8ccdd447a121ddb3892da9e53d1ca5cca3106d58d65"}, ] -[[package]] -name = "async-timeout" -version = "5.0.1" -description = "Timeout context manager for asyncio programs" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.11\"" -files = [ - {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, - {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, -] - [[package]] name = "babel" version = "2.16.0" @@ -1140,4 +1127,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "eb91a0dd1c260f37d2579b4793f537f8017f9e1801e2a372849439f5c9132245" +content-hash = "ea903296f015035c594eb8cce08d4dedc716074e33644033938dfdb5f047d72e" diff --git a/pyproject.toml b/pyproject.toml index f5084253e..7514d9a5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ prerelease = true [tool.poetry.dependencies] python = "^3.9" -async-timeout = {version = ">=3.0.0", python = "<3.11"} ifaddr = ">=0.1.7" [tool.poetry.group.dev.dependencies] diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 68cb8a9ac..6a3760df2 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -53,8 +53,8 @@ get_running_loop, run_coro_with_timeout, shutdown_loop, - wait_event_or_timeout, wait_for_future_set_or_timeout, + wait_future_or_timeout, ) from ._utils.name import service_type_name from ._utils.net import ( @@ -201,7 +201,7 @@ def __init__( @property def started(self) -> bool: """Check if the instance has started.""" - return bool(not self.done and self.engine.running_event and self.engine.running_event.is_set()) + return bool(not self.done and self.engine.running_future and self.engine.running_future.result()) def start(self) -> None: """Start Zeroconf.""" @@ -233,9 +233,9 @@ async def async_wait_for_start(self) -> None: """ if self.done: # If the instance was shutdown from under us, raise immediately raise NotRunningException - assert self.engine.running_event is not None - await wait_event_or_timeout(self.engine.running_event, timeout=_STARTUP_TIMEOUT) - if not self.engine.running_event.is_set() or self.done: + assert self.engine.running_future is not None + await wait_future_or_timeout(self.engine.running_future, timeout=_STARTUP_TIMEOUT) + if not self.engine.running_future.result() or self.done: raise NotRunningException @property diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 05f8c948c..0e88746ed 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -51,7 +51,7 @@ class AsyncEngine: "loop", "protocols", "readers", - "running_event", + "running_future", "senders", "zc", ) @@ -67,7 +67,7 @@ def __init__( self.protocols: List[AsyncListener] = [] self.readers: List[_WrappedTransport] = [] self.senders: List[_WrappedTransport] = [] - self.running_event: Optional[asyncio.Event] = None + self.running_future: Optional[asyncio.Future[bool | None]] = None self._listen_socket = listen_socket self._respond_sockets = respond_sockets self._cleanup_timer: Optional[asyncio.TimerHandle] = None @@ -79,15 +79,15 @@ def setup( ) -> None: """Set up the instance.""" self.loop = loop - self.running_event = asyncio.Event() + self.running_future = loop.create_future() self.loop.create_task(self._async_setup(loop_thread_ready)) async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: """Set up the instance.""" self._async_schedule_next_cache_cleanup() await self._async_create_endpoints() - assert self.running_event is not None - self.running_event.set() + assert self.running_future is not None + self.running_future.set_result(True) if loop_thread_ready: loop_thread_ready.set() @@ -140,8 +140,8 @@ async def _async_close(self) -> None: def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" - assert self.running_event is not None - self.running_event.clear() + assert self.running_future is not None + self.running_future = None for wrapped_transport in itertools.chain(self.senders, self.readers): wrapped_transport.transport.close() diff --git a/src/zeroconf/_utils/asyncio.py b/src/zeroconf/_utils/asyncio.py index 6d070e306..b8b2548ce 100644 --- a/src/zeroconf/_utils/asyncio.py +++ b/src/zeroconf/_utils/asyncio.py @@ -23,14 +23,8 @@ import asyncio import concurrent.futures import contextlib -import sys from typing import Any, Awaitable, Coroutine, Optional, Set -if sys.version_info[:2] < (3, 11): - from async_timeout import timeout as asyncio_timeout -else: - from asyncio import timeout as asyncio_timeout # type: ignore[attr-defined] - from .._exceptions import EventLoopBlocked from ..const import _LOADED_SYSTEM_TIMEOUT from .time import millis_to_seconds @@ -68,11 +62,14 @@ async def wait_for_future_set_or_timeout( future_set.discard(future) -async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: - """Wait for an event or timeout.""" - with contextlib.suppress(asyncio.TimeoutError): - async with asyncio_timeout(timeout): - await event.wait() +async def wait_future_or_timeout(future: asyncio.Future[bool | None], timeout: float) -> None: + """Wait for a future or timeout.""" + loop = asyncio.get_running_loop() + handle = loop.call_later(timeout, _set_future_none_if_not_done, future) + try: + await future + finally: + handle.cancel() async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: diff --git a/tests/utils/test_asyncio.py b/tests/utils/test_asyncio.py index 09137a719..7989a82cf 100644 --- a/tests/utils/test_asyncio.py +++ b/tests/utils/test_asyncio.py @@ -45,16 +45,17 @@ def test_get_running_loop_no_loop() -> None: @pytest.mark.asyncio -async def test_wait_event_or_timeout_times_out() -> None: - """Test wait_event_or_timeout will timeout.""" - test_event = asyncio.Event() - await aioutils.wait_event_or_timeout(test_event, 0.1) +async def test_wait_future_or_timeout_times_out() -> None: + """Test wait_future_or_timeout will timeout.""" + loop = asyncio.get_running_loop() + test_future = loop.create_future() + await aioutils.wait_future_or_timeout(test_future, 0.1) - task = asyncio.ensure_future(test_event.wait()) + task = asyncio.ensure_future(test_future) await asyncio.sleep(0.1) async def _async_wait_or_timeout(): - await aioutils.wait_event_or_timeout(test_event, 0.1) + await aioutils.wait_future_or_timeout(test_future, 0.1) # Test high lock contention await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)]) From 48c02739b9e93ab3ac64f2625cf778e6a4725a47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 12:57:31 -0600 Subject: [PATCH 02/13] feat: eliminate async_timeout dep on python less than 3.11 --- src/zeroconf/_utils/asyncio.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/zeroconf/_utils/asyncio.py b/src/zeroconf/_utils/asyncio.py index b8b2548ce..5d3beb7cb 100644 --- a/src/zeroconf/_utils/asyncio.py +++ b/src/zeroconf/_utils/asyncio.py @@ -20,10 +20,12 @@ USA """ +from __future__ import annotations + import asyncio import concurrent.futures import contextlib -from typing import Any, Awaitable, Coroutine, Optional, Set +from typing import Any, Awaitable, Coroutine from .._exceptions import EventLoopBlocked from ..const import _LOADED_SYSTEM_TIMEOUT @@ -41,7 +43,7 @@ def _set_future_none_if_not_done(fut: asyncio.Future) -> None: fut.set_result(None) -def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None: +def _resolve_all_futures_to_none(futures: set[asyncio.Future]) -> None: """Resolve all futures to None.""" for fut in futures: _set_future_none_if_not_done(fut) @@ -49,7 +51,7 @@ def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None: async def wait_for_future_set_or_timeout( - loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float + loop: asyncio.AbstractEventLoop, future_set: set[asyncio.Future], timeout: float ) -> None: """Wait for a future or timeout (in milliseconds).""" future = loop.create_future() @@ -72,7 +74,7 @@ async def wait_future_or_timeout(future: asyncio.Future[bool | None], timeout: f handle.cancel() -async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: +async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> set[asyncio.Task]: """Return all tasks running.""" await asyncio.sleep(0) # flush out any call_soon_threadsafe # If there are multiple event loops running, all_tasks is not @@ -84,7 +86,7 @@ async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.T return set() -async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: +async def _wait_for_loop_tasks(wait_tasks: set[asyncio.Task]) -> None: """Wait for the event loop thread we started to shutdown.""" await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) @@ -127,7 +129,7 @@ def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: loop.call_soon_threadsafe(loop.stop) -def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: +def get_running_loop() -> asyncio.AbstractEventLoop | None: """Check if an event loop is already running.""" with contextlib.suppress(RuntimeError): return asyncio.get_running_loop() From e25246c7269d6fae7f3957a65c75c61ef283810e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 12:58:14 -0600 Subject: [PATCH 03/13] feat: eliminate async_timeout dep on python less than 3.11 --- src/zeroconf/__init__.py | 4 ++- src/zeroconf/_core.py | 66 +++++++++++++++++++++------------------- src/zeroconf/_engine.py | 26 ++++++++-------- 3 files changed, 51 insertions(+), 45 deletions(-) diff --git a/src/zeroconf/__init__.py b/src/zeroconf/__init__.py index 1a41ddd3b..26f60cde2 100644 --- a/src/zeroconf/__init__.py +++ b/src/zeroconf/__init__.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + from ._cache import DNSCache # noqa # import needed for backwards compat from ._core import Zeroconf from ._dns import ( # noqa # import needed for backwards compat @@ -57,10 +59,10 @@ ) from ._services.browser import ServiceBrowser from ._services.info import ( # noqa # import needed for backwards compat - ServiceInfo, AddressResolver, AddressResolverIPv4, AddressResolverIPv6, + ServiceInfo, instance_name_from_service_info, ) from ._services.registry import ( # noqa # import needed for backwards compat diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 6a3760df2..82e0fad20 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -20,12 +20,14 @@ USA """ +from __future__ import annotations + import asyncio import logging import sys import threading from types import TracebackType -from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Awaitable from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType @@ -108,9 +110,9 @@ def async_send_with_transport( packet: bytes, packet_num: int, out: DNSOutgoing, - addr: Optional[str], + addr: str | None, port: int, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + v6_flow_scope: tuple[()] | tuple[int, int] = (), ) -> None: ipv6_socket = transport.is_ipv6 if addr is None: @@ -149,7 +151,7 @@ def __init__( self, interfaces: InterfacesType = InterfaceChoice.All, unicast: bool = False, - ip_version: Optional[IPVersion] = None, + ip_version: IPVersion | None = None, apple_p2p: bool = False, ) -> None: """Creates an instance of the Zeroconf class, establishing @@ -181,7 +183,7 @@ def __init__( self.engine = AsyncEngine(self, listen_socket, respond_sockets) - self.browsers: Dict[ServiceListener, ServiceBrowser] = {} + self.browsers: dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() self.cache = DNSCache() self.question_history = QuestionHistory() @@ -192,9 +194,9 @@ def __init__( self.query_handler = QueryHandler(self) self.record_manager = RecordManager(self) - self._notify_futures: Set[asyncio.Future] = set() - self.loop: Optional[asyncio.AbstractEventLoop] = None - self._loop_thread: Optional[threading.Thread] = None + self._notify_futures: set[asyncio.Future] = set() + self.loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None self.start() @@ -239,7 +241,7 @@ async def async_wait_for_start(self) -> None: raise NotRunningException @property - def listeners(self) -> Set[RecordUpdateListener]: + def listeners(self) -> set[RecordUpdateListener]: return self.record_manager.listeners async def async_wait(self, timeout: float) -> None: @@ -264,8 +266,8 @@ def get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[ServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> ServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -301,7 +303,7 @@ def remove_all_service_listeners(self) -> None: def register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -329,7 +331,7 @@ def register_service( async def async_register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -380,8 +382,8 @@ async def async_get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[AsyncServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> AsyncServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -400,7 +402,7 @@ async def _async_broadcast_service( self, info: ServiceInfo, interval: int, - ttl: Optional[int], + ttl: int | None, broadcast_addresses: bool = True, ) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -412,7 +414,7 @@ async def _async_broadcast_service( def generate_service_broadcast( self, info: ServiceInfo, - ttl: Optional[int], + ttl: int | None, broadcast_addresses: bool = True, ) -> DNSOutgoing: """Generate a broadcast to announce a service.""" @@ -439,7 +441,7 @@ def _add_broadcast_answer( # pylint: disable=no-self-use self, out: DNSOutgoing, info: ServiceInfo, - override_ttl: Optional[int], + override_ttl: int | None, broadcast_addresses: bool = True, ) -> None: """Add answers to broadcast a service.""" @@ -481,7 +483,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) ) - def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: + def generate_unregister_all_services(self) -> DNSOutgoing | None: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" service_infos = self.registry.async_get_service_infos() if not service_infos: @@ -562,7 +564,7 @@ async def async_check_service( def add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -584,7 +586,7 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: def async_add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -604,10 +606,10 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: def send( self, out: DNSOutgoing, - addr: Optional[str] = None, + addr: str | None = None, port: int = _MDNS_PORT, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[_WrappedTransport] = None, + v6_flow_scope: tuple[()] | tuple[int, int] = (), + transport: _WrappedTransport | None = None, ) -> None: """Sends an outgoing packet threadsafe.""" assert self.loop is not None @@ -616,10 +618,10 @@ def send( def async_send( self, out: DNSOutgoing, - addr: Optional[str] = None, + addr: str | None = None, port: int = _MDNS_PORT, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[_WrappedTransport] = None, + v6_flow_scope: tuple[()] | tuple[int, int] = (), + transport: _WrappedTransport | None = None, ) -> None: """Sends an outgoing packet.""" if self.done: @@ -701,14 +703,14 @@ async def _async_close(self) -> None: await self.engine._async_close() # pylint: disable=protected-access self._shutdown_threads() - def __enter__(self) -> "Zeroconf": + def __enter__(self) -> Zeroconf: return self def __exit__( # pylint: disable=useless-return self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: self.close() return None diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 0e88746ed..b9f5d9a57 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -20,11 +20,13 @@ USA """ +from __future__ import annotations + import asyncio import itertools import socket import threading -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, cast from ._record_update import RecordUpdate from ._utils.asyncio import get_running_loop, run_coro_with_timeout @@ -58,31 +60,31 @@ class AsyncEngine: def __init__( self, - zeroconf: "Zeroconf", - listen_socket: Optional[socket.socket], - respond_sockets: List[socket.socket], + zeroconf: Zeroconf, + listen_socket: socket.socket | None, + respond_sockets: list[socket.socket], ) -> None: - self.loop: Optional[asyncio.AbstractEventLoop] = None + self.loop: asyncio.AbstractEventLoop | None = None self.zc = zeroconf - self.protocols: List[AsyncListener] = [] - self.readers: List[_WrappedTransport] = [] - self.senders: List[_WrappedTransport] = [] - self.running_future: Optional[asyncio.Future[bool | None]] = None + self.protocols: list[AsyncListener] = [] + self.readers: list[_WrappedTransport] = [] + self.senders: list[_WrappedTransport] = [] + self.running_future: asyncio.Future[bool | None] | None = None self._listen_socket = listen_socket self._respond_sockets = respond_sockets - self._cleanup_timer: Optional[asyncio.TimerHandle] = None + self._cleanup_timer: asyncio.TimerHandle | None = None def setup( self, loop: asyncio.AbstractEventLoop, - loop_thread_ready: Optional[threading.Event], + loop_thread_ready: threading.Event | None, ) -> None: """Set up the instance.""" self.loop = loop self.running_future = loop.create_future() self.loop.create_task(self._async_setup(loop_thread_ready)) - async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: + async def _async_setup(self, loop_thread_ready: threading.Event | None) -> None: """Set up the instance.""" self._async_schedule_next_cache_cleanup() await self._async_create_endpoints() From a5cb2bc39ff07f18e02a1583271321bbd0f97f4e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 12:58:24 -0600 Subject: [PATCH 04/13] feat: eliminate async_timeout dep on python less than 3.11 --- src/zeroconf/_cache.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/zeroconf/_cache.py b/src/zeroconf/_cache.py index 1b7aae38f..5ac43f307 100644 --- a/src/zeroconf/_cache.py +++ b/src/zeroconf/_cache.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + from heapq import heapify, heappop, heappush -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast +from typing import Dict, Iterable, Union, cast from ._dns import ( DNSAddress, @@ -66,8 +68,8 @@ class DNSCache: def __init__(self) -> None: self.cache: _DNSRecordCacheType = {} - self._expire_heap: List[Tuple[float, DNSRecord]] = [] - self._expirations: Dict[DNSRecord, float] = {} + self._expire_heap: list[tuple[float, DNSRecord]] = [] + self._expirations: dict[DNSRecord, float] = {} self.service_cache: _DNSRecordCacheType = {} # Functions prefixed with async_ are NOT threadsafe and must @@ -135,7 +137,7 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: for entry in entries: self._async_remove(entry) - def async_expire(self, now: _float) -> List[DNSRecord]: + def async_expire(self, now: _float) -> list[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. @@ -145,7 +147,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]: if not (expire_heap_len := len(self._expire_heap)): return [] - expired: List[DNSRecord] = [] + expired: list[DNSRecord] = [] # Find any expired records and add them to the to-delete list while self._expire_heap: when_record = self._expire_heap[0] @@ -182,7 +184,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]: self.async_remove_records(expired) return expired - def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + def async_get_unique(self, entry: _UniqueRecordsType) -> DNSRecord | None: """Gets a unique entry by key. Will return None if there is no matching entry. @@ -194,7 +196,7 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: return None return store.get(entry) - def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]: + def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> list[DNSRecord]: """Gets all matching entries by details. This function is not thread-safe and must be called from @@ -202,7 +204,7 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN """ key = name.lower() records = self.cache.get(key) - matches: List[DNSRecord] = [] + matches: list[DNSRecord] = [] if records is None: return matches for record in records.values(): @@ -210,7 +212,7 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN matches.append(record) return matches - def async_entries_with_name(self, name: str) -> List[DNSRecord]: + def async_entries_with_name(self, name: str) -> list[DNSRecord]: """Returns a dict of entries whose key matches the name. This function is not threadsafe and must be called from @@ -218,7 +220,7 @@ def async_entries_with_name(self, name: str) -> List[DNSRecord]: """ return self.entries_with_name(name) - def async_entries_with_server(self, name: str) -> List[DNSRecord]: + def async_entries_with_server(self, name: str) -> list[DNSRecord]: """Returns a dict of entries whose key matches the server. This function is not threadsafe and must be called from @@ -230,7 +232,7 @@ def async_entries_with_server(self, name: str) -> List[DNSRecord]: # event loop, however they all make copies so they significantly # inefficient. - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + def get(self, entry: DNSEntry) -> DNSRecord | None: """Gets an entry by key. Will return None if there is no matching entry.""" if isinstance(entry, _UNIQUE_RECORD_TYPES): @@ -240,7 +242,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: return cached_entry return None - def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRecord]: + def get_by_details(self, name: str, type_: _int, class_: _int) -> DNSRecord | None: """Gets the first matching entry by details. Returns None if no entries match. Calling this function is not recommended as it will only @@ -261,7 +263,7 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe return cached_entry return None - def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRecord]: + def get_all_by_details(self, name: str, type_: _int, class_: _int) -> list[DNSRecord]: """Gets all matching entries by details.""" key = name.lower() records = self.cache.get(key) @@ -269,19 +271,19 @@ def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRe return [] return [entry for entry in list(records.values()) if type_ == entry.type and class_ == entry.class_] - def entries_with_server(self, server: str) -> List[DNSRecord]: + def entries_with_server(self, server: str) -> list[DNSRecord]: """Returns a list of entries whose server matches the name.""" if entries := self.service_cache.get(server.lower()): return list(entries.values()) return [] - def entries_with_name(self, name: str) -> List[DNSRecord]: + def entries_with_name(self, name: str) -> list[DNSRecord]: """Returns a list of entries whose key matches the name.""" if entries := self.cache.get(name.lower()): return list(entries.values()) return [] - def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + def current_entry_with_name_and_alias(self, name: str, alias: str) -> DNSRecord | None: now = current_time_millis() for record in reversed(self.entries_with_name(name)): if ( @@ -292,13 +294,13 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D return record return None - def names(self) -> List[str]: + def names(self) -> list[str]: """Return a copy of the list of current cache names.""" return list(self.cache) def async_mark_unique_records_older_than_1s_to_expire( self, - unique_types: Set[Tuple[_str, _int, _int]], + unique_types: set[tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float, ) -> None: From bf1eb21a77ac4d31d062265d93ab83c670b9d47c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:00:27 -0600 Subject: [PATCH 05/13] feat: eliminate async_timeout dep on python less than 3.11 --- src/zeroconf/_handlers/__init__.py | 2 + src/zeroconf/_handlers/answers.py | 8 +- .../_handlers/multicast_outgoing_queue.py | 4 +- src/zeroconf/_handlers/query_handler.py | 60 ++++---- src/zeroconf/_handlers/record_manager.py | 26 ++-- src/zeroconf/_history.py | 10 +- src/zeroconf/_listener.py | 34 ++--- src/zeroconf/_protocol/__init__.py | 2 + src/zeroconf/_protocol/incoming.py | 32 +++-- src/zeroconf/_protocol/outgoing.py | 26 ++-- src/zeroconf/_record_update.py | 8 +- src/zeroconf/_services/__init__.py | 20 +-- src/zeroconf/_services/browser.py | 104 +++++++------- src/zeroconf/_services/info.py | 130 +++++++++--------- src/zeroconf/_services/registry.py | 24 ++-- src/zeroconf/_services/types.py | 13 +- src/zeroconf/_transport.py | 5 +- src/zeroconf/_updates.py | 8 +- src/zeroconf/_utils/__init__.py | 2 + src/zeroconf/_utils/ipaddress.py | 14 +- src/zeroconf/_utils/name.py | 5 +- src/zeroconf/_utils/net.py | 34 ++--- src/zeroconf/_utils/time.py | 2 + src/zeroconf/const.py | 2 + 24 files changed, 304 insertions(+), 271 deletions(-) diff --git a/src/zeroconf/_handlers/__init__.py b/src/zeroconf/_handlers/__init__.py index 30920c6aa..584a74eca 100644 --- a/src/zeroconf/_handlers/__init__.py +++ b/src/zeroconf/_handlers/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_handlers/answers.py b/src/zeroconf/_handlers/answers.py index 7ddde1976..ec53eb842 100644 --- a/src/zeroconf/_handlers/answers.py +++ b/src/zeroconf/_handlers/answers.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + from operator import attrgetter -from typing import Dict, List, Set +from typing import Dict, Set from .._dns import DNSQuestion, DNSRecord from .._protocol.outgoing import DNSOutgoing @@ -96,7 +98,7 @@ def construct_outgoing_multicast_answers( def construct_outgoing_unicast_answers( answers: _AnswerWithAdditionalsType, ucast_source: bool, - questions: List[DNSQuestion], + questions: list[DNSQuestion], id_: int_, ) -> DNSOutgoing: """Add answers and additionals to a DNSOutgoing.""" @@ -111,7 +113,7 @@ def construct_outgoing_unicast_answers( def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: # Find additionals and suppress any additionals that are already in answers - sending: Set[DNSRecord] = set(answers) + sending: set[DNSRecord] = set(answers) # Answers are sorted to group names together to increase the chance # that similar names will end up in the same packet and can reduce the # overall size of the outgoing response via name compression diff --git a/src/zeroconf/_handlers/multicast_outgoing_queue.py b/src/zeroconf/_handlers/multicast_outgoing_queue.py index caf6470b1..73d5ee431 100644 --- a/src/zeroconf/_handlers/multicast_outgoing_queue.py +++ b/src/zeroconf/_handlers/multicast_outgoing_queue.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import random from collections import deque from typing import TYPE_CHECKING @@ -53,7 +55,7 @@ class MulticastOutgoingQueue: "zc", ) - def __init__(self, zeroconf: "Zeroconf", additional_delay: _int, max_aggregation_delay: _int) -> None: + def __init__(self, zeroconf: Zeroconf, additional_delay: _int, max_aggregation_delay: _int) -> None: self.zc = zeroconf self.queue: deque[AnswerGroup] = deque() # Additional delay is used to implement diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index ccfc7a771..60209568a 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast +from __future__ import annotations + +from typing import TYPE_CHECKING, cast from .._cache import DNSCache, _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet @@ -52,8 +54,8 @@ _RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} -_EMPTY_SERVICES_LIST: List[ServiceInfo] = [] -_EMPTY_TYPES_LIST: List[str] = [] +_EMPTY_SERVICES_LIST: list[ServiceInfo] = [] +_EMPTY_TYPES_LIST: list[str] = [] _IPVersion_ALL = IPVersion.All @@ -77,8 +79,8 @@ def __init__( self, question: DNSQuestion, strategy_type: _int, - types: List[str], - services: List[ServiceInfo], + types: list[str], + services: list[ServiceInfo], ) -> None: """Create an answer strategy.""" self.question = question @@ -102,17 +104,17 @@ class _QueryResponse: "_ucast", ) - def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None: + def __init__(self, cache: DNSCache, questions: list[DNSQuestion], is_probe: bool, now: float) -> None: """Build a query response.""" self._is_probe = is_probe self._questions = questions self._now = now self._cache = cache self._additionals: _AnswerWithAdditionalsType = {} - self._ucast: Set[DNSRecord] = set() - self._mcast_now: Set[DNSRecord] = set() - self._mcast_aggregate: Set[DNSRecord] = set() - self._mcast_aggregate_last_second: Set[DNSRecord] = set() + self._ucast: set[DNSRecord] = set() + self._mcast_now: set[DNSRecord] = set() + self._mcast_aggregate: set[DNSRecord] = set() + self._mcast_aggregate_last_second: set[DNSRecord] = set() def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast QU query.""" @@ -199,7 +201,7 @@ class QueryHandler: "zc", ) - def __init__(self, zc: "Zeroconf") -> None: + def __init__(self, zc: Zeroconf) -> None: """Init the query handler.""" self.zc = zc self.registry = zc.registry @@ -210,7 +212,7 @@ def __init__(self, zc: "Zeroconf") -> None: def _add_service_type_enumeration_query_answers( self, - types: List[str], + types: list[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, ) -> None: @@ -232,7 +234,7 @@ def _add_service_type_enumeration_query_answers( def _add_pointer_answers( self, - services: List[ServiceInfo], + services: list[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, ) -> None: @@ -251,23 +253,23 @@ def _add_pointer_answers( def _add_address_answers( self, - services: List[ServiceInfo], + services: list[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: _int, ) -> None: """Answer A/AAAA/ANY question.""" for service in services: - answers: List[DNSAddress] = [] - additionals: Set[DNSRecord] = set() - seen_types: Set[int] = set() + answers: list[DNSAddress] = [] + additionals: set[DNSRecord] = set() + seen_types: set[int] = set() for dns_address in service._dns_addresses(None, _IPVersion_ALL): seen_types.add(dns_address.type) if dns_address.type != type_: additionals.add(dns_address) elif not known_answers.suppresses(dns_address): answers.append(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + missing_types: set[int] = _ADDRESS_RECORD_TYPES - seen_types if answers: if missing_types: assert service.server is not None, "Service server must be set for NSEC record." @@ -282,8 +284,8 @@ def _answer_question( self, question: DNSQuestion, strategy_type: _int, - types: List[str], - services: List[ServiceInfo], + types: list[str], + services: list[ServiceInfo], known_answers: DNSRRSet, ) -> _AnswerWithAdditionalsType: """Answer a question.""" @@ -311,14 +313,14 @@ def _answer_question( return answer_set def async_response( # pylint: disable=unused-argument - self, msgs: List[DNSIncoming], ucast_source: bool - ) -> Optional[QuestionAnswers]: + self, msgs: list[DNSIncoming], ucast_source: bool + ) -> QuestionAnswers | None: """Deal with incoming query packets. Provides a response if possible. This function must be run in the event loop as it is not threadsafe. """ - strategies: List[_AnswerStrategy] = [] + strategies: list[_AnswerStrategy] = [] for msg in msgs: for question in msg._questions: strategies.extend(self._get_answer_strategies(question)) @@ -334,7 +336,7 @@ def async_response( # pylint: disable=unused-argument questions = msg._questions # Only decode known answers if we are not a probe and we have # at least one answer strategy - answers: List[DNSRecord] = [] + answers: list[DNSRecord] = [] for msg in msgs: if msg.is_probe(): is_probe = True @@ -343,7 +345,7 @@ def async_response( # pylint: disable=unused-argument query_res = _QueryResponse(self.cache, questions, is_probe, msg.now) known_answers = DNSRRSet(answers) - known_answers_set: Optional[Set[DNSRecord]] = None + known_answers_set: set[DNSRecord] | None = None now = msg.now for strategy in strategies: question = strategy.question @@ -373,12 +375,12 @@ def async_response( # pylint: disable=unused-argument def _get_answer_strategies( self, question: DNSQuestion, - ) -> List[_AnswerStrategy]: + ) -> list[_AnswerStrategy]: """Collect strategies to answer a question.""" name = question.name question_lower_name = name.lower() type_ = question.type - strategies: List[_AnswerStrategy] = [] + strategies: list[_AnswerStrategy] = [] if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: types = self.registry.async_get_types() @@ -433,11 +435,11 @@ def _get_answer_strategies( def handle_assembled_query( self, - packets: List[DNSIncoming], + packets: list[DNSIncoming], addr: _str, port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + v6_flow_scope: tuple[()] | tuple[int, int], ) -> None: """Respond to a (re)assembled query. diff --git a/src/zeroconf/_handlers/record_manager.py b/src/zeroconf/_handlers/record_manager.py index d4e2792c8..566f0e8c9 100644 --- a/src/zeroconf/_handlers/record_manager.py +++ b/src/zeroconf/_handlers/record_manager.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast +from __future__ import annotations + +from typing import TYPE_CHECKING, cast from .._cache import _UniqueRecordsType from .._dns import DNSQuestion, DNSRecord @@ -42,13 +44,13 @@ class RecordManager: __slots__ = ("cache", "listeners", "zc") - def __init__(self, zeroconf: "Zeroconf") -> None: + def __init__(self, zeroconf: Zeroconf) -> None: """Init the record manager.""" self.zc = zeroconf self.cache = zeroconf.cache - self.listeners: Set[RecordUpdateListener] = set() + self.listeners: set[RecordUpdateListener] = set() - def async_updates(self, now: _float, records: List[RecordUpdate]) -> None: + def async_updates(self, now: _float, records: list[RecordUpdate]) -> None: """Used to notify listeners of new information that has updated a record. @@ -79,12 +81,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: This function must be run in the event loop as it is not threadsafe. """ - updates: List[RecordUpdate] = [] - address_adds: List[DNSRecord] = [] - other_adds: List[DNSRecord] = [] - removes: Set[DNSRecord] = set() + updates: list[RecordUpdate] = [] + address_adds: list[DNSRecord] = [] + other_adds: list[DNSRecord] = [] + removes: set[DNSRecord] = set() now = msg.now - unique_types: Set[Tuple[str, int, int]] = set() + unique_types: set[tuple[str, int, int]] = set() cache = self.cache answers = msg.answers() @@ -165,7 +167,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: def async_add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -188,14 +190,14 @@ def async_add_listener( self._async_update_matching_records(listener, questions) def _async_update_matching_records( - self, listener: RecordUpdateListener, questions: List[DNSQuestion] + self, listener: RecordUpdateListener, questions: list[DNSQuestion] ) -> None: """Calls back any existing entries in the cache that answer the question. This function must be run from the event loop. """ now = current_time_millis() - records: List[RecordUpdate] = [ + records: list[RecordUpdate] = [ RecordUpdate(record, None) for question in questions for record in self.cache.async_entries_with_name(question.name) diff --git a/src/zeroconf/_history.py b/src/zeroconf/_history.py index aa28519c5..5bae7be04 100644 --- a/src/zeroconf/_history.py +++ b/src/zeroconf/_history.py @@ -20,7 +20,7 @@ USA """ -from typing import Dict, List, Set, Tuple +from __future__ import annotations from ._dns import DNSQuestion, DNSRecord from .const import _DUPLICATE_QUESTION_INTERVAL @@ -36,13 +36,13 @@ class QuestionHistory: def __init__(self) -> None: """Init a new QuestionHistory.""" - self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {} + self._history: dict[DNSQuestion, tuple[float, set[DNSRecord]]] = {} - def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> None: + def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: set[DNSRecord]) -> None: """Remember a question with known answers.""" self._history[question] = (now, known_answers) - def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> bool: + def suppresses(self, question: DNSQuestion, now: _float, known_answers: set[DNSRecord]) -> bool: """Check to see if a question should be suppressed. https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 @@ -66,7 +66,7 @@ def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSR def async_expire(self, now: _float) -> None: """Expire the history of old questions.""" - removes: List[DNSQuestion] = [] + removes: list[DNSQuestion] = [] for question, now_known_answers in self._history.items(): than, _ = now_known_answers if now - than > _DUPLICATE_QUESTION_INTERVAL: diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index 1980a8201..925c689e0 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -20,11 +20,13 @@ USA """ +from __future__ import annotations + import asyncio import logging import random from functools import partial -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Tuple, cast from ._logger import QuietLogger, log from ._protocol.incoming import DNSIncoming @@ -68,23 +70,21 @@ class AsyncListener: "zc", ) - def __init__(self, zc: "Zeroconf") -> None: + def __init__(self, zc: Zeroconf) -> None: self.zc = zc self._registry = zc.registry self._record_manager = zc.record_manager self._query_handler = zc.query_handler - self.data: Optional[bytes] = None + self.data: bytes | None = None self.last_time: float = 0 - self.last_message: Optional[DNSIncoming] = None - self.transport: Optional[_WrappedTransport] = None - self.sock_description: Optional[str] = None - self._deferred: Dict[str, List[DNSIncoming]] = {} - self._timers: Dict[str, asyncio.TimerHandle] = {} + self.last_message: DNSIncoming | None = None + self.transport: _WrappedTransport | None = None + self.sock_description: str | None = None + self._deferred: dict[str, list[DNSIncoming]] = {} + self._timers: dict[str, asyncio.TimerHandle] = {} super().__init__() - def datagram_received( - self, data: _bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] - ) -> None: + def datagram_received(self, data: _bytes, addrs: tuple[str, int] | tuple[str, int, int, int]) -> None: data_len = len(data) debug = DEBUG_ENABLED() @@ -108,7 +108,7 @@ def _process_datagram_at_time( data_len: _int, now: _float, data: _bytes, - addrs: Union[Tuple[str, int], Tuple[str, int, int, int]], + addrs: tuple[str, int] | tuple[str, int, int, int], ) -> None: if ( self.data == data @@ -129,7 +129,7 @@ def _process_datagram_at_time( return if len(addrs) == 2: - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + v6_flow_scope: tuple[()] | tuple[int, int] = () # https://github.com/python/mypy/issues/1178 addr, port = addrs # type: ignore addr_port = addrs @@ -189,7 +189,7 @@ def handle_query_or_defer( addr: _str, port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + v6_flow_scope: tuple[()] | tuple[int, int], ) -> None: """Deal with incoming query packets. Provides a response if possible.""" @@ -224,11 +224,11 @@ def _cancel_any_timers_for_addr(self, addr: _str) -> None: def _respond_query( self, - msg: Optional[DNSIncoming], + msg: DNSIncoming | None, addr: _str, port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + v6_flow_scope: tuple[()] | tuple[int, int], ) -> None: """Respond to a query and reassemble any truncated deferred packets.""" self._cancel_any_timers_for_addr(addr) @@ -252,5 +252,5 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = wrapped_transport self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Handle connection lost.""" diff --git a/src/zeroconf/_protocol/__init__.py b/src/zeroconf/_protocol/__init__.py index 30920c6aa..584a74eca 100644 --- a/src/zeroconf/_protocol/__init__.py +++ b/src/zeroconf/_protocol/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 6e009b293..7f4a8eec1 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import struct import sys -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any from .._dns import ( DNSAddress, @@ -61,7 +63,7 @@ DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) -_seen_logs: Dict[str, Union[int, tuple]] = {} +_seen_logs: dict[str, int | tuple] = {} _str = str _int = int @@ -94,9 +96,9 @@ class DNSIncoming: def __init__( self, data: bytes, - source: Optional[Tuple[str, int]] = None, - scope_id: Optional[int] = None, - now: Optional[float] = None, + source: tuple[str, int] | None = None, + scope_id: int | None = None, + now: float | None = None, ) -> None: """Constructor from string holding bytes of packet""" self.flags = 0 @@ -104,9 +106,9 @@ def __init__( self.data = data self.view = data self._data_len = len(data) - self._name_cache: Dict[int, List[str]] = {} - self._questions: List[DNSQuestion] = [] - self._answers: List[DNSRecord] = [] + self._name_cache: dict[int, list[str]] = {} + self._questions: list[DNSQuestion] = [] + self._answers: list[DNSRecord] = [] self.id = 0 self._num_questions = 0 self._num_answers = 0 @@ -146,7 +148,7 @@ def truncated(self) -> bool: return (self.flags & _FLAGS_TC) == _FLAGS_TC @property - def questions(self) -> List[DNSQuestion]: + def questions(self) -> list[DNSQuestion]: """Questions in the packet.""" return self._questions @@ -189,7 +191,7 @@ def _log_exception_debug(cls, *logger_data: Any) -> None: log_exc_info = True log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info) - def answers(self) -> List[DNSRecord]: + def answers(self) -> list[DNSRecord]: """Answers in the packet.""" if not self._did_read_others: try: @@ -306,7 +308,7 @@ def _read_others(self) -> None: def _read_record( self, domain: _str, type_: _int, class_: _int, ttl: _int, length: _int - ) -> Optional[DNSRecord]: + ) -> DNSRecord | None: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: address_rec = DNSAddress.__new__(DNSAddress) @@ -384,7 +386,7 @@ def _read_record( self.offset += length return None - def _read_bitmap(self, end: _int) -> List[int]: + def _read_bitmap(self, end: _int) -> list[int]: """Reads an NSEC bitmap from the packet.""" rdtypes = [] view = self.view @@ -404,8 +406,8 @@ def _read_bitmap(self, end: _int) -> List[int]: def _read_name(self) -> str: """Reads a domain name from the packet.""" - labels: List[str] = [] - seen_pointers: Set[int] = set() + labels: list[str] = [] + seen_pointers: set[int] = set() original_offset = self.offset self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers) self._name_cache[original_offset] = labels @@ -416,7 +418,7 @@ def _read_name(self) -> str: ) return name - def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int: + def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. view = self.view while off < self._data_len: diff --git a/src/zeroconf/_protocol/outgoing.py b/src/zeroconf/_protocol/outgoing.py index c937350ed..f5d098211 100644 --- a/src/zeroconf/_protocol/outgoing.py +++ b/src/zeroconf/_protocol/outgoing.py @@ -20,10 +20,12 @@ USA """ +from __future__ import annotations + import enum import logging from struct import Struct -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from .._dns import DNSPointer, DNSQuestion, DNSRecord from .._exceptions import NamePartTooLongException @@ -98,20 +100,20 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: self.finished = False self.id = id_ self.multicast = multicast - self.packets_data: List[bytes] = [] + self.packets_data: list[bytes] = [] # these 3 are per-packet -- see also _reset_for_next_packet() - self.names: Dict[str, int] = {} - self.data: List[bytes] = [] + self.names: dict[str, int] = {} + self.data: list[bytes] = [] self.size: int = _DNS_PACKET_HEADER_LEN self.allow_long: bool = True self.state = STATE_INIT - self.questions: List[DNSQuestion] = [] - self.answers: List[Tuple[DNSRecord, float]] = [] - self.authorities: List[DNSPointer] = [] - self.additionals: List[DNSRecord] = [] + self.questions: list[DNSQuestion] = [] + self.answers: list[tuple[DNSRecord, float]] = [] + self.authorities: list[DNSPointer] = [] + self.additionals: list[DNSRecord] = [] def is_query(self) -> bool: """Returns true if this is a query.""" @@ -150,7 +152,7 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: if not record.suppressed_by(inp): self.add_answer_at_time(record, 0.0) - def add_answer_at_time(self, record: Optional[DNSRecord], now: float_) -> None: + def add_answer_at_time(self, record: DNSRecord | None, now: float_) -> None: """Adds an answer if it does not expire by a certain time""" now_double = now if record is not None and (now_double == 0 or not record.is_expired(now_double)): @@ -220,7 +222,7 @@ def write_short(self, value: int_) -> None: self.data.append(self._get_short(value)) self.size += 2 - def _write_int(self, value: Union[float, int]) -> None: + def _write_int(self, value: float | int) -> None: """Writes an unsigned integer to the packet""" value_as_int = int(value) long_bytes = LONG_LOOKUP.get(value_as_int) @@ -313,7 +315,7 @@ def _write_question(self, question: DNSQuestion_) -> bool: self._write_record_class(question) return self._check_data_limit_or_rollback(start_data_length, start_size) - def _write_record_class(self, record: Union[DNSQuestion_, DNSRecord_]) -> None: + def _write_record_class(self, record: DNSQuestion_ | DNSRecord_) -> None: """Write out the record class including the unique/unicast (QU) bit.""" class_ = record.class_ if record.unique is True and self.multicast: @@ -409,7 +411,7 @@ def _has_more_to_add( or additional_offset < len(self.additionals) ) - def packets(self) -> List[bytes]: + def packets(self) -> list[bytes]: """Returns a list of bytestrings containing the packets' bytes No further parts should be added to the packet once this diff --git a/src/zeroconf/_record_update.py b/src/zeroconf/_record_update.py index 912ab6f1d..5f8175113 100644 --- a/src/zeroconf/_record_update.py +++ b/src/zeroconf/_record_update.py @@ -20,7 +20,7 @@ USA """ -from typing import Optional +from __future__ import annotations from ._dns import DNSRecord @@ -30,16 +30,16 @@ class RecordUpdate: __slots__ = ("new", "old") - def __init__(self, new: DNSRecord, old: Optional[DNSRecord] = None) -> None: + def __init__(self, new: DNSRecord, old: DNSRecord | None = None) -> None: """RecordUpdate represents a change in a DNS record.""" self._fast_init(new, old) - def _fast_init(self, new: _DNSRecord, old: Optional[_DNSRecord]) -> None: + def _fast_init(self, new: _DNSRecord, old: _DNSRecord | None) -> None: """Fast init for RecordUpdate.""" self.new = new self.old = old - def __getitem__(self, index: int) -> Optional[DNSRecord]: + def __getitem__(self, index: int) -> DNSRecord | None: """Get the new or old record.""" if index == 0: return self.new diff --git a/src/zeroconf/_services/__init__.py b/src/zeroconf/_services/__init__.py index 7a6bddebb..6936aed61 100644 --- a/src/zeroconf/_services/__init__.py +++ b/src/zeroconf/_services/__init__.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + import enum -from typing import TYPE_CHECKING, Any, Callable, List +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: from .._core import Zeroconf @@ -35,13 +37,13 @@ class ServiceStateChange(enum.Enum): class ServiceListener: - def add_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() - def remove_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() - def update_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() @@ -49,27 +51,27 @@ class Signal: __slots__ = ("_handlers",) def __init__(self) -> None: - self._handlers: List[Callable[..., None]] = [] + self._handlers: list[Callable[..., None]] = [] def fire(self, **kwargs: Any) -> None: for h in self._handlers[:]: h(**kwargs) @property - def registration_interface(self) -> "SignalRegistrationInterface": + def registration_interface(self) -> SignalRegistrationInterface: return SignalRegistrationInterface(self._handlers) class SignalRegistrationInterface: __slots__ = ("_handlers",) - def __init__(self, handlers: List[Callable[..., None]]) -> None: + def __init__(self, handlers: list[Callable[..., None]]) -> None: self._handlers = handlers - def register_handler(self, handler: Callable[..., None]) -> "SignalRegistrationInterface": + def register_handler(self, handler: Callable[..., None]) -> SignalRegistrationInterface: self._handlers.append(handler) return self - def unregister_handler(self, handler: Callable[..., None]) -> "SignalRegistrationInterface": + def unregister_handler(self, handler: Callable[..., None]) -> SignalRegistrationInterface: self._handlers.remove(handler) return self diff --git a/src/zeroconf/_services/browser.py b/src/zeroconf/_services/browser.py index 42aaa1ac8..c2ab115b0 100644 --- a/src/zeroconf/_services/browser.py +++ b/src/zeroconf/_services/browser.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import asyncio import heapq import queue @@ -36,11 +38,7 @@ Dict, Iterable, List, - Optional, Set, - Tuple, - Type, - Union, cast, ) @@ -155,13 +153,13 @@ def __repr__(self) -> str: ">" ) - def __lt__(self, other: "_ScheduledPTRQuery") -> bool: + def __lt__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis < other.when_millis return NotImplemented - def __le__(self, other: "_ScheduledPTRQuery") -> bool: + def __le__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis < other.when_millis or self.__eq__(other) @@ -173,13 +171,13 @@ def __eq__(self, other: Any) -> bool: return self.when_millis == other.when_millis return NotImplemented - def __ge__(self, other: "_ScheduledPTRQuery") -> bool: + def __ge__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis > other.when_millis or self.__eq__(other) return NotImplemented - def __gt__(self, other: "_ScheduledPTRQuery") -> bool: + def __gt__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis > other.when_millis @@ -197,7 +195,7 @@ def __init__(self, now_millis: float, multicast: bool) -> None: self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast) self.bytes = 0 - def add(self, max_compressed_size: int_, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + def add(self, max_compressed_size: int_, question: DNSQuestion, answers: set[DNSPointer]) -> None: """Add a new set of questions and known answers to the outgoing.""" self.out.add_question(question) for answer in answers: @@ -209,7 +207,7 @@ def group_ptr_queries_with_known_answers( now: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers, -) -> List[DNSOutgoing]: +) -> list[DNSOutgoing]: """Aggregate queries so that as many known answers as possible fit in the same packet without having known answers spill over into the next packet unless the question and known answers are always going to exceed the packet size. @@ -225,19 +223,19 @@ def _group_ptr_queries_with_known_answers( now_millis: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers, -) -> List[DNSOutgoing]: +) -> list[DNSOutgoing]: """Inner wrapper for group_ptr_queries_with_known_answers.""" # This is the maximum size the query + known answers can be with name compression. # The actual size of the query + known answers may be a bit smaller since other # parts may be shared when the final DNSOutgoing packets are constructed. The # goal of this algorithm is to quickly bucket the query + known answers without # the overhead of actually constructing the packets. - query_by_size: Dict[DNSQuestion, int] = { + query_by_size: dict[DNSQuestion, int] = { question: (question.max_size + sum(answer.max_size_compressed for answer in known_answers)) for question, known_answers in question_with_known_answers.items() } max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN - query_buckets: List[_DNSPointerOutgoingBucket] = [] + query_buckets: list[_DNSPointerOutgoingBucket] = [] for question in sorted( query_by_size, key=query_by_size.get, # type: ignore @@ -261,12 +259,12 @@ def _group_ptr_queries_with_known_answers( def generate_service_query( - zc: "Zeroconf", + zc: Zeroconf, now_millis: float_, - types_: Set[str], + types_: set[str], multicast: bool, - question_type: Optional[DNSQuestionType], -) -> List[DNSOutgoing]: + question_type: DNSQuestionType | None, +) -> list[DNSOutgoing]: """Generate a service query for sending with zeroconf.send.""" questions_with_known_answers: _QuestionWithKnownAnswers = {} qu_question = not multicast if question_type is None else question_type is QU_QUESTION @@ -296,7 +294,7 @@ def generate_service_query( def _on_change_dispatcher( listener: ServiceListener, - zeroconf: "Zeroconf", + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange, @@ -346,14 +344,14 @@ class QueryScheduler: def __init__( self, - zc: "Zeroconf", - types: Set[str], - addr: Optional[str], + zc: Zeroconf, + types: set[str], + addr: str | None, port: int, multicast: bool, delay: int, - first_random_delay_interval: Tuple[int, int], - question_type: Optional[DNSQuestionType], + first_random_delay_interval: tuple[int, int], + question_type: DNSQuestionType | None, ) -> None: self._zc = zc self._types = types @@ -362,11 +360,11 @@ def __init__( self._multicast = multicast self._first_random_delay_interval = first_random_delay_interval self._min_time_between_queries_millis = delay - self._loop: Optional[asyncio.AbstractEventLoop] = None + self._loop: asyncio.AbstractEventLoop | None = None self._startup_queries_sent = 0 - self._next_scheduled_for_alias: Dict[str, _ScheduledPTRQuery] = {} + self._next_scheduled_for_alias: dict[str, _ScheduledPTRQuery] = {} self._query_heap: list[_ScheduledPTRQuery] = [] - self._next_run: Optional[asyncio.TimerHandle] = None + self._next_run: asyncio.TimerHandle | None = None self._clock_resolution_millis = time.get_clock_info("monotonic").resolution * 1000 self._question_type = question_type @@ -500,10 +498,10 @@ def _process_ready_types(self) -> None: # with a minimum time between queries of _min_time_between_queries # which defaults to 10s - ready_types: Set[str] = set() - next_scheduled: Optional[_ScheduledPTRQuery] = None + ready_types: set[str] = set() + next_scheduled: _ScheduledPTRQuery | None = None end_time_millis = now_millis + self._clock_resolution_millis - schedule_rescue: List[_ScheduledPTRQuery] = [] + schedule_rescue: list[_ScheduledPTRQuery] = [] while self._query_heap: query = self._query_heap[0] @@ -538,7 +536,7 @@ def _process_ready_types(self) -> None: self._next_run = self._loop.call_at(millis_to_seconds(next_when_millis), self._process_ready_types) def async_send_ready_queries( - self, first_request: bool, now_millis: float_, ready_types: Set[str] + self, first_request: bool, now_millis: float_, ready_types: set[str] ) -> None: """Send any ready queries.""" # If they did not specify and this is the first request, ask QU questions @@ -569,14 +567,14 @@ class _ServiceBrowserBase(RecordUpdateListener): def __init__( self, - zc: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zc: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: """Used to browse for a service for specific type(s). @@ -596,7 +594,7 @@ def __init__( discovers changes in the services availability. """ assert handlers or listener, "You need to specify at least one handler" - self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) + self.types: set[str] = set(type_ if isinstance(type_, list) else [type_]) for check_type_ in self.types: # Will generate BadTypeInNameException on a bad name service_type_name(check_type_, strict=False) @@ -604,7 +602,7 @@ def __init__( self._cache = zc.cache assert zc.loop is not None self._loop = zc.loop - self._pending_handlers: Dict[Tuple[str, str], ServiceStateChange] = {} + self._pending_handlers: dict[tuple[str, str], ServiceStateChange] = {} self._service_state_changed = Signal() self.query_scheduler = QueryScheduler( zc, @@ -617,7 +615,7 @@ def __init__( question_type, ) self.done = False - self._query_sender_task: Optional[asyncio.Task] = None + self._query_sender_task: asyncio.Task | None = None if hasattr(handlers, "add_service"): listener = cast("ServiceListener", handlers) @@ -645,7 +643,7 @@ def _async_start(self) -> None: def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface - def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]: + def _names_matching_types(self, names: Iterable[str]) -> list[tuple[str, str]]: """Return the type and name for records matching the types we are browsing.""" return [ (type_, name) for name in names for type_ in self.types.intersection(cached_possible_types(name)) @@ -670,7 +668,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Callback invoked by Zeroconf when new information arrives. Updates information required by browser in the Zeroconf cache. @@ -727,7 +725,7 @@ def async_update_records_complete(self) -> None: self._fire_service_state_changed_event(pending) self._pending_handlers.clear() - def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None: + def _fire_service_state_changed_event(self, event: tuple[tuple[str, str], ServiceStateChange]) -> None: """Fire a service state changed event. When running with ServiceBrowser, this will happen in the dedicated @@ -769,14 +767,14 @@ class ServiceBrowser(_ServiceBrowserBase, threading.Thread): def __init__( self, - zc: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zc: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: assert zc.loop is not None if not zc.loop.is_running(): @@ -821,14 +819,14 @@ def async_update_records_complete(self) -> None: self.queue.put(pending) self._pending_handlers.clear() - def __enter__(self) -> "ServiceBrowser": + def __enter__(self) -> ServiceBrowser: return self def __exit__( # pylint: disable=useless-return self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: self.cancel() return None diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index a6e815b51..677774594 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import asyncio import random -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, cast from .._cache import DNSCache from .._dns import ( @@ -106,7 +108,7 @@ from .._core import Zeroconf -def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) -> str: +def instance_name_from_service_info(info: ServiceInfo, strict: bool = True) -> str: """Calculate the instance name from the ServiceInfo.""" # This is kind of funky because of the subtype based tests # need to make subtypes a first class citizen @@ -168,17 +170,17 @@ def __init__( self, type_: str, name: str, - port: Optional[int] = None, + port: int | None = None, weight: int = 0, priority: int = 0, - properties: Union[bytes, Dict] = b"", - server: Optional[str] = None, + properties: bytes | dict = b"", + server: str | None = None, host_ttl: int = _DNS_HOST_TTL, other_ttl: int = _DNS_OTHER_TTL, *, - addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None, - interface_index: Optional[int] = None, + addresses: list[bytes] | None = None, + parsed_addresses: list[str] | None = None, + interface_index: int | None = None, ) -> None: # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: @@ -190,8 +192,8 @@ def __init__( self.type = type_ self._name = name self.key = name.lower() - self._ipv4_addresses: List[ZeroconfIPv4Address] = [] - self._ipv6_addresses: List[ZeroconfIPv6Address] = [] + self._ipv4_addresses: list[ZeroconfIPv4Address] = [] + self._ipv6_addresses: list[ZeroconfIPv6Address] = [] if addresses is not None: self.addresses = addresses elif parsed_addresses is not None: @@ -201,20 +203,20 @@ def __init__( self.priority = priority self.server = server if server else None self.server_key = server.lower() if server else None - self._properties: Optional[Dict[bytes, Optional[bytes]]] = None - self._decoded_properties: Optional[Dict[str, Optional[str]]] = None + self._properties: dict[bytes, bytes | None] | None = None + self._decoded_properties: dict[str, str | None] | None = None if isinstance(properties, bytes): self._set_text(properties) else: self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl - self._new_records_futures: Optional[Set[asyncio.Future]] = None - self._dns_address_cache: Optional[List[DNSAddress]] = None - self._dns_pointer_cache: Optional[DNSPointer] = None - self._dns_service_cache: Optional[DNSService] = None - self._dns_text_cache: Optional[DNSText] = None - self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None + self._new_records_futures: set[asyncio.Future] | None = None + self._dns_address_cache: list[DNSAddress] | None = None + self._dns_pointer_cache: DNSPointer | None = None + self._dns_service_cache: DNSService | None = None + self._dns_text_cache: DNSText | None = None + self._get_address_and_nsec_records_cache: set[DNSRecord] | None = None self._query_record_types = {_TYPE_SRV, _TYPE_TXT, _TYPE_A, _TYPE_AAAA} @property @@ -232,7 +234,7 @@ def name(self, name: str) -> None: self._dns_text_cache = None @property - def addresses(self) -> List[bytes]: + def addresses(self) -> list[bytes]: """IPv4 addresses of this service. Only IPv4 addresses are returned for backward compatibility. @@ -242,7 +244,7 @@ def addresses(self) -> List[bytes]: return self.addresses_by_version(IPVersion.V4Only) @addresses.setter - def addresses(self, value: List[bytes]) -> None: + def addresses(self, value: list[bytes]) -> None: """Replace the addresses list. This replaces all currently stored addresses, both IPv4 and IPv6. @@ -272,7 +274,7 @@ def addresses(self, value: List[bytes]) -> None: self._ipv6_addresses.append(addr) @property - def properties(self) -> Dict[bytes, Optional[bytes]]: + def properties(self) -> dict[bytes, bytes | None]: """Return properties as bytes.""" if self._properties is None: self._unpack_text_into_properties() @@ -281,7 +283,7 @@ def properties(self) -> Dict[bytes, Optional[bytes]]: return self._properties @property - def decoded_properties(self) -> Dict[str, Optional[str]]: + def decoded_properties(self) -> dict[str, str | None]: """Return properties as strings.""" if self._decoded_properties is None: self._generate_decoded_properties() @@ -297,7 +299,7 @@ def async_clear_cache(self) -> None: self._dns_text_cache = None self._get_address_and_nsec_records_cache = None - async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + async def async_wait(self, timeout: float, loop: asyncio.AbstractEventLoop | None = None) -> None: """Calling task waits for a given number of milliseconds or until notified.""" if not self._new_records_futures: self._new_records_futures = set() @@ -305,7 +307,7 @@ async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventL loop or asyncio.get_running_loop(), self._new_records_futures, timeout ) - def addresses_by_version(self, version: IPVersion) -> List[bytes]: + def addresses_by_version(self, version: IPVersion) -> list[bytes]: """List addresses matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -325,7 +327,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: def ip_addresses_by_version( self, version: IPVersion - ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: + ) -> list[ZeroconfIPv4Address] | list[ZeroconfIPv6Address]: """List ip_address objects matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -338,7 +340,7 @@ def ip_addresses_by_version( def _ip_addresses_by_version_value( self, version_value: int_ - ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: + ) -> list[ZeroconfIPv4Address] | list[ZeroconfIPv6Address]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] @@ -346,7 +348,7 @@ def _ip_addresses_by_version_value( return self._ipv4_addresses return self._ipv6_addresses - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> list[str]: """List addresses in their parsed string form. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -357,7 +359,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """ return [str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)] - def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> list[str]: """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local addresses are qualified with % when available @@ -369,9 +371,9 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st """ return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] - def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: + def _set_properties(self, properties: dict[str | bytes, str | bytes | None]) -> None: """Sets properties and text of this info from a dictionary""" - list_: List[bytes] = [] + list_: list[bytes] = [] properties_contain_str = False result = b"" for key, value in properties.items(): @@ -425,7 +427,7 @@ def _unpack_text_into_properties(self) -> None: return index = 0 - properties: Dict[bytes, Optional[bytes]] = {} + properties: dict[bytes, bytes | None] = {} while index < end: length = text[index] index += 1 @@ -443,10 +445,10 @@ def get_name(self) -> str: return self._name[: len(self._name) - len(self.type) - 1] def _get_ip_addresses_from_cache_lifo( - self, zc: "Zeroconf", now: float_, type: int_ - ) -> List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: + self, zc: Zeroconf, now: float_, type: int_ + ) -> list[ZeroconfIPv4Address | ZeroconfIPv6Address]: """Set IPv6 addresses from the cache.""" - address_list: List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]] = [] + address_list: list[ZeroconfIPv4Address | ZeroconfIPv6Address] = [] for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue @@ -456,7 +458,7 @@ def _get_ip_addresses_from_cache_lifo( address_list.reverse() # Reverse to get LIFO order return address_list - def _set_ipv6_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: + def _set_ipv6_addresses_from_cache(self, zc: Zeroconf, now: float_) -> None: """Set IPv6 addresses from the cache.""" if TYPE_CHECKING: self._ipv6_addresses = cast( @@ -466,7 +468,7 @@ def _set_ipv6_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: else: self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - def _set_ipv4_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: + def _set_ipv4_addresses_from_cache(self, zc: Zeroconf, now: float_) -> None: """Set IPv4 addresses from the cache.""" if TYPE_CHECKING: self._ipv4_addresses = cast( @@ -476,7 +478,7 @@ def _set_ipv4_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: else: self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Updates service information from a DNS record. This method will be run in the event loop. @@ -488,7 +490,7 @@ def async_update_records(self, zc: "Zeroconf", now: float_, records: List[Record if updated and new_records_futures: _resolve_all_futures_to_none(new_records_futures) - def _process_record_threadsafe(self, zc: "Zeroconf", record: DNSRecord, now: float_) -> bool: + def _process_record_threadsafe(self, zc: Zeroconf, record: DNSRecord, now: float_) -> bool: """Thread safe record updating. Returns True if a new record was added. @@ -575,17 +577,17 @@ def _process_record_threadsafe(self, zc: "Zeroconf", record: DNSRecord, now: flo def dns_addresses( self, - override_ttl: Optional[int] = None, + override_ttl: int | None = None, version: IPVersion = IPVersion.All, - ) -> List[DNSAddress]: + ) -> list[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return self._dns_addresses(override_ttl, version) def _dns_addresses( self, - override_ttl: Optional[int], + override_ttl: int | None, version: IPVersion, - ) -> List[DNSAddress]: + ) -> list[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" cacheable = version is IPVersion.All and override_ttl is None if self._dns_address_cache is not None and cacheable: @@ -609,11 +611,11 @@ def _dns_addresses( self._dns_address_cache = records return records - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + def dns_pointer(self, override_ttl: int | None = None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" return self._dns_pointer(override_ttl) - def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer: + def _dns_pointer(self, override_ttl: int | None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" cacheable = override_ttl is None if self._dns_pointer_cache is not None and cacheable: @@ -630,11 +632,11 @@ def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer: self._dns_pointer_cache = record return record - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + def dns_service(self, override_ttl: int | None = None) -> DNSService: """Return DNSService from ServiceInfo.""" return self._dns_service(override_ttl) - def _dns_service(self, override_ttl: Optional[int]) -> DNSService: + def _dns_service(self, override_ttl: int | None) -> DNSService: """Return DNSService from ServiceInfo.""" cacheable = override_ttl is None if self._dns_service_cache is not None and cacheable: @@ -657,11 +659,11 @@ def _dns_service(self, override_ttl: Optional[int]) -> DNSService: self._dns_service_cache = record return record - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + def dns_text(self, override_ttl: int | None = None) -> DNSText: """Return DNSText from ServiceInfo.""" return self._dns_text(override_ttl) - def _dns_text(self, override_ttl: Optional[int]) -> DNSText: + def _dns_text(self, override_ttl: int | None) -> DNSText: """Return DNSText from ServiceInfo.""" cacheable = override_ttl is None if self._dns_text_cache is not None and cacheable: @@ -678,11 +680,11 @@ def _dns_text(self, override_ttl: Optional[int]) -> DNSText: self._dns_text_cache = record return record - def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec: + def dns_nsec(self, missing_types: list[int], override_ttl: int | None = None) -> DNSNsec: """Return DNSNsec from ServiceInfo.""" return self._dns_nsec(missing_types, override_ttl) - def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DNSNsec: + def _dns_nsec(self, missing_types: list[int], override_ttl: int | None) -> DNSNsec: """Return DNSNsec from ServiceInfo.""" return DNSNsec( self._name, @@ -694,17 +696,17 @@ def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DN 0.0, ) - def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]: + def get_address_and_nsec_records(self, override_ttl: int | None = None) -> set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" return self._get_address_and_nsec_records(override_ttl) - def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSRecord]: + def _get_address_and_nsec_records(self, override_ttl: int | None) -> set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" cacheable = override_ttl is None if self._get_address_and_nsec_records_cache is not None and cacheable: return self._get_address_and_nsec_records_cache - missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() - records: Set[DNSRecord] = set() + missing_types: set[int] = _ADDRESS_RECORD_TYPES.copy() + records: set[DNSRecord] = set() for dns_address in self._dns_addresses(override_ttl, IPVersion.All): missing_types.discard(dns_address.type) records.add(dns_address) @@ -715,7 +717,7 @@ def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSR self._get_address_and_nsec_records_cache = records return records - def _get_address_records_from_cache_by_type(self, zc: "Zeroconf", _type: int_) -> List[DNSAddress]: + def _get_address_records_from_cache_by_type(self, zc: Zeroconf, _type: int_) -> list[DNSAddress]: """Get the addresses from the cache.""" if self.server_key is None: return [] @@ -738,14 +740,14 @@ def set_server_if_missing(self) -> None: self.server = self._name self.server_key = self.key - def load_from_cache(self, zc: "Zeroconf", now: Optional[float_] = None) -> bool: + def load_from_cache(self, zc: Zeroconf, now: float_ | None = None) -> bool: """Populate the service info from the cache. This method is designed to be threadsafe. """ return self._load_from_cache(zc, now or current_time_millis()) - def _load_from_cache(self, zc: "Zeroconf", now: float_) -> bool: + def _load_from_cache(self, zc: Zeroconf, now: float_) -> bool: """Populate the service info from the cache. This method is designed to be threadsafe. @@ -775,10 +777,10 @@ def _is_complete(self) -> bool: def request( self, - zc: "Zeroconf", + zc: Zeroconf, timeout: float, - question_type: Optional[DNSQuestionType] = None, - addr: Optional[str] = None, + question_type: DNSQuestionType | None = None, + addr: str | None = None, port: int = _MDNS_PORT, ) -> bool: """Returns true if the service could be discovered on the @@ -814,10 +816,10 @@ def _get_random_delay(self) -> int_: async def async_request( self, - zc: "Zeroconf", + zc: Zeroconf, timeout: float, - question_type: Optional[DNSQuestionType] = None, - addr: Optional[str] = None, + question_type: DNSQuestionType | None = None, + addr: str | None = None, port: int = _MDNS_PORT, ) -> bool: """Returns true if the service could be discovered on the @@ -914,7 +916,7 @@ def _add_question_with_known_answers( out.add_answer_at_time(answer, now) def _generate_request_query( - self, zc: "Zeroconf", now: float_, question_type: DNSQuestionType + self, zc: Zeroconf, now: float_, question_type: DNSQuestionType ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index 4100c690e..937992eb0 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -20,7 +20,7 @@ USA """ -from typing import Dict, List, Optional, Union +from __future__ import annotations from .._exceptions import ServiceNameAlreadyRegistered from .info import ServiceInfo @@ -41,16 +41,16 @@ def __init__( self, ) -> None: """Create the ServiceRegistry class.""" - self._services: Dict[str, ServiceInfo] = {} - self.types: Dict[str, List] = {} - self.servers: Dict[str, List] = {} + self._services: dict[str, ServiceInfo] = {} + self.types: dict[str, list] = {} + self.servers: dict[str, list] = {} self.has_entries: bool = False def async_add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" self._add(info) - def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: + def async_remove(self, info: list[ServiceInfo] | ServiceInfo) -> None: """Remove a new service from the registry.""" self._remove(info if isinstance(info, list) else [info]) @@ -59,27 +59,27 @@ def async_update(self, info: ServiceInfo) -> None: self._remove([info]) self._add(info) - def async_get_service_infos(self) -> List[ServiceInfo]: + def async_get_service_infos(self) -> list[ServiceInfo]: """Return all ServiceInfo.""" return list(self._services.values()) - def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: + def async_get_info_name(self, name: str) -> ServiceInfo | None: """Return all ServiceInfo for the name.""" return self._services.get(name) - def async_get_types(self) -> List[str]: + def async_get_types(self) -> list[str]: """Return all types.""" return list(self.types) - def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: + def async_get_infos_type(self, type_: str) -> list[ServiceInfo]: """Return all ServiceInfo matching type.""" return self._async_get_by_index(self.types, type_) - def async_get_infos_server(self, server: str) -> List[ServiceInfo]: + def async_get_infos_server(self, server: str) -> list[ServiceInfo]: """Return all ServiceInfo matching server.""" return self._async_get_by_index(self.servers, server) - def _async_get_by_index(self, records: Dict[str, List], key: _str) -> List[ServiceInfo]: + def _async_get_by_index(self, records: dict[str, list], key: _str) -> list[ServiceInfo]: """Return all ServiceInfo matching the index.""" record_list = records.get(key) if record_list is None: @@ -98,7 +98,7 @@ def _add(self, info: ServiceInfo) -> None: self.servers.setdefault(info.server_key, []).append(info.key) self.has_entries = True - def _remove(self, infos: List[ServiceInfo]) -> None: + def _remove(self, infos: list[ServiceInfo]) -> None: """Remove a services under the lock.""" for info in infos: old_service_info = self._services.get(info.key) diff --git a/src/zeroconf/_services/types.py b/src/zeroconf/_services/types.py index 63b6d19a1..af25dc6db 100644 --- a/src/zeroconf/_services/types.py +++ b/src/zeroconf/_services/types.py @@ -20,8 +20,9 @@ USA """ +from __future__ import annotations + import time -from typing import Optional, Set, Tuple, Union from .._core import Zeroconf from .._services import ServiceListener @@ -37,7 +38,7 @@ class ZeroconfServiceTypes(ServiceListener): def __init__(self) -> None: """Keep track of found services in a set.""" - self.found_services: Set[str] = set() + self.found_services: set[str] = set() def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: """Service added.""" @@ -52,11 +53,11 @@ def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: @classmethod def find( cls, - zc: Optional[Zeroconf] = None, - timeout: Union[int, float] = 5, + zc: Zeroconf | None = None, + timeout: int | float = 5, interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: + ip_version: IPVersion | None = None, + ) -> tuple[str, ...]: """ Return all of the advertised services on any local networks. diff --git a/src/zeroconf/_transport.py b/src/zeroconf/_transport.py index b08110943..c8d7699b9 100644 --- a/src/zeroconf/_transport.py +++ b/src/zeroconf/_transport.py @@ -20,9 +20,10 @@ USA """ +from __future__ import annotations + import asyncio import socket -from typing import Tuple class _WrappedTransport: @@ -42,7 +43,7 @@ def __init__( is_ipv6: bool, sock: socket.socket, fileno: int, - sock_name: Tuple, + sock_name: tuple, ) -> None: """Initialize the wrapped transport. diff --git a/src/zeroconf/_updates.py b/src/zeroconf/_updates.py index 58be33d8c..c0bf9b8c9 100644 --- a/src/zeroconf/_updates.py +++ b/src/zeroconf/_updates.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List +from __future__ import annotations + +from typing import TYPE_CHECKING from ._dns import DNSRecord from ._record_update import RecordUpdate @@ -40,7 +42,7 @@ class RecordUpdateListener: """ def update_record( # pylint: disable=no-self-use - self, zc: "Zeroconf", now: float, record: DNSRecord + self, zc: Zeroconf, now: float, record: DNSRecord ) -> None: """Update a single record. @@ -49,7 +51,7 @@ def update_record( # pylint: disable=no-self-use """ raise RuntimeError("update_record is deprecated and will be removed in a future version.") - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Update multiple records in one shot. All records that are received in a single packet are passed diff --git a/src/zeroconf/_utils/__init__.py b/src/zeroconf/_utils/__init__.py index 30920c6aa..584a74eca 100644 --- a/src/zeroconf/_utils/__init__.py +++ b/src/zeroconf/_utils/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_utils/ipaddress.py b/src/zeroconf/_utils/ipaddress.py index 64cdfb638..d172d0c9f 100644 --- a/src/zeroconf/_utils/ipaddress.py +++ b/src/zeroconf/_utils/ipaddress.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + from functools import cache, lru_cache from ipaddress import AddressValueError, IPv4Address, IPv6Address, NetmaskValueError -from typing import Any, Optional, Union +from typing import Any from .._dns import DNSAddress from ..const import _TYPE_AAAA @@ -99,8 +101,8 @@ def is_loopback(self) -> bool: @lru_cache(maxsize=512) def _cached_ip_addresses( - address: Union[str, bytes, int], -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: + address: str | bytes | int, +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Cache IP addresses.""" try: return ZeroconfIPv4Address(address) @@ -119,7 +121,7 @@ def _cached_ip_addresses( def get_ip_address_object_from_record( record: DNSAddress, -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Get the IP address object from the record.""" if record.type == _TYPE_AAAA and record.scope_id: return ip_bytes_and_scope_to_address(record.address, record.scope_id) @@ -128,7 +130,7 @@ def get_ip_address_object_from_record( def ip_bytes_and_scope_to_address( address: bytes_, scope: int_ -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Convert the bytes and scope to an IP address object.""" base_address = cached_ip_addresses_wrapper(address) if base_address is not None and base_address.is_link_local: @@ -137,7 +139,7 @@ def ip_bytes_and_scope_to_address( return base_address -def str_without_scope_id(addr: Union[ZeroconfIPv4Address, ZeroconfIPv6Address]) -> str: +def str_without_scope_id(addr: ZeroconfIPv4Address | ZeroconfIPv6Address) -> str: """Return the string representation of the address without the scope id.""" if addr.version == 6: address_str = str(addr) diff --git a/src/zeroconf/_utils/name.py b/src/zeroconf/_utils/name.py index cda01b28e..de35f7afb 100644 --- a/src/zeroconf/_utils/name.py +++ b/src/zeroconf/_utils/name.py @@ -20,8 +20,9 @@ USA """ +from __future__ import annotations + from functools import lru_cache -from typing import Set from .._exceptions import BadTypeInNameException from ..const import ( @@ -162,7 +163,7 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis return service_name + trailer -def possible_types(name: str) -> Set[str]: +def possible_types(name: str) -> set[str]: """Build a set of all possible types from a fully qualified name.""" labels = name.split(".") label_count = len(labels) diff --git a/src/zeroconf/_utils/net.py b/src/zeroconf/_utils/net.py index 7298bec4d..3cc4336bf 100644 --- a/src/zeroconf/_utils/net.py +++ b/src/zeroconf/_utils/net.py @@ -20,13 +20,15 @@ USA """ +from __future__ import annotations + import enum import errno import ipaddress import socket import struct import sys -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Sequence, Tuple, Union, cast import ifaddr @@ -70,11 +72,11 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) -def get_all_addresses() -> List[str]: +def get_all_addresses() -> list[str]: return list({addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4}) -def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: +def get_all_addresses_v6() -> list[tuple[tuple[str, int, int], int]]: # IPv6 multicast uses positive indexes for interfaces # TODO: What about multi-address interfaces? return list( @@ -82,7 +84,7 @@ def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: ) -def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: +def ip6_to_address_and_index(adapters: list[Any], ip: str) -> tuple[tuple[str, int, int], int]: if "%" in ip: ip = ip[: ip.index("%")] # Strip scope_id. ipaddr = ipaddress.ip_address(ip) @@ -98,7 +100,7 @@ def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, i raise RuntimeError(f"No adapter found for IP address {ip}") -def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: +def interface_index_to_ip6_address(adapters: list[Any], index: int) -> tuple[str, int, int]: for adapter in adapters: if adapter.index == index: for adapter_ip in adapter.ips: @@ -110,8 +112,8 @@ def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str def ip6_addresses_to_indexes( - interfaces: Sequence[Union[str, int, Tuple[Tuple[str, int, int], int]]], -) -> List[Tuple[Tuple[str, int, int], int]]: + interfaces: Sequence[str | int | tuple[tuple[str, int, int], int]], +) -> list[tuple[tuple[str, int, int], int]]: """Convert IPv6 interface addresses to interface indexes. IPv4 addresses are ignored. @@ -133,14 +135,14 @@ def ip6_addresses_to_indexes( def normalize_interface_choice( choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: +) -> list[str | tuple[tuple[str, int, int], int]]: """Convert the interfaces choice into internal representation. :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). :param ip_address: IP version to use (ignored if `choice` is a list). :returns: List of IP addresses (for IPv4) and indexes (for IPv6). """ - result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = [] + result: list[str | tuple[tuple[str, int, int], int]] = [] if choice is InterfaceChoice.Default: if ip_version != IPVersion.V4Only: # IPv6 multicast uses interface 0 to mean the default @@ -196,7 +198,7 @@ def set_so_reuseport_if_available(s: socket.socket) -> None: def set_mdns_port_socket_options_for_ip_version( s: socket.socket, - bind_addr: Union[Tuple[str], Tuple[str, int, int]], + bind_addr: tuple[str] | tuple[str, int, int], ip_version: IPVersion, ) -> None: """Set ttl/hops and loop for mdns port.""" @@ -219,11 +221,11 @@ def set_mdns_port_socket_options_for_ip_version( def new_socket( - bind_addr: Union[Tuple[str], Tuple[str, int, int]], + bind_addr: tuple[str] | tuple[str, int, int], port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False, -) -> Optional[socket.socket]: +) -> socket.socket | None: log.debug( "Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r", port, @@ -265,7 +267,7 @@ def new_socket( def add_multicast_member( listen_socket: socket.socket, - interface: Union[str, Tuple[Tuple[str, int, int], int]], + interface: str | tuple[tuple[str, int, int], int], ) -> bool: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, tuple) @@ -331,9 +333,9 @@ def add_multicast_member( def new_respond_socket( - interface: Union[str, Tuple[Tuple[str, int, int], int]], + interface: str | tuple[tuple[str, int, int], int], apple_p2p: bool = False, -) -> Optional[socket.socket]: +) -> socket.socket | None: is_v6 = isinstance(interface, tuple) respond_socket = new_socket( ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), @@ -360,7 +362,7 @@ def create_sockets( unicast: bool = False, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False, -) -> Tuple[Optional[socket.socket], List[socket.socket]]: +) -> tuple[socket.socket | None, list[socket.socket]]: if unicast: listen_socket = None else: diff --git a/src/zeroconf/_utils/time.py b/src/zeroconf/_utils/time.py index 055e0658a..4057f0630 100644 --- a/src/zeroconf/_utils/time.py +++ b/src/zeroconf/_utils/time.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import time _float = float diff --git a/src/zeroconf/const.py b/src/zeroconf/const.py index d84cb73ba..3b4b3abcc 100644 --- a/src/zeroconf/const.py +++ b/src/zeroconf/const.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import re import socket From 6399b44d52f5d4e8ff8c0eba6e73ddd34416e755 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:01:30 -0600 Subject: [PATCH 06/13] feat: eliminate async_timeout dep on python less than 3.11 --- src/zeroconf/_dns.py | 70 +++++++++++++++++++------------------ src/zeroconf/_exceptions.py | 2 ++ src/zeroconf/_logger.py | 6 ++-- src/zeroconf/asyncio.py | 56 +++++++++++++++-------------- 4 files changed, 71 insertions(+), 63 deletions(-) diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index c22f8b170..bc0a3948e 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import enum import socket -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Any, cast from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address @@ -94,7 +96,7 @@ def get_type(t: int) -> str: """Type accessor""" return _TYPES.get(t, f"?({t})") - def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: + def entry_to_string(self, hdr: str, other: bytes | str | None) -> str: """String representation with additional information""" return "{}[{},{}{},{}]{}".format( hdr, @@ -119,7 +121,7 @@ def _fast_init(self, name: str, type_: _int, class_: _int) -> None: self._fast_init_entry(name, type_, class_) self._hash = hash((self.key, type_, self.class_)) - def answered_by(self, rec: "DNSRecord") -> bool: + def answered_by(self, rec: DNSRecord) -> bool: """Returns true if the question is answered by the record""" return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name @@ -170,8 +172,8 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[float, int], - created: Optional[float] = None, + ttl: float | int, + created: float | None = None, ) -> None: self._fast_init_record(name, type_, class_, ttl, created or current_time_millis()) @@ -185,10 +187,10 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException - def __lt__(self, other: "DNSRecord") -> bool: + def __lt__(self, other: DNSRecord) -> bool: return self.ttl < other.ttl - def suppressed_by(self, msg: "DNSIncoming") -> bool: + def suppressed_by(self, msg: DNSIncoming) -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" answers = msg.answers() @@ -208,7 +210,7 @@ def get_expiration_time(self, percent: _int) -> float: return self.created + (percent * self.ttl * 10) # TODO: Switch to just int here - def get_remaining_ttl(self, now: _float) -> Union[int, float]: + def get_remaining_ttl(self, now: _float) -> int | float: """Returns the remaining TTL in seconds.""" remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0 return 0 if remain < 0 else remain @@ -225,18 +227,18 @@ def is_recent(self, now: _float) -> bool: """Returns true if the record more than one quarter of its TTL remaining.""" return self.created + (_RECENT_TIME_MS * self.ttl) > now - def _set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None: + def _set_created_ttl(self, created: _float, ttl: float | int) -> None: """Set the created and ttl of a record.""" # It would be better if we made a copy instead of mutating the record # in place, but records currently don't have a copy method. self.created = created self.ttl = ttl - def write(self, out: "DNSOutgoing") -> None: # pylint: disable=no-self-use + def write(self, out: DNSOutgoing) -> None: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException - def to_string(self, other: Union[bytes, str]) -> str: + def to_string(self, other: bytes | str) -> str: """String representation with additional information""" arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}" return DNSEntry.entry_to_string(self, "record", arg) @@ -254,8 +256,8 @@ def __init__( class_: int, ttl: int, address: bytes, - scope_id: Optional[int] = None, - created: Optional[float] = None, + scope_id: int | None = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, address, scope_id, created or current_time_millis()) @@ -266,7 +268,7 @@ def _fast_init( class_: _int, ttl: _float, address: bytes, - scope_id: Optional[_int], + scope_id: _int | None, created: _float, ) -> None: """Fast init for reuse.""" @@ -275,7 +277,7 @@ def _fast_init( self.scope_id = scope_id self._hash = hash((self.key, type_, self.class_, address, scope_id)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_string(self.address) @@ -320,7 +322,7 @@ def __init__( ttl: int, cpu: str, os: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, cpu, os, created or current_time_millis()) @@ -333,7 +335,7 @@ def _fast_init( self.os = os self._hash = hash((self.key, type_, self.class_, cpu, os)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_character_string(self.cpu.encode("utf-8")) out.write_character_string(self.os.encode("utf-8")) @@ -367,7 +369,7 @@ def __init__( class_: int, ttl: int, alias: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, alias, created or current_time_millis()) @@ -389,7 +391,7 @@ def max_size_compressed(self) -> int: + _NAME_COMPRESSION_MIN_SIZE ) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_name(self.alias) @@ -422,7 +424,7 @@ def __init__( class_: int, ttl: int, text: bytes, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, text, created or current_time_millis()) @@ -433,7 +435,7 @@ def _fast_init( self.text = text self._hash = hash((self.key, type_, self.class_, text)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_string(self.text) @@ -466,12 +468,12 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[float, int], + ttl: float | int, priority: int, weight: int, port: int, server: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init( name, type_, class_, ttl, priority, weight, port, server, created or current_time_millis() @@ -497,7 +499,7 @@ def _fast_init( self.server_key = server.lower() self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_short(self.priority) out.write_short(self.weight) @@ -537,10 +539,10 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[int, float], + ttl: int | float, next_name: str, - rdtypes: List[int], - created: Optional[float] = None, + rdtypes: list[int], + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, next_name, rdtypes, created or current_time_millis()) @@ -551,7 +553,7 @@ def _fast_init( class_: _int, ttl: _float, next_name: str, - rdtypes: List[_int], + rdtypes: list[_int], created: _float, ) -> None: self._fast_init_record(name, type_, class_, ttl, created) @@ -559,7 +561,7 @@ def _fast_init( self.rdtypes = sorted(rdtypes) self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet.""" bitmap = bytearray(b"\0" * 32) total_octets = 0 @@ -610,21 +612,21 @@ class DNSRRSet: __slots__ = ("_lookup", "_records") - def __init__(self, records: List[DNSRecord]) -> None: + def __init__(self, records: list[DNSRecord]) -> None: """Create an RRset from records sets.""" self._records = records - self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None + self._lookup: dict[DNSRecord, DNSRecord] | None = None @property - def lookup(self) -> Dict[DNSRecord, DNSRecord]: + def lookup(self) -> dict[DNSRecord, DNSRecord]: """Return the lookup table.""" return self._get_lookup() - def lookup_set(self) -> Set[DNSRecord]: + def lookup_set(self) -> set[DNSRecord]: """Return the lookup table as aset.""" return set(self._get_lookup()) - def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]: + def _get_lookup(self) -> dict[DNSRecord, DNSRecord]: """Return the lookup table, building it if needed.""" if self._lookup is None: # Build the hash table so we can lookup the record ttl diff --git a/src/zeroconf/_exceptions.py b/src/zeroconf/_exceptions.py index 5eb58f793..5fc812593 100644 --- a/src/zeroconf/_exceptions.py +++ b/src/zeroconf/_exceptions.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + class Error(Exception): """Base class for all zeroconf exceptions.""" diff --git a/src/zeroconf/_logger.py b/src/zeroconf/_logger.py index 1556522eb..0d734dfde 100644 --- a/src/zeroconf/_logger.py +++ b/src/zeroconf/_logger.py @@ -21,9 +21,11 @@ USA """ +from __future__ import annotations + import logging import sys -from typing import Any, ClassVar, Dict, Union, cast +from typing import Any, ClassVar, cast log = logging.getLogger(__name__.split(".", maxsplit=1)[0]) log.addHandler(logging.NullHandler()) @@ -38,7 +40,7 @@ def set_logger_level_if_unset() -> None: class QuietLogger: - _seen_logs: ClassVar[Dict[str, Union[int, tuple]]] = {} + _seen_logs: ClassVar[dict[str, int | tuple]] = {} @classmethod def log_exception_warning(cls, *logger_data: Any) -> None: diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 926ef5099..2a29a4bb7 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -20,10 +20,12 @@ USA """ +from __future__ import annotations + import asyncio import contextlib from types import TracebackType # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Awaitable, Callable from ._core import Zeroconf from ._dns import DNSQuestionType @@ -63,14 +65,14 @@ class AsyncServiceBrowser(_ServiceBrowserBase): def __init__( self, - zeroconf: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zeroconf: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) self._async_start() @@ -79,15 +81,15 @@ async def async_cancel(self) -> None: """Cancel the browser.""" self._async_cancel() - async def __aenter__(self) -> "AsyncServiceBrowser": + async def __aenter__(self) -> AsyncServiceBrowser: return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: await self.async_cancel() return None @@ -98,11 +100,11 @@ class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): @classmethod async def async_find( cls, - aiozc: Optional["AsyncZeroconf"] = None, - timeout: Union[int, float] = 5, + aiozc: AsyncZeroconf | None = None, + timeout: int | float = 5, interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: + ip_version: IPVersion | None = None, + ) -> tuple[str, ...]: """ Return all of the advertised services on any local networks. @@ -145,9 +147,9 @@ def __init__( self, interfaces: InterfacesType = InterfaceChoice.All, unicast: bool = False, - ip_version: Optional[IPVersion] = None, + ip_version: IPVersion | None = None, apple_p2p: bool = False, - zc: Optional[Zeroconf] = None, + zc: Zeroconf | None = None, ) -> None: """Creates an instance of the Zeroconf class, establishing multicast communications, and listening. @@ -170,12 +172,12 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) - self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} + self.async_browsers: dict[ServiceListener, AsyncServiceBrowser] = {} async def async_register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -236,8 +238,8 @@ async def async_get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[AsyncServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> AsyncServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -268,14 +270,14 @@ async def async_remove_all_service_listeners(self) -> None: *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers)) ) - async def __aenter__(self) -> "AsyncZeroconf": + async def __aenter__(self) -> AsyncZeroconf: return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: await self.async_close() return None From 33dddd3ea37c1166949b2d89f75b70d62abd81ab Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:05:18 -0600 Subject: [PATCH 07/13] chore: fix incompatible objects --- src/zeroconf/_handlers/record_manager.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroconf/_handlers/record_manager.pxd b/src/zeroconf/_handlers/record_manager.pxd index d4e068c2e..37232b131 100644 --- a/src/zeroconf/_handlers/record_manager.pxd +++ b/src/zeroconf/_handlers/record_manager.pxd @@ -21,7 +21,7 @@ cdef class RecordManager: cdef public DNSCache cache cdef public cython.set listeners - cpdef void async_updates(self, object now, object records) + cpdef void async_updates(self, object now, list records) cpdef void async_updates_complete(self, bint notify) From 074cbcf6458b8a8da505741c21f178a90af529f1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:10:03 -0600 Subject: [PATCH 08/13] chore: fix handling of future --- src/zeroconf/_core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 82e0fad20..9b1f2de6e 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -203,7 +203,12 @@ def __init__( @property def started(self) -> bool: """Check if the instance has started.""" - return bool(not self.done and self.engine.running_future and self.engine.running_future.result()) + return bool( + not self.done + and self.engine.running_future + and self.engine.running_future.done() + and not self.engine.running_future.exception() + ) def start(self) -> None: """Start Zeroconf.""" @@ -237,7 +242,7 @@ async def async_wait_for_start(self) -> None: raise NotRunningException assert self.engine.running_future is not None await wait_future_or_timeout(self.engine.running_future, timeout=_STARTUP_TIMEOUT) - if not self.engine.running_future.result() or self.done: + if not self.engine.running_future.done() or self.engine.running_future.exception() or self.done: raise NotRunningException @property From 7d7436a8d24db266271a4b40915ad01711529aa1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:16:52 -0600 Subject: [PATCH 09/13] chore: fix shutdown --- src/zeroconf/_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index b9f5d9a57..8c800a33a 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -143,7 +143,8 @@ async def _async_close(self) -> None: def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" assert self.running_future is not None - self.running_future = None + assert self.loop is not None + self.running_future = self.loop.create_future() for wrapped_transport in itertools.chain(self.senders, self.readers): wrapped_transport.transport.close() From 05577430bcb34e4f91560fa09fc2f566f665559c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:27:43 -0600 Subject: [PATCH 10/13] fix: cancellation handling --- src/zeroconf/_core.py | 8 +++++++- src/zeroconf/_utils/asyncio.py | 4 ++++ src/zeroconf/asyncio.py | 3 ++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 9b1f2de6e..1c946f52c 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -207,6 +207,7 @@ def started(self) -> bool: not self.done and self.engine.running_future and self.engine.running_future.done() + and not self.engine.running_future.cancelled() and not self.engine.running_future.exception() ) @@ -242,7 +243,12 @@ async def async_wait_for_start(self) -> None: raise NotRunningException assert self.engine.running_future is not None await wait_future_or_timeout(self.engine.running_future, timeout=_STARTUP_TIMEOUT) - if not self.engine.running_future.done() or self.engine.running_future.exception() or self.done: + if ( + not self.engine.running_future.done() + or self.engine.running_future.cancelled() + or self.engine.running_future.exception() + or self.done + ): raise NotRunningException @property diff --git a/src/zeroconf/_utils/asyncio.py b/src/zeroconf/_utils/asyncio.py index 5d3beb7cb..c92d99d56 100644 --- a/src/zeroconf/_utils/asyncio.py +++ b/src/zeroconf/_utils/asyncio.py @@ -25,6 +25,7 @@ import asyncio import concurrent.futures import contextlib +import sys from typing import Any, Awaitable, Coroutine from .._exceptions import EventLoopBlocked @@ -70,6 +71,9 @@ async def wait_future_or_timeout(future: asyncio.Future[bool | None], timeout: f handle = loop.call_later(timeout, _set_future_none_if_not_done, future) try: await future + except asyncio.CancelledError: + if sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling(): + raise finally: handle.cancel() diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 2a29a4bb7..6e9d6e9bd 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -29,6 +29,7 @@ from ._core import Zeroconf from ._dns import DNSQuestionType +from ._exceptions import NotRunningException from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase from ._services.info import AsyncServiceInfo, ServiceInfo @@ -227,7 +228,7 @@ async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" if not self.zeroconf.done: - with contextlib.suppress(asyncio.TimeoutError): + with contextlib.suppress(asyncio.TimeoutError, NotRunningException): await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) await self.async_remove_all_service_listeners() await self.async_unregister_all_services() From 9fd11cfaedf3eb6c872f69acbb895196cb7b90d0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:31:31 -0600 Subject: [PATCH 11/13] chore: refactor --- src/zeroconf/_core.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 1c946f52c..e3a209c41 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -203,12 +203,14 @@ def __init__( @property def started(self) -> bool: """Check if the instance has started.""" + running_future = self.engine.running_future return bool( not self.done - and self.engine.running_future - and self.engine.running_future.done() - and not self.engine.running_future.cancelled() - and not self.engine.running_future.exception() + and running_future + and running_future.done() + and not running_future.cancelled() + and not running_future.exception() + and running_future.result() ) def start(self) -> None: @@ -247,6 +249,7 @@ async def async_wait_for_start(self) -> None: not self.engine.running_future.done() or self.engine.running_future.cancelled() or self.engine.running_future.exception() + or not self.engine.running_future.result() or self.done ): raise NotRunningException From 410637ec2daba8ea7550dbb844b843b9ae0f74ea Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:33:08 -0600 Subject: [PATCH 12/13] chore: refactor --- src/zeroconf/_core.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index e3a209c41..7c73e9498 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -245,13 +245,7 @@ async def async_wait_for_start(self) -> None: raise NotRunningException assert self.engine.running_future is not None await wait_future_or_timeout(self.engine.running_future, timeout=_STARTUP_TIMEOUT) - if ( - not self.engine.running_future.done() - or self.engine.running_future.cancelled() - or self.engine.running_future.exception() - or not self.engine.running_future.result() - or self.done - ): + if not self.started: raise NotRunningException @property From 82caf4c932a85bf2183209fd406a0aa55a6c60c1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 31 Jan 2025 13:34:56 -0600 Subject: [PATCH 13/13] chore: refactor --- src/zeroconf/_core.py | 4 ++-- src/zeroconf/asyncio.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 7c73e9498..3f007c174 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -235,7 +235,7 @@ def _run_loop() -> None: self._loop_thread.start() loop_thread_ready.wait() - async def async_wait_for_start(self) -> None: + async def async_wait_for_start(self, timeout: float = _STARTUP_TIMEOUT) -> None: """Wait for start up for actions that require a running Zeroconf instance. Throws NotRunningException if the instance is not running or could @@ -244,7 +244,7 @@ async def async_wait_for_start(self) -> None: if self.done: # If the instance was shutdown from under us, raise immediately raise NotRunningException assert self.engine.running_future is not None - await wait_future_or_timeout(self.engine.running_future, timeout=_STARTUP_TIMEOUT) + await wait_future_or_timeout(self.engine.running_future, timeout=timeout) if not self.started: raise NotRunningException diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 6e9d6e9bd..ce5a43eb9 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -228,8 +228,8 @@ async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" if not self.zeroconf.done: - with contextlib.suppress(asyncio.TimeoutError, NotRunningException): - await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) + with contextlib.suppress(NotRunningException): + await self.zeroconf.async_wait_for_start(timeout=1.0) await self.async_remove_all_service_listeners() await self.async_unregister_all_services() await self.zeroconf._async_close() # pylint: disable=protected-access