|
22 | 22 |
|
23 | 23 | from __future__ import annotations |
24 | 24 |
|
| 25 | +import asyncio |
| 26 | +from functools import partial |
25 | 27 | from typing import TYPE_CHECKING, cast |
26 | 28 |
|
27 | 29 | from .._cache import _UniqueRecordsType |
28 | 30 | from .._dns import DNSQuestion, DNSRecord |
29 | 31 | from .._logger import log |
30 | 32 | from .._protocol.incoming import DNSIncoming |
| 33 | +from .._protocol.outgoing import DNSOutgoing |
31 | 34 | from .._record_update import RecordUpdate |
32 | 35 | from .._updates import RecordUpdateListener |
33 | | -from .._utils.time import current_time_millis |
34 | | -from ..const import _ADDRESS_RECORD_TYPES, _DNS_PTR_MIN_TTL, _TYPE_PTR |
| 36 | +from .._utils.time import current_time_millis, millis_to_seconds |
| 37 | +from ..const import ( |
| 38 | + _ADDRESS_RECORD_TYPES, |
| 39 | + _DNS_PTR_MIN_TTL, |
| 40 | + _FLAGS_QR_QUERY, |
| 41 | + _RECONFIRM_QUERY_INTERVALS_MS, |
| 42 | + _RECONFIRM_TIMEOUT_MS, |
| 43 | + _TYPE_PTR, |
| 44 | +) |
35 | 45 |
|
36 | 46 | if TYPE_CHECKING: |
37 | 47 | from .._core import Zeroconf |
|
42 | 52 | class RecordManager: |
43 | 53 | """Process records into the cache and notify listeners.""" |
44 | 54 |
|
45 | | - __slots__ = ("cache", "listeners", "zc") |
| 55 | + __slots__ = ("_reconfirm_tasks", "cache", "listeners", "zc") |
46 | 56 |
|
47 | 57 | def __init__(self, zeroconf: Zeroconf) -> None: |
48 | 58 | """Init the record manager.""" |
49 | 59 | self.zc = zeroconf |
50 | 60 | self.cache = zeroconf.cache |
51 | 61 | self.listeners: set[RecordUpdateListener] = set() |
| 62 | + # Active per-record reconfirmations. Keyed by the cache entry |
| 63 | + # so that repeated calls for the same record while one is in |
| 64 | + # flight are no-ops (RFC 6762 §10.4). |
| 65 | + self._reconfirm_tasks: dict[DNSRecord, asyncio.Task] = {} |
52 | 66 |
|
53 | 67 | def async_updates(self, now: _float, records: list[RecordUpdate]) -> None: |
54 | 68 | """Used to notify listeners of new information that has updated |
@@ -219,3 +233,83 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: |
219 | 233 | self.zc.async_notify_all() |
220 | 234 | except ValueError as e: |
221 | 235 | log.exception("Failed to remove listener: %r", e) |
| 236 | + |
| 237 | + def async_reconfirm_record(self, record: DNSRecord) -> bool: |
| 238 | + """Schedule RFC 6762 §10.4 reconfirmation for ``record``.""" |
| 239 | + cached = self.cache.get(record) |
| 240 | + if cached is None: |
| 241 | + return False |
| 242 | + if cached in self._reconfirm_tasks: |
| 243 | + return False |
| 244 | + loop = self.zc.loop |
| 245 | + if loop is None: |
| 246 | + return False |
| 247 | + task = loop.create_task(self._async_reconfirm(cached)) |
| 248 | + self._reconfirm_tasks[cached] = task |
| 249 | + task.add_done_callback(partial(self._reconfirm_done, cached)) |
| 250 | + return True |
| 251 | + |
| 252 | + def _reconfirm_done(self, record: DNSRecord, _task: asyncio.Task) -> None: |
| 253 | + """Drop ``record`` from the active reconfirmation set.""" |
| 254 | + self._reconfirm_tasks.pop(record, None) |
| 255 | + |
| 256 | + async def _async_reconfirm(self, record: DNSRecord) -> None: |
| 257 | + """Re-query ``record`` and flush from cache if not refreshed. |
| 258 | +
|
| 259 | + RFC 6762 §10.4: send two or more queries, then flush the |
| 260 | + record if no response arrives within ten seconds. |
| 261 | + """ |
| 262 | + start = current_time_millis() |
| 263 | + original_created = record.created |
| 264 | + zc = self.zc |
| 265 | + question = DNSQuestion(record.name, record.type, record.class_) |
| 266 | + |
| 267 | + prev_delay_ms = 0 |
| 268 | + for delay_ms in _RECONFIRM_QUERY_INTERVALS_MS: |
| 269 | + wait_ms = delay_ms - prev_delay_ms |
| 270 | + if wait_ms > 0: |
| 271 | + await asyncio.sleep(millis_to_seconds(wait_ms)) |
| 272 | + prev_delay_ms = delay_ms |
| 273 | + if zc.done: |
| 274 | + return |
| 275 | + if self._record_refreshed_since(record, original_created): |
| 276 | + return |
| 277 | + out = DNSOutgoing(_FLAGS_QR_QUERY) |
| 278 | + out.add_question(question) |
| 279 | + zc.async_send(out) |
| 280 | + |
| 281 | + remaining_ms = _RECONFIRM_TIMEOUT_MS - prev_delay_ms |
| 282 | + if remaining_ms > 0: |
| 283 | + await asyncio.sleep(millis_to_seconds(remaining_ms)) |
| 284 | + if zc.done: |
| 285 | + return |
| 286 | + if self._record_refreshed_since(record, original_created): |
| 287 | + return |
| 288 | + |
| 289 | + now = current_time_millis() |
| 290 | + elapsed_secs = max(0, int((now - start) / 1000)) |
| 291 | + log.debug( |
| 292 | + "Reconfirmation of %s timed out after %ds; flushing from cache", |
| 293 | + record, |
| 294 | + elapsed_secs, |
| 295 | + ) |
| 296 | + cached = self.cache.get(record) |
| 297 | + if cached is None: |
| 298 | + return |
| 299 | + # Mark expired so listeners interpret this as a goodbye when |
| 300 | + # they re-check ``is_expired(now)`` from inside |
| 301 | + # ``async_update_records``. Mirrors the goodbye path in |
| 302 | + # ``async_updates_from_response``. |
| 303 | + cached._set_created_ttl(now - 1000, 0) |
| 304 | + update = RecordUpdate.__new__(RecordUpdate) |
| 305 | + update._fast_init(cached, cached) |
| 306 | + self.async_updates(now, [update]) |
| 307 | + self.cache.async_remove_records([cached]) |
| 308 | + self.async_updates_complete(True) |
| 309 | + |
| 310 | + def _record_refreshed_since(self, record: DNSRecord, original_created: float) -> bool: |
| 311 | + """Return True if the cache holds a newer copy of ``record``.""" |
| 312 | + cached = self.cache.get(record) |
| 313 | + if cached is None: |
| 314 | + return True |
| 315 | + return cached.created > original_created |
0 commit comments