Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
The default is HTTP and HAP; use --find to search for all available services in the network
"""

from __future__ import annotations

import argparse
import logging
from time import sleep
Expand Down
2 changes: 2 additions & 0 deletions examples/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Example of announcing a service (in this case, a fake HTTP server)"""

from __future__ import annotations

import argparse
import logging
import socket
Expand Down
2 changes: 2 additions & 0 deletions examples/resolve_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Example of resolving a name to an IP address."""

from __future__ import annotations

import asyncio
import logging
import sys
Expand Down
2 changes: 2 additions & 0 deletions examples/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Example of resolving a service with a known name"""

from __future__ import annotations

import logging
import sys

Expand Down
1 change: 1 addition & 0 deletions examples/self_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
from __future__ import annotations

import logging
import socket
Expand Down
4 changes: 3 additions & 1 deletion src/zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
USA
"""

from __future__ import annotations

from ._cache import DNSCache # noqa # import needed for backwards compat
from ._core import Zeroconf
from ._dns import ( # noqa # import needed for backwards compat
Expand Down Expand Up @@ -57,10 +59,10 @@
)
from ._services.browser import ServiceBrowser
from ._services.info import ( # noqa # import needed for backwards compat
ServiceInfo,
AddressResolver,
AddressResolverIPv4,
AddressResolverIPv6,
ServiceInfo,
instance_name_from_service_info,
)
from ._services.registry import ( # noqa # import needed for backwards compat
Expand Down
38 changes: 20 additions & 18 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
USA
"""

from __future__ import annotations

from heapq import heapify, heappop, heappush
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast
from typing import Dict, Iterable, Union, cast

from ._dns import (
DNSAddress,
Expand Down Expand Up @@ -66,8 +68,8 @@ class DNSCache:

def __init__(self) -> None:
self.cache: _DNSRecordCacheType = {}
self._expire_heap: List[Tuple[float, DNSRecord]] = []
self._expirations: Dict[DNSRecord, float] = {}
self._expire_heap: list[tuple[float, DNSRecord]] = []
self._expirations: dict[DNSRecord, float] = {}
self.service_cache: _DNSRecordCacheType = {}

# Functions prefixed with async_ are NOT threadsafe and must
Expand Down Expand Up @@ -135,7 +137,7 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
for entry in entries:
self._async_remove(entry)

def async_expire(self, now: _float) -> List[DNSRecord]:
def async_expire(self, now: _float) -> list[DNSRecord]:
"""Purge expired entries from the cache.

This function must be run in from event loop.
Expand All @@ -145,7 +147,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
if not (expire_heap_len := len(self._expire_heap)):
return []

expired: List[DNSRecord] = []
expired: list[DNSRecord] = []
# Find any expired records and add them to the to-delete list
while self._expire_heap:
when_record = self._expire_heap[0]
Expand Down Expand Up @@ -182,7 +184,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
self.async_remove_records(expired)
return expired

def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
def async_get_unique(self, entry: _UniqueRecordsType) -> DNSRecord | None:
"""Gets a unique entry by key. Will return None if there is no
matching entry.

Expand All @@ -194,31 +196,31 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
return None
return store.get(entry)

def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]:
def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> list[DNSRecord]:
"""Gets all matching entries by details.

This function is not thread-safe and must be called from
the event loop.
"""
key = name.lower()
records = self.cache.get(key)
matches: List[DNSRecord] = []
matches: list[DNSRecord] = []
if records is None:
return matches
for record in records.values():
if type_ == record.type and class_ == record.class_:
matches.append(record)
return matches

def async_entries_with_name(self, name: str) -> List[DNSRecord]:
def async_entries_with_name(self, name: str) -> list[DNSRecord]:
"""Returns a dict of entries whose key matches the name.

This function is not threadsafe and must be called from
the event loop.
"""
return self.entries_with_name(name)

def async_entries_with_server(self, name: str) -> List[DNSRecord]:
def async_entries_with_server(self, name: str) -> list[DNSRecord]:
"""Returns a dict of entries whose key matches the server.

This function is not threadsafe and must be called from
Expand All @@ -230,7 +232,7 @@ def async_entries_with_server(self, name: str) -> List[DNSRecord]:
# event loop, however they all make copies so they significantly
# inefficient.

def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
def get(self, entry: DNSEntry) -> DNSRecord | None:
"""Gets an entry by key. Will return None if there is no
matching entry."""
if isinstance(entry, _UNIQUE_RECORD_TYPES):
Expand All @@ -240,7 +242,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
return cached_entry
return None

def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRecord]:
def get_by_details(self, name: str, type_: _int, class_: _int) -> DNSRecord | None:
"""Gets the first matching entry by details. Returns None if no entries match.

Calling this function is not recommended as it will only
Expand All @@ -261,27 +263,27 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe
return cached_entry
return None

def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRecord]:
def get_all_by_details(self, name: str, type_: _int, class_: _int) -> list[DNSRecord]:
"""Gets all matching entries by details."""
key = name.lower()
records = self.cache.get(key)
if records is None:
return []
return [entry for entry in list(records.values()) if type_ == entry.type and class_ == entry.class_]

def entries_with_server(self, server: str) -> List[DNSRecord]:
def entries_with_server(self, server: str) -> list[DNSRecord]:
"""Returns a list of entries whose server matches the name."""
if entries := self.service_cache.get(server.lower()):
return list(entries.values())
return []

def entries_with_name(self, name: str) -> List[DNSRecord]:
def entries_with_name(self, name: str) -> list[DNSRecord]:
"""Returns a list of entries whose key matches the name."""
if entries := self.cache.get(name.lower()):
return list(entries.values())
return []

def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
def current_entry_with_name_and_alias(self, name: str, alias: str) -> DNSRecord | None:
now = current_time_millis()
for record in reversed(self.entries_with_name(name)):
if (
Expand All @@ -292,13 +294,13 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
return record
return None

def names(self) -> List[str]:
def names(self) -> list[str]:
"""Return a copy of the list of current cache names."""
return list(self.cache)

def async_mark_unique_records_older_than_1s_to_expire(
self,
unique_types: Set[Tuple[_str, _int, _int]],
unique_types: set[tuple[_str, _int, _int]],
answers: Iterable[DNSRecord],
now: _float,
) -> None:
Expand Down
Loading