Skip to content

Commit 0deb56b

Browse files
authored
fix(core): release sockets when close runs before engine setup completes (#1706)
1 parent cab5fa8 commit 0deb56b

8 files changed

Lines changed: 83 additions & 8 deletions

File tree

src/zeroconf/_engine.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,17 @@ async def _async_create_endpoints(self) -> None:
114114
lambda: AsyncListener(self.zc), # type: ignore[arg-type, return-value]
115115
sock=s,
116116
)
117+
# Register the wrapped transport before releasing the engine's
118+
# handle so a concurrent shutdown always sees ``s`` in exactly
119+
# one place; do not add an ``await`` between these two steps.
117120
self.protocols.append(cast(AsyncListener, protocol))
118121
self.readers.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
119122
if s in sender_sockets:
120123
self.senders.append(make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
124+
if s is self._listen_socket:
125+
self._listen_socket = None
126+
if s in self._respond_sockets:
127+
self._respond_sockets.remove(s)
121128

122129
def _async_cache_cleanup(self) -> None:
123130
"""Periodic cache cleanup."""
@@ -139,19 +146,37 @@ def _async_schedule_next_cache_cleanup(self) -> None:
139146
async def _async_close(self) -> None:
140147
"""Cancel and wait for the cleanup task to finish."""
141148
assert self._setup_task is not None
142-
await self._setup_task
149+
# Swallow CancelledError only if the setup task itself was
150+
# cancelled (close-before-start); outer-task cancellation must
151+
# propagate.
152+
try:
153+
await self._setup_task
154+
except asyncio.CancelledError:
155+
if not self._setup_task.cancelled():
156+
raise
143157
self._async_shutdown()
144158
await asyncio.sleep(0) # flush out any call soons
145-
assert self._cleanup_timer is not None
146-
self._cleanup_timer.cancel()
159+
if self._cleanup_timer is not None:
160+
self._cleanup_timer.cancel()
147161

148162
def _async_shutdown(self) -> None:
149-
"""Shutdown transports and sockets."""
163+
"""Shutdown transports and sockets; safe to call repeatedly."""
150164
assert self.running_future is not None
151165
assert self.loop is not None
152166
self.running_future = self.loop.create_future()
167+
# Cancel pending setup so it can't wrap fresh transports after
168+
# shutdown has started.
169+
if self._setup_task is not None and not self._setup_task.done():
170+
self._setup_task.cancel()
153171
for wrapped_transport in itertools.chain(self.senders, self.readers):
154172
wrapped_transport.transport.close()
173+
# Anything still here was never adopted by a transport.
174+
if self._listen_socket is not None:
175+
self._listen_socket.close()
176+
self._listen_socket = None
177+
for s in self._respond_sockets:
178+
s.close()
179+
self._respond_sockets = []
155180

156181
def close(self) -> None:
157182
"""Close from sync context.

tests/services/test_browser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
10331033

10341034

10351035
def test_service_browser_listeners_no_update_service():
1036-
"""Test that the ServiceBrowser ServiceListener that does not implement update_service."""
1036+
"""A listener that ignores update events records only add/remove callbacks."""
10371037

10381038
# instantiate a zeroconf instance
10391039
zc = Zeroconf(interfaces=["127.0.0.1"])
@@ -1051,6 +1051,9 @@ def remove_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
10511051
if name == registration_name:
10521052
callbacks.append(("remove", type_, name))
10531053

1054+
def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-def]
1055+
pass
1056+
10541057
listener = MyServiceListener()
10551058

10561059
browser = r.ServiceBrowser(zc, type_, None, listener)

tests/services/test_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,7 @@ async def test_bad_ip_addresses_ignored_in_cache():
15291529
info = ServiceInfo(type_, registration_name)
15301530
info.load_from_cache(aiozc.zeroconf)
15311531
assert info.addresses_by_version(IPVersion.V4Only) == [b"\x7f\x00\x00\x01"]
1532+
await aiozc.async_close()
15321533

15331534

15341535
@pytest.mark.asyncio
@@ -1804,6 +1805,7 @@ async def test_address_resolver():
18041805
aiozc.zeroconf.async_send(outgoing)
18051806
assert await resolve_task
18061807
assert resolver.addresses == [b"\x7f\x00\x00\x01"]
1808+
await aiozc.async_close()
18071809

18081810

18091811
@pytest.mark.asyncio
@@ -1828,6 +1830,7 @@ async def test_address_resolver_ipv4():
18281830
aiozc.zeroconf.async_send(outgoing)
18291831
assert await resolve_task
18301832
assert resolver.addresses == [b"\x7f\x00\x00\x01"]
1833+
await aiozc.async_close()
18311834

18321835

18331836
@pytest.mark.asyncio
@@ -1854,6 +1857,7 @@ async def test_address_resolver_ipv6():
18541857
aiozc.zeroconf.async_send(outgoing)
18551858
assert await resolve_task
18561859
assert resolver.ip_addresses_by_version(IPVersion.All) == [ip_address("fe80::52e:c2f2:bc5f:e9c6")]
1860+
await aiozc.async_close()
18571861

18581862

18591863
@pytest.mark.asyncio

tests/test_asyncio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ async def test_async_service_registration_name_strict_check(quick_timing: None)
500500

501501
await aiozc.async_unregister_service(info)
502502
await aiozc.async_close()
503+
zc.close()
503504

504505

505506
@pytest.mark.asyncio

tests/test_core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,9 @@ def _background_register():
824824
async def test_event_loop_blocked(mock_start):
825825
"""Test we raise NotRunningException when waiting for startup that times out."""
826826
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
827-
with pytest.raises(NotRunningException):
828-
await aiozc.zeroconf.async_wait_for_start(timeout=0)
829-
assert aiozc.zeroconf.started is False
827+
try:
828+
with pytest.raises(NotRunningException):
829+
await aiozc.zeroconf.async_wait_for_start(timeout=0)
830+
assert aiozc.zeroconf.started is False
831+
finally:
832+
await aiozc.async_close()

tests/test_engine.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,42 @@ async def test_reaper():
7373
assert record_with_1s_ttl not in entries
7474

7575

76+
@pytest.mark.asyncio
77+
async def test_setup_releases_socket_ownership() -> None:
78+
"""Engine releases its pending-socket refs once each socket has a transport."""
79+
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
80+
try:
81+
await aiozc.zeroconf.async_wait_for_start()
82+
engine = aiozc.zeroconf.engine
83+
assert engine._listen_socket is None
84+
assert engine._respond_sockets == []
85+
assert engine.readers
86+
assert engine.senders
87+
finally:
88+
await aiozc.async_close()
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_async_close_propagates_outer_cancellation() -> None:
93+
"""Outer-task cancellation while awaiting setup propagates to the caller."""
94+
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
95+
try:
96+
await aiozc.zeroconf.async_wait_for_start()
97+
engine = aiozc.zeroconf.engine
98+
loop = asyncio.get_running_loop()
99+
original_task = engine._setup_task
100+
fake_task = loop.create_future()
101+
fake_task.set_exception(asyncio.CancelledError())
102+
engine._setup_task = fake_task # type: ignore[assignment]
103+
try:
104+
with pytest.raises(asyncio.CancelledError):
105+
await engine._async_close()
106+
finally:
107+
engine._setup_task = original_task
108+
finally:
109+
await aiozc.async_close()
110+
111+
76112
@pytest.mark.asyncio
77113
async def test_reaper_aborts_when_done():
78114
"""Ensure cache cleanup stops when zeroconf is done."""

tests/test_handlers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,8 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
17991799
zc.record_manager.async_updates_from_response(incoming)
18001800
assert info2.dns_pointer() in incoming.answers()
18011801

1802+
await aiozc.async_close()
1803+
18021804

18031805
@pytest.mark.asyncio
18041806
async def test_response_aggregation_random_delay():

tests/utils/test_asyncio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def _run_coro() -> None:
105105

106106
assert loop.is_running() is False
107107
runcoro_thread.join()
108+
loop.close()
108109

109110

110111
def test_cumulative_timeouts_less_than_close_plus_buffer():

0 commit comments

Comments
 (0)