diff --git a/poetry.lock b/poetry.lock index 14c79f61..962899b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12,19 +12,6 @@ files = [ {file = "alabaster-0.7.16.tar.gz", hash = "sha256:75a8b99c28a5dad50dd7f8ccdd447a121ddb3892da9e53d1ca5cca3106d58d65"}, ] -[[package]] -name = "async-timeout" -version = "5.0.1" -description = "Timeout context manager for asyncio programs" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.11\"" -files = [ - {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, - {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, -] - [[package]] name = "babel" version = "2.16.0" @@ -1140,4 +1127,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "eb91a0dd1c260f37d2579b4793f537f8017f9e1801e2a372849439f5c9132245" +content-hash = "ea903296f015035c594eb8cce08d4dedc716074e33644033938dfdb5f047d72e" diff --git a/pyproject.toml b/pyproject.toml index f5084253..7514d9a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ prerelease = true [tool.poetry.dependencies] python = "^3.9" -async-timeout = {version = ">=3.0.0", python = "<3.11"} ifaddr = ">=0.1.7" [tool.poetry.group.dev.dependencies] diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 01e98e8f..3f007c17 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -55,8 +55,8 @@ get_running_loop, run_coro_with_timeout, shutdown_loop, - wait_event_or_timeout, wait_for_future_set_or_timeout, + wait_future_or_timeout, ) from ._utils.name import service_type_name from ._utils.net import ( @@ -203,7 +203,15 @@ def __init__( @property def started(self) -> bool: """Check if the instance has started.""" - return bool(not self.done and self.engine.running_event and self.engine.running_event.is_set()) + running_future = self.engine.running_future + return bool( + not self.done + and running_future + and running_future.done() + and not running_future.cancelled() + and not running_future.exception() + and running_future.result() + ) def start(self) -> None: """Start Zeroconf.""" @@ -227,7 +235,7 @@ def _run_loop() -> None: self._loop_thread.start() loop_thread_ready.wait() - async def async_wait_for_start(self) -> None: + async def async_wait_for_start(self, timeout: float = _STARTUP_TIMEOUT) -> None: """Wait for start up for actions that require a running Zeroconf instance. Throws NotRunningException if the instance is not running or could @@ -235,9 +243,9 @@ async def async_wait_for_start(self) -> None: """ if self.done: # If the instance was shutdown from under us, raise immediately raise NotRunningException - assert self.engine.running_event is not None - await wait_event_or_timeout(self.engine.running_event, timeout=_STARTUP_TIMEOUT) - if not self.engine.running_event.is_set() or self.done: + assert self.engine.running_future is not None + await wait_future_or_timeout(self.engine.running_future, timeout=timeout) + if not self.started: raise NotRunningException @property diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 7b22f788..8c800a33 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -53,7 +53,7 @@ class AsyncEngine: "loop", "protocols", "readers", - "running_event", + "running_future", "senders", "zc", ) @@ -69,7 +69,7 @@ def __init__( self.protocols: list[AsyncListener] = [] self.readers: list[_WrappedTransport] = [] self.senders: list[_WrappedTransport] = [] - self.running_event: asyncio.Event | None = None + self.running_future: asyncio.Future[bool | None] | None = None self._listen_socket = listen_socket self._respond_sockets = respond_sockets self._cleanup_timer: asyncio.TimerHandle | None = None @@ -81,15 +81,15 @@ def setup( ) -> None: """Set up the instance.""" self.loop = loop - self.running_event = asyncio.Event() + self.running_future = loop.create_future() self.loop.create_task(self._async_setup(loop_thread_ready)) async def _async_setup(self, loop_thread_ready: threading.Event | None) -> None: """Set up the instance.""" self._async_schedule_next_cache_cleanup() await self._async_create_endpoints() - assert self.running_event is not None - self.running_event.set() + assert self.running_future is not None + self.running_future.set_result(True) if loop_thread_ready: loop_thread_ready.set() @@ -142,8 +142,9 @@ async def _async_close(self) -> None: def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" - assert self.running_event is not None - self.running_event.clear() + assert self.running_future is not None + assert self.loop is not None + self.running_future = self.loop.create_future() for wrapped_transport in itertools.chain(self.senders, self.readers): wrapped_transport.transport.close() diff --git a/src/zeroconf/_utils/asyncio.py b/src/zeroconf/_utils/asyncio.py index 07b3f422..c92d99d5 100644 --- a/src/zeroconf/_utils/asyncio.py +++ b/src/zeroconf/_utils/asyncio.py @@ -28,11 +28,6 @@ import sys from typing import Any, Awaitable, Coroutine -if sys.version_info[:2] < (3, 11): - from async_timeout import timeout as asyncio_timeout -else: - from asyncio import timeout as asyncio_timeout # type: ignore[attr-defined] - from .._exceptions import EventLoopBlocked from ..const import _LOADED_SYSTEM_TIMEOUT from .time import millis_to_seconds @@ -70,11 +65,17 @@ async def wait_for_future_set_or_timeout( future_set.discard(future) -async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: - """Wait for an event or timeout.""" - with contextlib.suppress(asyncio.TimeoutError): - async with asyncio_timeout(timeout): - await event.wait() +async def wait_future_or_timeout(future: asyncio.Future[bool | None], timeout: float) -> None: + """Wait for a future or timeout.""" + loop = asyncio.get_running_loop() + handle = loop.call_later(timeout, _set_future_none_if_not_done, future) + try: + await future + except asyncio.CancelledError: + if sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling(): + raise + finally: + handle.cancel() async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> set[asyncio.Task]: diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 2a29a4bb..ce5a43eb 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -29,6 +29,7 @@ from ._core import Zeroconf from ._dns import DNSQuestionType +from ._exceptions import NotRunningException from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase from ._services.info import AsyncServiceInfo, ServiceInfo @@ -227,8 +228,8 @@ async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" if not self.zeroconf.done: - with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) + with contextlib.suppress(NotRunningException): + await self.zeroconf.async_wait_for_start(timeout=1.0) await self.async_remove_all_service_listeners() await self.async_unregister_all_services() await self.zeroconf._async_close() # pylint: disable=protected-access diff --git a/tests/utils/test_asyncio.py b/tests/utils/test_asyncio.py index 09137a71..7989a82c 100644 --- a/tests/utils/test_asyncio.py +++ b/tests/utils/test_asyncio.py @@ -45,16 +45,17 @@ def test_get_running_loop_no_loop() -> None: @pytest.mark.asyncio -async def test_wait_event_or_timeout_times_out() -> None: - """Test wait_event_or_timeout will timeout.""" - test_event = asyncio.Event() - await aioutils.wait_event_or_timeout(test_event, 0.1) +async def test_wait_future_or_timeout_times_out() -> None: + """Test wait_future_or_timeout will timeout.""" + loop = asyncio.get_running_loop() + test_future = loop.create_future() + await aioutils.wait_future_or_timeout(test_future, 0.1) - task = asyncio.ensure_future(test_event.wait()) + task = asyncio.ensure_future(test_future) await asyncio.sleep(0.1) async def _async_wait_or_timeout(): - await aioutils.wait_event_or_timeout(test_event, 0.1) + await aioutils.wait_future_or_timeout(test_future, 0.1) # Test high lock contention await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)])