diff --git a/examples/browser.py b/examples/browser.py index 107be452..92adc949 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -5,6 +5,8 @@ The default is HTTP and HAP; use --find to search for all available services in the network """ +from __future__ import annotations + import argparse import logging from time import sleep diff --git a/examples/registration.py b/examples/registration.py index 1c42d890..1ba19b16 100755 --- a/examples/registration.py +++ b/examples/registration.py @@ -2,6 +2,8 @@ """Example of announcing a service (in this case, a fake HTTP server)""" +from __future__ import annotations + import argparse import logging import socket diff --git a/examples/resolve_address.py b/examples/resolve_address.py index eeecfda0..88ce825b 100755 --- a/examples/resolve_address.py +++ b/examples/resolve_address.py @@ -2,6 +2,8 @@ """Example of resolving a name to an IP address.""" +from __future__ import annotations + import asyncio import logging import sys diff --git a/examples/resolver.py b/examples/resolver.py index 1b74f97e..a52050f4 100755 --- a/examples/resolver.py +++ b/examples/resolver.py @@ -2,6 +2,8 @@ """Example of resolving a service with a known name""" +from __future__ import annotations + import logging import sys diff --git a/examples/self_test.py b/examples/self_test.py index b12a8518..3d1fa050 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from __future__ import annotations import logging import socket diff --git a/src/zeroconf/__init__.py b/src/zeroconf/__init__.py index 1a41ddd3..26f60cde 100644 --- a/src/zeroconf/__init__.py +++ b/src/zeroconf/__init__.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + from ._cache import DNSCache # noqa # import needed for backwards compat from ._core import Zeroconf from ._dns import ( # noqa # import needed for backwards compat @@ -57,10 +59,10 @@ ) from ._services.browser import ServiceBrowser from ._services.info import ( # noqa # import needed for backwards compat - ServiceInfo, AddressResolver, AddressResolverIPv4, AddressResolverIPv6, + ServiceInfo, instance_name_from_service_info, ) from ._services.registry import ( # noqa # import needed for backwards compat diff --git a/src/zeroconf/_cache.py b/src/zeroconf/_cache.py index 1b7aae38..5ac43f30 100644 --- a/src/zeroconf/_cache.py +++ b/src/zeroconf/_cache.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + from heapq import heapify, heappop, heappush -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast +from typing import Dict, Iterable, Union, cast from ._dns import ( DNSAddress, @@ -66,8 +68,8 @@ class DNSCache: def __init__(self) -> None: self.cache: _DNSRecordCacheType = {} - self._expire_heap: List[Tuple[float, DNSRecord]] = [] - self._expirations: Dict[DNSRecord, float] = {} + self._expire_heap: list[tuple[float, DNSRecord]] = [] + self._expirations: dict[DNSRecord, float] = {} self.service_cache: _DNSRecordCacheType = {} # Functions prefixed with async_ are NOT threadsafe and must @@ -135,7 +137,7 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: for entry in entries: self._async_remove(entry) - def async_expire(self, now: _float) -> List[DNSRecord]: + def async_expire(self, now: _float) -> list[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. @@ -145,7 +147,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]: if not (expire_heap_len := len(self._expire_heap)): return [] - expired: List[DNSRecord] = [] + expired: list[DNSRecord] = [] # Find any expired records and add them to the to-delete list while self._expire_heap: when_record = self._expire_heap[0] @@ -182,7 +184,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]: self.async_remove_records(expired) return expired - def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + def async_get_unique(self, entry: _UniqueRecordsType) -> DNSRecord | None: """Gets a unique entry by key. Will return None if there is no matching entry. @@ -194,7 +196,7 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: return None return store.get(entry) - def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DNSRecord]: + def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> list[DNSRecord]: """Gets all matching entries by details. This function is not thread-safe and must be called from @@ -202,7 +204,7 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN """ key = name.lower() records = self.cache.get(key) - matches: List[DNSRecord] = [] + matches: list[DNSRecord] = [] if records is None: return matches for record in records.values(): @@ -210,7 +212,7 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN matches.append(record) return matches - def async_entries_with_name(self, name: str) -> List[DNSRecord]: + def async_entries_with_name(self, name: str) -> list[DNSRecord]: """Returns a dict of entries whose key matches the name. This function is not threadsafe and must be called from @@ -218,7 +220,7 @@ def async_entries_with_name(self, name: str) -> List[DNSRecord]: """ return self.entries_with_name(name) - def async_entries_with_server(self, name: str) -> List[DNSRecord]: + def async_entries_with_server(self, name: str) -> list[DNSRecord]: """Returns a dict of entries whose key matches the server. This function is not threadsafe and must be called from @@ -230,7 +232,7 @@ def async_entries_with_server(self, name: str) -> List[DNSRecord]: # event loop, however they all make copies so they significantly # inefficient. - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + def get(self, entry: DNSEntry) -> DNSRecord | None: """Gets an entry by key. Will return None if there is no matching entry.""" if isinstance(entry, _UNIQUE_RECORD_TYPES): @@ -240,7 +242,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: return cached_entry return None - def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRecord]: + def get_by_details(self, name: str, type_: _int, class_: _int) -> DNSRecord | None: """Gets the first matching entry by details. Returns None if no entries match. Calling this function is not recommended as it will only @@ -261,7 +263,7 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe return cached_entry return None - def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRecord]: + def get_all_by_details(self, name: str, type_: _int, class_: _int) -> list[DNSRecord]: """Gets all matching entries by details.""" key = name.lower() records = self.cache.get(key) @@ -269,19 +271,19 @@ def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRe return [] return [entry for entry in list(records.values()) if type_ == entry.type and class_ == entry.class_] - def entries_with_server(self, server: str) -> List[DNSRecord]: + def entries_with_server(self, server: str) -> list[DNSRecord]: """Returns a list of entries whose server matches the name.""" if entries := self.service_cache.get(server.lower()): return list(entries.values()) return [] - def entries_with_name(self, name: str) -> List[DNSRecord]: + def entries_with_name(self, name: str) -> list[DNSRecord]: """Returns a list of entries whose key matches the name.""" if entries := self.cache.get(name.lower()): return list(entries.values()) return [] - def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + def current_entry_with_name_and_alias(self, name: str, alias: str) -> DNSRecord | None: now = current_time_millis() for record in reversed(self.entries_with_name(name)): if ( @@ -292,13 +294,13 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D return record return None - def names(self) -> List[str]: + def names(self) -> list[str]: """Return a copy of the list of current cache names.""" return list(self.cache) def async_mark_unique_records_older_than_1s_to_expire( self, - unique_types: Set[Tuple[_str, _int, _int]], + unique_types: set[tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float, ) -> None: diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 68cb8a9a..01e98e8f 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -20,12 +20,14 @@ USA """ +from __future__ import annotations + import asyncio import logging import sys import threading from types import TracebackType -from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Awaitable from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType @@ -108,9 +110,9 @@ def async_send_with_transport( packet: bytes, packet_num: int, out: DNSOutgoing, - addr: Optional[str], + addr: str | None, port: int, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + v6_flow_scope: tuple[()] | tuple[int, int] = (), ) -> None: ipv6_socket = transport.is_ipv6 if addr is None: @@ -149,7 +151,7 @@ def __init__( self, interfaces: InterfacesType = InterfaceChoice.All, unicast: bool = False, - ip_version: Optional[IPVersion] = None, + ip_version: IPVersion | None = None, apple_p2p: bool = False, ) -> None: """Creates an instance of the Zeroconf class, establishing @@ -181,7 +183,7 @@ def __init__( self.engine = AsyncEngine(self, listen_socket, respond_sockets) - self.browsers: Dict[ServiceListener, ServiceBrowser] = {} + self.browsers: dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() self.cache = DNSCache() self.question_history = QuestionHistory() @@ -192,9 +194,9 @@ def __init__( self.query_handler = QueryHandler(self) self.record_manager = RecordManager(self) - self._notify_futures: Set[asyncio.Future] = set() - self.loop: Optional[asyncio.AbstractEventLoop] = None - self._loop_thread: Optional[threading.Thread] = None + self._notify_futures: set[asyncio.Future] = set() + self.loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None self.start() @@ -239,7 +241,7 @@ async def async_wait_for_start(self) -> None: raise NotRunningException @property - def listeners(self) -> Set[RecordUpdateListener]: + def listeners(self) -> set[RecordUpdateListener]: return self.record_manager.listeners async def async_wait(self, timeout: float) -> None: @@ -264,8 +266,8 @@ def get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[ServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> ServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -301,7 +303,7 @@ def remove_all_service_listeners(self) -> None: def register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -329,7 +331,7 @@ def register_service( async def async_register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -380,8 +382,8 @@ async def async_get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[AsyncServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> AsyncServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -400,7 +402,7 @@ async def _async_broadcast_service( self, info: ServiceInfo, interval: int, - ttl: Optional[int], + ttl: int | None, broadcast_addresses: bool = True, ) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -412,7 +414,7 @@ async def _async_broadcast_service( def generate_service_broadcast( self, info: ServiceInfo, - ttl: Optional[int], + ttl: int | None, broadcast_addresses: bool = True, ) -> DNSOutgoing: """Generate a broadcast to announce a service.""" @@ -439,7 +441,7 @@ def _add_broadcast_answer( # pylint: disable=no-self-use self, out: DNSOutgoing, info: ServiceInfo, - override_ttl: Optional[int], + override_ttl: int | None, broadcast_addresses: bool = True, ) -> None: """Add answers to broadcast a service.""" @@ -481,7 +483,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) ) - def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: + def generate_unregister_all_services(self) -> DNSOutgoing | None: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" service_infos = self.registry.async_get_service_infos() if not service_infos: @@ -562,7 +564,7 @@ async def async_check_service( def add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -584,7 +586,7 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: def async_add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -604,10 +606,10 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: def send( self, out: DNSOutgoing, - addr: Optional[str] = None, + addr: str | None = None, port: int = _MDNS_PORT, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[_WrappedTransport] = None, + v6_flow_scope: tuple[()] | tuple[int, int] = (), + transport: _WrappedTransport | None = None, ) -> None: """Sends an outgoing packet threadsafe.""" assert self.loop is not None @@ -616,10 +618,10 @@ def send( def async_send( self, out: DNSOutgoing, - addr: Optional[str] = None, + addr: str | None = None, port: int = _MDNS_PORT, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[_WrappedTransport] = None, + v6_flow_scope: tuple[()] | tuple[int, int] = (), + transport: _WrappedTransport | None = None, ) -> None: """Sends an outgoing packet.""" if self.done: @@ -701,14 +703,14 @@ async def _async_close(self) -> None: await self.engine._async_close() # pylint: disable=protected-access self._shutdown_threads() - def __enter__(self) -> "Zeroconf": + def __enter__(self) -> Zeroconf: return self def __exit__( # pylint: disable=useless-return self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: self.close() return None diff --git a/src/zeroconf/_dns.py b/src/zeroconf/_dns.py index c22f8b17..bc0a3948 100644 --- a/src/zeroconf/_dns.py +++ b/src/zeroconf/_dns.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import enum import socket -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Any, cast from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address @@ -94,7 +96,7 @@ def get_type(t: int) -> str: """Type accessor""" return _TYPES.get(t, f"?({t})") - def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: + def entry_to_string(self, hdr: str, other: bytes | str | None) -> str: """String representation with additional information""" return "{}[{},{}{},{}]{}".format( hdr, @@ -119,7 +121,7 @@ def _fast_init(self, name: str, type_: _int, class_: _int) -> None: self._fast_init_entry(name, type_, class_) self._hash = hash((self.key, type_, self.class_)) - def answered_by(self, rec: "DNSRecord") -> bool: + def answered_by(self, rec: DNSRecord) -> bool: """Returns true if the question is answered by the record""" return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name @@ -170,8 +172,8 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[float, int], - created: Optional[float] = None, + ttl: float | int, + created: float | None = None, ) -> None: self._fast_init_record(name, type_, class_, ttl, created or current_time_millis()) @@ -185,10 +187,10 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException - def __lt__(self, other: "DNSRecord") -> bool: + def __lt__(self, other: DNSRecord) -> bool: return self.ttl < other.ttl - def suppressed_by(self, msg: "DNSIncoming") -> bool: + def suppressed_by(self, msg: DNSIncoming) -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" answers = msg.answers() @@ -208,7 +210,7 @@ def get_expiration_time(self, percent: _int) -> float: return self.created + (percent * self.ttl * 10) # TODO: Switch to just int here - def get_remaining_ttl(self, now: _float) -> Union[int, float]: + def get_remaining_ttl(self, now: _float) -> int | float: """Returns the remaining TTL in seconds.""" remain = (self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) - now) / 1000.0 return 0 if remain < 0 else remain @@ -225,18 +227,18 @@ def is_recent(self, now: _float) -> bool: """Returns true if the record more than one quarter of its TTL remaining.""" return self.created + (_RECENT_TIME_MS * self.ttl) > now - def _set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None: + def _set_created_ttl(self, created: _float, ttl: float | int) -> None: """Set the created and ttl of a record.""" # It would be better if we made a copy instead of mutating the record # in place, but records currently don't have a copy method. self.created = created self.ttl = ttl - def write(self, out: "DNSOutgoing") -> None: # pylint: disable=no-self-use + def write(self, out: DNSOutgoing) -> None: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException - def to_string(self, other: Union[bytes, str]) -> str: + def to_string(self, other: bytes | str) -> str: """String representation with additional information""" arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}" return DNSEntry.entry_to_string(self, "record", arg) @@ -254,8 +256,8 @@ def __init__( class_: int, ttl: int, address: bytes, - scope_id: Optional[int] = None, - created: Optional[float] = None, + scope_id: int | None = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, address, scope_id, created or current_time_millis()) @@ -266,7 +268,7 @@ def _fast_init( class_: _int, ttl: _float, address: bytes, - scope_id: Optional[_int], + scope_id: _int | None, created: _float, ) -> None: """Fast init for reuse.""" @@ -275,7 +277,7 @@ def _fast_init( self.scope_id = scope_id self._hash = hash((self.key, type_, self.class_, address, scope_id)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_string(self.address) @@ -320,7 +322,7 @@ def __init__( ttl: int, cpu: str, os: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, cpu, os, created or current_time_millis()) @@ -333,7 +335,7 @@ def _fast_init( self.os = os self._hash = hash((self.key, type_, self.class_, cpu, os)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_character_string(self.cpu.encode("utf-8")) out.write_character_string(self.os.encode("utf-8")) @@ -367,7 +369,7 @@ def __init__( class_: int, ttl: int, alias: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, alias, created or current_time_millis()) @@ -389,7 +391,7 @@ def max_size_compressed(self) -> int: + _NAME_COMPRESSION_MIN_SIZE ) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_name(self.alias) @@ -422,7 +424,7 @@ def __init__( class_: int, ttl: int, text: bytes, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, text, created or current_time_millis()) @@ -433,7 +435,7 @@ def _fast_init( self.text = text self._hash = hash((self.key, type_, self.class_, text)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_string(self.text) @@ -466,12 +468,12 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[float, int], + ttl: float | int, priority: int, weight: int, port: int, server: str, - created: Optional[float] = None, + created: float | None = None, ) -> None: self._fast_init( name, type_, class_, ttl, priority, weight, port, server, created or current_time_millis() @@ -497,7 +499,7 @@ def _fast_init( self.server_key = server.lower() self._hash = hash((self.key, type_, self.class_, priority, weight, port, self.server_key)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet""" out.write_short(self.priority) out.write_short(self.weight) @@ -537,10 +539,10 @@ def __init__( name: str, type_: int, class_: int, - ttl: Union[int, float], + ttl: int | float, next_name: str, - rdtypes: List[int], - created: Optional[float] = None, + rdtypes: list[int], + created: float | None = None, ) -> None: self._fast_init(name, type_, class_, ttl, next_name, rdtypes, created or current_time_millis()) @@ -551,7 +553,7 @@ def _fast_init( class_: _int, ttl: _float, next_name: str, - rdtypes: List[_int], + rdtypes: list[_int], created: _float, ) -> None: self._fast_init_record(name, type_, class_, ttl, created) @@ -559,7 +561,7 @@ def _fast_init( self.rdtypes = sorted(rdtypes) self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) - def write(self, out: "DNSOutgoing") -> None: + def write(self, out: DNSOutgoing) -> None: """Used in constructing an outgoing packet.""" bitmap = bytearray(b"\0" * 32) total_octets = 0 @@ -610,21 +612,21 @@ class DNSRRSet: __slots__ = ("_lookup", "_records") - def __init__(self, records: List[DNSRecord]) -> None: + def __init__(self, records: list[DNSRecord]) -> None: """Create an RRset from records sets.""" self._records = records - self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None + self._lookup: dict[DNSRecord, DNSRecord] | None = None @property - def lookup(self) -> Dict[DNSRecord, DNSRecord]: + def lookup(self) -> dict[DNSRecord, DNSRecord]: """Return the lookup table.""" return self._get_lookup() - def lookup_set(self) -> Set[DNSRecord]: + def lookup_set(self) -> set[DNSRecord]: """Return the lookup table as aset.""" return set(self._get_lookup()) - def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]: + def _get_lookup(self) -> dict[DNSRecord, DNSRecord]: """Return the lookup table, building it if needed.""" if self._lookup is None: # Build the hash table so we can lookup the record ttl diff --git a/src/zeroconf/_engine.py b/src/zeroconf/_engine.py index 05f8c948..7b22f788 100644 --- a/src/zeroconf/_engine.py +++ b/src/zeroconf/_engine.py @@ -20,11 +20,13 @@ USA """ +from __future__ import annotations + import asyncio import itertools import socket import threading -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, cast from ._record_update import RecordUpdate from ._utils.asyncio import get_running_loop, run_coro_with_timeout @@ -58,31 +60,31 @@ class AsyncEngine: def __init__( self, - zeroconf: "Zeroconf", - listen_socket: Optional[socket.socket], - respond_sockets: List[socket.socket], + zeroconf: Zeroconf, + listen_socket: socket.socket | None, + respond_sockets: list[socket.socket], ) -> None: - self.loop: Optional[asyncio.AbstractEventLoop] = None + self.loop: asyncio.AbstractEventLoop | None = None self.zc = zeroconf - self.protocols: List[AsyncListener] = [] - self.readers: List[_WrappedTransport] = [] - self.senders: List[_WrappedTransport] = [] - self.running_event: Optional[asyncio.Event] = None + self.protocols: list[AsyncListener] = [] + self.readers: list[_WrappedTransport] = [] + self.senders: list[_WrappedTransport] = [] + self.running_event: asyncio.Event | None = None self._listen_socket = listen_socket self._respond_sockets = respond_sockets - self._cleanup_timer: Optional[asyncio.TimerHandle] = None + self._cleanup_timer: asyncio.TimerHandle | None = None def setup( self, loop: asyncio.AbstractEventLoop, - loop_thread_ready: Optional[threading.Event], + loop_thread_ready: threading.Event | None, ) -> None: """Set up the instance.""" self.loop = loop self.running_event = asyncio.Event() self.loop.create_task(self._async_setup(loop_thread_ready)) - async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: + async def _async_setup(self, loop_thread_ready: threading.Event | None) -> None: """Set up the instance.""" self._async_schedule_next_cache_cleanup() await self._async_create_endpoints() diff --git a/src/zeroconf/_exceptions.py b/src/zeroconf/_exceptions.py index 5eb58f79..5fc81259 100644 --- a/src/zeroconf/_exceptions.py +++ b/src/zeroconf/_exceptions.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + class Error(Exception): """Base class for all zeroconf exceptions.""" diff --git a/src/zeroconf/_handlers/__init__.py b/src/zeroconf/_handlers/__init__.py index 30920c6a..584a74ec 100644 --- a/src/zeroconf/_handlers/__init__.py +++ b/src/zeroconf/_handlers/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_handlers/answers.py b/src/zeroconf/_handlers/answers.py index 7ddde197..ec53eb84 100644 --- a/src/zeroconf/_handlers/answers.py +++ b/src/zeroconf/_handlers/answers.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + from operator import attrgetter -from typing import Dict, List, Set +from typing import Dict, Set from .._dns import DNSQuestion, DNSRecord from .._protocol.outgoing import DNSOutgoing @@ -96,7 +98,7 @@ def construct_outgoing_multicast_answers( def construct_outgoing_unicast_answers( answers: _AnswerWithAdditionalsType, ucast_source: bool, - questions: List[DNSQuestion], + questions: list[DNSQuestion], id_: int_, ) -> DNSOutgoing: """Add answers and additionals to a DNSOutgoing.""" @@ -111,7 +113,7 @@ def construct_outgoing_unicast_answers( def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: # Find additionals and suppress any additionals that are already in answers - sending: Set[DNSRecord] = set(answers) + sending: set[DNSRecord] = set(answers) # Answers are sorted to group names together to increase the chance # that similar names will end up in the same packet and can reduce the # overall size of the outgoing response via name compression diff --git a/src/zeroconf/_handlers/multicast_outgoing_queue.py b/src/zeroconf/_handlers/multicast_outgoing_queue.py index caf6470b..73d5ee43 100644 --- a/src/zeroconf/_handlers/multicast_outgoing_queue.py +++ b/src/zeroconf/_handlers/multicast_outgoing_queue.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import random from collections import deque from typing import TYPE_CHECKING @@ -53,7 +55,7 @@ class MulticastOutgoingQueue: "zc", ) - def __init__(self, zeroconf: "Zeroconf", additional_delay: _int, max_aggregation_delay: _int) -> None: + def __init__(self, zeroconf: Zeroconf, additional_delay: _int, max_aggregation_delay: _int) -> None: self.zc = zeroconf self.queue: deque[AnswerGroup] = deque() # Additional delay is used to implement diff --git a/src/zeroconf/_handlers/query_handler.py b/src/zeroconf/_handlers/query_handler.py index ccfc7a77..60209568 100644 --- a/src/zeroconf/_handlers/query_handler.py +++ b/src/zeroconf/_handlers/query_handler.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast +from __future__ import annotations + +from typing import TYPE_CHECKING, cast from .._cache import DNSCache, _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet @@ -52,8 +54,8 @@ _RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} -_EMPTY_SERVICES_LIST: List[ServiceInfo] = [] -_EMPTY_TYPES_LIST: List[str] = [] +_EMPTY_SERVICES_LIST: list[ServiceInfo] = [] +_EMPTY_TYPES_LIST: list[str] = [] _IPVersion_ALL = IPVersion.All @@ -77,8 +79,8 @@ def __init__( self, question: DNSQuestion, strategy_type: _int, - types: List[str], - services: List[ServiceInfo], + types: list[str], + services: list[ServiceInfo], ) -> None: """Create an answer strategy.""" self.question = question @@ -102,17 +104,17 @@ class _QueryResponse: "_ucast", ) - def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None: + def __init__(self, cache: DNSCache, questions: list[DNSQuestion], is_probe: bool, now: float) -> None: """Build a query response.""" self._is_probe = is_probe self._questions = questions self._now = now self._cache = cache self._additionals: _AnswerWithAdditionalsType = {} - self._ucast: Set[DNSRecord] = set() - self._mcast_now: Set[DNSRecord] = set() - self._mcast_aggregate: Set[DNSRecord] = set() - self._mcast_aggregate_last_second: Set[DNSRecord] = set() + self._ucast: set[DNSRecord] = set() + self._mcast_now: set[DNSRecord] = set() + self._mcast_aggregate: set[DNSRecord] = set() + self._mcast_aggregate_last_second: set[DNSRecord] = set() def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast QU query.""" @@ -199,7 +201,7 @@ class QueryHandler: "zc", ) - def __init__(self, zc: "Zeroconf") -> None: + def __init__(self, zc: Zeroconf) -> None: """Init the query handler.""" self.zc = zc self.registry = zc.registry @@ -210,7 +212,7 @@ def __init__(self, zc: "Zeroconf") -> None: def _add_service_type_enumeration_query_answers( self, - types: List[str], + types: list[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, ) -> None: @@ -232,7 +234,7 @@ def _add_service_type_enumeration_query_answers( def _add_pointer_answers( self, - services: List[ServiceInfo], + services: list[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, ) -> None: @@ -251,23 +253,23 @@ def _add_pointer_answers( def _add_address_answers( self, - services: List[ServiceInfo], + services: list[ServiceInfo], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: _int, ) -> None: """Answer A/AAAA/ANY question.""" for service in services: - answers: List[DNSAddress] = [] - additionals: Set[DNSRecord] = set() - seen_types: Set[int] = set() + answers: list[DNSAddress] = [] + additionals: set[DNSRecord] = set() + seen_types: set[int] = set() for dns_address in service._dns_addresses(None, _IPVersion_ALL): seen_types.add(dns_address.type) if dns_address.type != type_: additionals.add(dns_address) elif not known_answers.suppresses(dns_address): answers.append(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + missing_types: set[int] = _ADDRESS_RECORD_TYPES - seen_types if answers: if missing_types: assert service.server is not None, "Service server must be set for NSEC record." @@ -282,8 +284,8 @@ def _answer_question( self, question: DNSQuestion, strategy_type: _int, - types: List[str], - services: List[ServiceInfo], + types: list[str], + services: list[ServiceInfo], known_answers: DNSRRSet, ) -> _AnswerWithAdditionalsType: """Answer a question.""" @@ -311,14 +313,14 @@ def _answer_question( return answer_set def async_response( # pylint: disable=unused-argument - self, msgs: List[DNSIncoming], ucast_source: bool - ) -> Optional[QuestionAnswers]: + self, msgs: list[DNSIncoming], ucast_source: bool + ) -> QuestionAnswers | None: """Deal with incoming query packets. Provides a response if possible. This function must be run in the event loop as it is not threadsafe. """ - strategies: List[_AnswerStrategy] = [] + strategies: list[_AnswerStrategy] = [] for msg in msgs: for question in msg._questions: strategies.extend(self._get_answer_strategies(question)) @@ -334,7 +336,7 @@ def async_response( # pylint: disable=unused-argument questions = msg._questions # Only decode known answers if we are not a probe and we have # at least one answer strategy - answers: List[DNSRecord] = [] + answers: list[DNSRecord] = [] for msg in msgs: if msg.is_probe(): is_probe = True @@ -343,7 +345,7 @@ def async_response( # pylint: disable=unused-argument query_res = _QueryResponse(self.cache, questions, is_probe, msg.now) known_answers = DNSRRSet(answers) - known_answers_set: Optional[Set[DNSRecord]] = None + known_answers_set: set[DNSRecord] | None = None now = msg.now for strategy in strategies: question = strategy.question @@ -373,12 +375,12 @@ def async_response( # pylint: disable=unused-argument def _get_answer_strategies( self, question: DNSQuestion, - ) -> List[_AnswerStrategy]: + ) -> list[_AnswerStrategy]: """Collect strategies to answer a question.""" name = question.name question_lower_name = name.lower() type_ = question.type - strategies: List[_AnswerStrategy] = [] + strategies: list[_AnswerStrategy] = [] if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME: types = self.registry.async_get_types() @@ -433,11 +435,11 @@ def _get_answer_strategies( def handle_assembled_query( self, - packets: List[DNSIncoming], + packets: list[DNSIncoming], addr: _str, port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]], + v6_flow_scope: tuple[()] | tuple[int, int], ) -> None: """Respond to a (re)assembled query. diff --git a/src/zeroconf/_handlers/record_manager.pxd b/src/zeroconf/_handlers/record_manager.pxd index d4e068c2..37232b13 100644 --- a/src/zeroconf/_handlers/record_manager.pxd +++ b/src/zeroconf/_handlers/record_manager.pxd @@ -21,7 +21,7 @@ cdef class RecordManager: cdef public DNSCache cache cdef public cython.set listeners - cpdef void async_updates(self, object now, object records) + cpdef void async_updates(self, object now, list records) cpdef void async_updates_complete(self, bint notify) diff --git a/src/zeroconf/_handlers/record_manager.py b/src/zeroconf/_handlers/record_manager.py index d4e2792c..566f0e8c 100644 --- a/src/zeroconf/_handlers/record_manager.py +++ b/src/zeroconf/_handlers/record_manager.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast +from __future__ import annotations + +from typing import TYPE_CHECKING, cast from .._cache import _UniqueRecordsType from .._dns import DNSQuestion, DNSRecord @@ -42,13 +44,13 @@ class RecordManager: __slots__ = ("cache", "listeners", "zc") - def __init__(self, zeroconf: "Zeroconf") -> None: + def __init__(self, zeroconf: Zeroconf) -> None: """Init the record manager.""" self.zc = zeroconf self.cache = zeroconf.cache - self.listeners: Set[RecordUpdateListener] = set() + self.listeners: set[RecordUpdateListener] = set() - def async_updates(self, now: _float, records: List[RecordUpdate]) -> None: + def async_updates(self, now: _float, records: list[RecordUpdate]) -> None: """Used to notify listeners of new information that has updated a record. @@ -79,12 +81,12 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: This function must be run in the event loop as it is not threadsafe. """ - updates: List[RecordUpdate] = [] - address_adds: List[DNSRecord] = [] - other_adds: List[DNSRecord] = [] - removes: Set[DNSRecord] = set() + updates: list[RecordUpdate] = [] + address_adds: list[DNSRecord] = [] + other_adds: list[DNSRecord] = [] + removes: set[DNSRecord] = set() now = msg.now - unique_types: Set[Tuple[str, int, int]] = set() + unique_types: set[tuple[str, int, int]] = set() cache = self.cache answers = msg.answers() @@ -165,7 +167,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: def async_add_listener( self, listener: RecordUpdateListener, - question: Optional[Union[DNSQuestion, List[DNSQuestion]]], + question: DNSQuestion | list[DNSQuestion] | None, ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to @@ -188,14 +190,14 @@ def async_add_listener( self._async_update_matching_records(listener, questions) def _async_update_matching_records( - self, listener: RecordUpdateListener, questions: List[DNSQuestion] + self, listener: RecordUpdateListener, questions: list[DNSQuestion] ) -> None: """Calls back any existing entries in the cache that answer the question. This function must be run from the event loop. """ now = current_time_millis() - records: List[RecordUpdate] = [ + records: list[RecordUpdate] = [ RecordUpdate(record, None) for question in questions for record in self.cache.async_entries_with_name(question.name) diff --git a/src/zeroconf/_history.py b/src/zeroconf/_history.py index aa28519c..5bae7be0 100644 --- a/src/zeroconf/_history.py +++ b/src/zeroconf/_history.py @@ -20,7 +20,7 @@ USA """ -from typing import Dict, List, Set, Tuple +from __future__ import annotations from ._dns import DNSQuestion, DNSRecord from .const import _DUPLICATE_QUESTION_INTERVAL @@ -36,13 +36,13 @@ class QuestionHistory: def __init__(self) -> None: """Init a new QuestionHistory.""" - self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {} + self._history: dict[DNSQuestion, tuple[float, set[DNSRecord]]] = {} - def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> None: + def add_question_at_time(self, question: DNSQuestion, now: _float, known_answers: set[DNSRecord]) -> None: """Remember a question with known answers.""" self._history[question] = (now, known_answers) - def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSRecord]) -> bool: + def suppresses(self, question: DNSQuestion, now: _float, known_answers: set[DNSRecord]) -> bool: """Check to see if a question should be suppressed. https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 @@ -66,7 +66,7 @@ def suppresses(self, question: DNSQuestion, now: _float, known_answers: Set[DNSR def async_expire(self, now: _float) -> None: """Expire the history of old questions.""" - removes: List[DNSQuestion] = [] + removes: list[DNSQuestion] = [] for question, now_known_answers in self._history.items(): than, _ = now_known_answers if now - than > _DUPLICATE_QUESTION_INTERVAL: diff --git a/src/zeroconf/_logger.py b/src/zeroconf/_logger.py index 1556522e..0d734dfd 100644 --- a/src/zeroconf/_logger.py +++ b/src/zeroconf/_logger.py @@ -21,9 +21,11 @@ USA """ +from __future__ import annotations + import logging import sys -from typing import Any, ClassVar, Dict, Union, cast +from typing import Any, ClassVar, cast log = logging.getLogger(__name__.split(".", maxsplit=1)[0]) log.addHandler(logging.NullHandler()) @@ -38,7 +40,7 @@ def set_logger_level_if_unset() -> None: class QuietLogger: - _seen_logs: ClassVar[Dict[str, Union[int, tuple]]] = {} + _seen_logs: ClassVar[dict[str, int | tuple]] = {} @classmethod def log_exception_warning(cls, *logger_data: Any) -> None: diff --git a/src/zeroconf/_protocol/__init__.py b/src/zeroconf/_protocol/__init__.py index 30920c6a..584a74ec 100644 --- a/src/zeroconf/_protocol/__init__.py +++ b/src/zeroconf/_protocol/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_protocol/incoming.py b/src/zeroconf/_protocol/incoming.py index 6e009b29..7f4a8eec 100644 --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import struct import sys -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any from .._dns import ( DNSAddress, @@ -61,7 +63,7 @@ DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) -_seen_logs: Dict[str, Union[int, tuple]] = {} +_seen_logs: dict[str, int | tuple] = {} _str = str _int = int @@ -94,9 +96,9 @@ class DNSIncoming: def __init__( self, data: bytes, - source: Optional[Tuple[str, int]] = None, - scope_id: Optional[int] = None, - now: Optional[float] = None, + source: tuple[str, int] | None = None, + scope_id: int | None = None, + now: float | None = None, ) -> None: """Constructor from string holding bytes of packet""" self.flags = 0 @@ -104,9 +106,9 @@ def __init__( self.data = data self.view = data self._data_len = len(data) - self._name_cache: Dict[int, List[str]] = {} - self._questions: List[DNSQuestion] = [] - self._answers: List[DNSRecord] = [] + self._name_cache: dict[int, list[str]] = {} + self._questions: list[DNSQuestion] = [] + self._answers: list[DNSRecord] = [] self.id = 0 self._num_questions = 0 self._num_answers = 0 @@ -146,7 +148,7 @@ def truncated(self) -> bool: return (self.flags & _FLAGS_TC) == _FLAGS_TC @property - def questions(self) -> List[DNSQuestion]: + def questions(self) -> list[DNSQuestion]: """Questions in the packet.""" return self._questions @@ -189,7 +191,7 @@ def _log_exception_debug(cls, *logger_data: Any) -> None: log_exc_info = True log.debug(*(logger_data or ["Exception occurred"]), exc_info=log_exc_info) - def answers(self) -> List[DNSRecord]: + def answers(self) -> list[DNSRecord]: """Answers in the packet.""" if not self._did_read_others: try: @@ -306,7 +308,7 @@ def _read_others(self) -> None: def _read_record( self, domain: _str, type_: _int, class_: _int, ttl: _int, length: _int - ) -> Optional[DNSRecord]: + ) -> DNSRecord | None: """Read known records types and skip unknown ones.""" if type_ == _TYPE_A: address_rec = DNSAddress.__new__(DNSAddress) @@ -384,7 +386,7 @@ def _read_record( self.offset += length return None - def _read_bitmap(self, end: _int) -> List[int]: + def _read_bitmap(self, end: _int) -> list[int]: """Reads an NSEC bitmap from the packet.""" rdtypes = [] view = self.view @@ -404,8 +406,8 @@ def _read_bitmap(self, end: _int) -> List[int]: def _read_name(self) -> str: """Reads a domain name from the packet.""" - labels: List[str] = [] - seen_pointers: Set[int] = set() + labels: list[str] = [] + seen_pointers: set[int] = set() original_offset = self.offset self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers) self._name_cache[original_offset] = labels @@ -416,7 +418,7 @@ def _read_name(self) -> str: ) return name - def _decode_labels_at_offset(self, off: _int, labels: List[str], seen_pointers: Set[int]) -> int: + def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. view = self.view while off < self._data_len: diff --git a/src/zeroconf/_protocol/outgoing.py b/src/zeroconf/_protocol/outgoing.py index c937350e..f5d09821 100644 --- a/src/zeroconf/_protocol/outgoing.py +++ b/src/zeroconf/_protocol/outgoing.py @@ -20,10 +20,12 @@ USA """ +from __future__ import annotations + import enum import logging from struct import Struct -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Sequence from .._dns import DNSPointer, DNSQuestion, DNSRecord from .._exceptions import NamePartTooLongException @@ -98,20 +100,20 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: self.finished = False self.id = id_ self.multicast = multicast - self.packets_data: List[bytes] = [] + self.packets_data: list[bytes] = [] # these 3 are per-packet -- see also _reset_for_next_packet() - self.names: Dict[str, int] = {} - self.data: List[bytes] = [] + self.names: dict[str, int] = {} + self.data: list[bytes] = [] self.size: int = _DNS_PACKET_HEADER_LEN self.allow_long: bool = True self.state = STATE_INIT - self.questions: List[DNSQuestion] = [] - self.answers: List[Tuple[DNSRecord, float]] = [] - self.authorities: List[DNSPointer] = [] - self.additionals: List[DNSRecord] = [] + self.questions: list[DNSQuestion] = [] + self.answers: list[tuple[DNSRecord, float]] = [] + self.authorities: list[DNSPointer] = [] + self.additionals: list[DNSRecord] = [] def is_query(self) -> bool: """Returns true if this is a query.""" @@ -150,7 +152,7 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: if not record.suppressed_by(inp): self.add_answer_at_time(record, 0.0) - def add_answer_at_time(self, record: Optional[DNSRecord], now: float_) -> None: + def add_answer_at_time(self, record: DNSRecord | None, now: float_) -> None: """Adds an answer if it does not expire by a certain time""" now_double = now if record is not None and (now_double == 0 or not record.is_expired(now_double)): @@ -220,7 +222,7 @@ def write_short(self, value: int_) -> None: self.data.append(self._get_short(value)) self.size += 2 - def _write_int(self, value: Union[float, int]) -> None: + def _write_int(self, value: float | int) -> None: """Writes an unsigned integer to the packet""" value_as_int = int(value) long_bytes = LONG_LOOKUP.get(value_as_int) @@ -313,7 +315,7 @@ def _write_question(self, question: DNSQuestion_) -> bool: self._write_record_class(question) return self._check_data_limit_or_rollback(start_data_length, start_size) - def _write_record_class(self, record: Union[DNSQuestion_, DNSRecord_]) -> None: + def _write_record_class(self, record: DNSQuestion_ | DNSRecord_) -> None: """Write out the record class including the unique/unicast (QU) bit.""" class_ = record.class_ if record.unique is True and self.multicast: @@ -409,7 +411,7 @@ def _has_more_to_add( or additional_offset < len(self.additionals) ) - def packets(self) -> List[bytes]: + def packets(self) -> list[bytes]: """Returns a list of bytestrings containing the packets' bytes No further parts should be added to the packet once this diff --git a/src/zeroconf/_record_update.py b/src/zeroconf/_record_update.py index 912ab6f1..5f817511 100644 --- a/src/zeroconf/_record_update.py +++ b/src/zeroconf/_record_update.py @@ -20,7 +20,7 @@ USA """ -from typing import Optional +from __future__ import annotations from ._dns import DNSRecord @@ -30,16 +30,16 @@ class RecordUpdate: __slots__ = ("new", "old") - def __init__(self, new: DNSRecord, old: Optional[DNSRecord] = None) -> None: + def __init__(self, new: DNSRecord, old: DNSRecord | None = None) -> None: """RecordUpdate represents a change in a DNS record.""" self._fast_init(new, old) - def _fast_init(self, new: _DNSRecord, old: Optional[_DNSRecord]) -> None: + def _fast_init(self, new: _DNSRecord, old: _DNSRecord | None) -> None: """Fast init for RecordUpdate.""" self.new = new self.old = old - def __getitem__(self, index: int) -> Optional[DNSRecord]: + def __getitem__(self, index: int) -> DNSRecord | None: """Get the new or old record.""" if index == 0: return self.new diff --git a/src/zeroconf/_services/__init__.py b/src/zeroconf/_services/__init__.py index 7a6bddeb..6936aed6 100644 --- a/src/zeroconf/_services/__init__.py +++ b/src/zeroconf/_services/__init__.py @@ -20,8 +20,10 @@ USA """ +from __future__ import annotations + import enum -from typing import TYPE_CHECKING, Any, Callable, List +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: from .._core import Zeroconf @@ -35,13 +37,13 @@ class ServiceStateChange(enum.Enum): class ServiceListener: - def add_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() - def remove_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() - def update_service(self, zc: "Zeroconf", type_: str, name: str) -> None: + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: raise NotImplementedError() @@ -49,27 +51,27 @@ class Signal: __slots__ = ("_handlers",) def __init__(self) -> None: - self._handlers: List[Callable[..., None]] = [] + self._handlers: list[Callable[..., None]] = [] def fire(self, **kwargs: Any) -> None: for h in self._handlers[:]: h(**kwargs) @property - def registration_interface(self) -> "SignalRegistrationInterface": + def registration_interface(self) -> SignalRegistrationInterface: return SignalRegistrationInterface(self._handlers) class SignalRegistrationInterface: __slots__ = ("_handlers",) - def __init__(self, handlers: List[Callable[..., None]]) -> None: + def __init__(self, handlers: list[Callable[..., None]]) -> None: self._handlers = handlers - def register_handler(self, handler: Callable[..., None]) -> "SignalRegistrationInterface": + def register_handler(self, handler: Callable[..., None]) -> SignalRegistrationInterface: self._handlers.append(handler) return self - def unregister_handler(self, handler: Callable[..., None]) -> "SignalRegistrationInterface": + def unregister_handler(self, handler: Callable[..., None]) -> SignalRegistrationInterface: self._handlers.remove(handler) return self diff --git a/src/zeroconf/_services/browser.py b/src/zeroconf/_services/browser.py index 42aaa1ac..c2ab115b 100644 --- a/src/zeroconf/_services/browser.py +++ b/src/zeroconf/_services/browser.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import asyncio import heapq import queue @@ -36,11 +38,7 @@ Dict, Iterable, List, - Optional, Set, - Tuple, - Type, - Union, cast, ) @@ -155,13 +153,13 @@ def __repr__(self) -> str: ">" ) - def __lt__(self, other: "_ScheduledPTRQuery") -> bool: + def __lt__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis < other.when_millis return NotImplemented - def __le__(self, other: "_ScheduledPTRQuery") -> bool: + def __le__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis < other.when_millis or self.__eq__(other) @@ -173,13 +171,13 @@ def __eq__(self, other: Any) -> bool: return self.when_millis == other.when_millis return NotImplemented - def __ge__(self, other: "_ScheduledPTRQuery") -> bool: + def __ge__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis > other.when_millis or self.__eq__(other) return NotImplemented - def __gt__(self, other: "_ScheduledPTRQuery") -> bool: + def __gt__(self, other: _ScheduledPTRQuery) -> bool: """Compare two scheduled queries.""" if type(other) is _ScheduledPTRQuery: return self.when_millis > other.when_millis @@ -197,7 +195,7 @@ def __init__(self, now_millis: float, multicast: bool) -> None: self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast) self.bytes = 0 - def add(self, max_compressed_size: int_, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + def add(self, max_compressed_size: int_, question: DNSQuestion, answers: set[DNSPointer]) -> None: """Add a new set of questions and known answers to the outgoing.""" self.out.add_question(question) for answer in answers: @@ -209,7 +207,7 @@ def group_ptr_queries_with_known_answers( now: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers, -) -> List[DNSOutgoing]: +) -> list[DNSOutgoing]: """Aggregate queries so that as many known answers as possible fit in the same packet without having known answers spill over into the next packet unless the question and known answers are always going to exceed the packet size. @@ -225,19 +223,19 @@ def _group_ptr_queries_with_known_answers( now_millis: float_, multicast: bool_, question_with_known_answers: _QuestionWithKnownAnswers, -) -> List[DNSOutgoing]: +) -> list[DNSOutgoing]: """Inner wrapper for group_ptr_queries_with_known_answers.""" # This is the maximum size the query + known answers can be with name compression. # The actual size of the query + known answers may be a bit smaller since other # parts may be shared when the final DNSOutgoing packets are constructed. The # goal of this algorithm is to quickly bucket the query + known answers without # the overhead of actually constructing the packets. - query_by_size: Dict[DNSQuestion, int] = { + query_by_size: dict[DNSQuestion, int] = { question: (question.max_size + sum(answer.max_size_compressed for answer in known_answers)) for question, known_answers in question_with_known_answers.items() } max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN - query_buckets: List[_DNSPointerOutgoingBucket] = [] + query_buckets: list[_DNSPointerOutgoingBucket] = [] for question in sorted( query_by_size, key=query_by_size.get, # type: ignore @@ -261,12 +259,12 @@ def _group_ptr_queries_with_known_answers( def generate_service_query( - zc: "Zeroconf", + zc: Zeroconf, now_millis: float_, - types_: Set[str], + types_: set[str], multicast: bool, - question_type: Optional[DNSQuestionType], -) -> List[DNSOutgoing]: + question_type: DNSQuestionType | None, +) -> list[DNSOutgoing]: """Generate a service query for sending with zeroconf.send.""" questions_with_known_answers: _QuestionWithKnownAnswers = {} qu_question = not multicast if question_type is None else question_type is QU_QUESTION @@ -296,7 +294,7 @@ def generate_service_query( def _on_change_dispatcher( listener: ServiceListener, - zeroconf: "Zeroconf", + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange, @@ -346,14 +344,14 @@ class QueryScheduler: def __init__( self, - zc: "Zeroconf", - types: Set[str], - addr: Optional[str], + zc: Zeroconf, + types: set[str], + addr: str | None, port: int, multicast: bool, delay: int, - first_random_delay_interval: Tuple[int, int], - question_type: Optional[DNSQuestionType], + first_random_delay_interval: tuple[int, int], + question_type: DNSQuestionType | None, ) -> None: self._zc = zc self._types = types @@ -362,11 +360,11 @@ def __init__( self._multicast = multicast self._first_random_delay_interval = first_random_delay_interval self._min_time_between_queries_millis = delay - self._loop: Optional[asyncio.AbstractEventLoop] = None + self._loop: asyncio.AbstractEventLoop | None = None self._startup_queries_sent = 0 - self._next_scheduled_for_alias: Dict[str, _ScheduledPTRQuery] = {} + self._next_scheduled_for_alias: dict[str, _ScheduledPTRQuery] = {} self._query_heap: list[_ScheduledPTRQuery] = [] - self._next_run: Optional[asyncio.TimerHandle] = None + self._next_run: asyncio.TimerHandle | None = None self._clock_resolution_millis = time.get_clock_info("monotonic").resolution * 1000 self._question_type = question_type @@ -500,10 +498,10 @@ def _process_ready_types(self) -> None: # with a minimum time between queries of _min_time_between_queries # which defaults to 10s - ready_types: Set[str] = set() - next_scheduled: Optional[_ScheduledPTRQuery] = None + ready_types: set[str] = set() + next_scheduled: _ScheduledPTRQuery | None = None end_time_millis = now_millis + self._clock_resolution_millis - schedule_rescue: List[_ScheduledPTRQuery] = [] + schedule_rescue: list[_ScheduledPTRQuery] = [] while self._query_heap: query = self._query_heap[0] @@ -538,7 +536,7 @@ def _process_ready_types(self) -> None: self._next_run = self._loop.call_at(millis_to_seconds(next_when_millis), self._process_ready_types) def async_send_ready_queries( - self, first_request: bool, now_millis: float_, ready_types: Set[str] + self, first_request: bool, now_millis: float_, ready_types: set[str] ) -> None: """Send any ready queries.""" # If they did not specify and this is the first request, ask QU questions @@ -569,14 +567,14 @@ class _ServiceBrowserBase(RecordUpdateListener): def __init__( self, - zc: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zc: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: """Used to browse for a service for specific type(s). @@ -596,7 +594,7 @@ def __init__( discovers changes in the services availability. """ assert handlers or listener, "You need to specify at least one handler" - self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) + self.types: set[str] = set(type_ if isinstance(type_, list) else [type_]) for check_type_ in self.types: # Will generate BadTypeInNameException on a bad name service_type_name(check_type_, strict=False) @@ -604,7 +602,7 @@ def __init__( self._cache = zc.cache assert zc.loop is not None self._loop = zc.loop - self._pending_handlers: Dict[Tuple[str, str], ServiceStateChange] = {} + self._pending_handlers: dict[tuple[str, str], ServiceStateChange] = {} self._service_state_changed = Signal() self.query_scheduler = QueryScheduler( zc, @@ -617,7 +615,7 @@ def __init__( question_type, ) self.done = False - self._query_sender_task: Optional[asyncio.Task] = None + self._query_sender_task: asyncio.Task | None = None if hasattr(handlers, "add_service"): listener = cast("ServiceListener", handlers) @@ -645,7 +643,7 @@ def _async_start(self) -> None: def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface - def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]: + def _names_matching_types(self, names: Iterable[str]) -> list[tuple[str, str]]: """Return the type and name for records matching the types we are browsing.""" return [ (type_, name) for name in names for type_ in self.types.intersection(cached_possible_types(name)) @@ -670,7 +668,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Callback invoked by Zeroconf when new information arrives. Updates information required by browser in the Zeroconf cache. @@ -727,7 +725,7 @@ def async_update_records_complete(self) -> None: self._fire_service_state_changed_event(pending) self._pending_handlers.clear() - def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None: + def _fire_service_state_changed_event(self, event: tuple[tuple[str, str], ServiceStateChange]) -> None: """Fire a service state changed event. When running with ServiceBrowser, this will happen in the dedicated @@ -769,14 +767,14 @@ class ServiceBrowser(_ServiceBrowserBase, threading.Thread): def __init__( self, - zc: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zc: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: assert zc.loop is not None if not zc.loop.is_running(): @@ -821,14 +819,14 @@ def async_update_records_complete(self) -> None: self.queue.put(pending) self._pending_handlers.clear() - def __enter__(self) -> "ServiceBrowser": + def __enter__(self) -> ServiceBrowser: return self def __exit__( # pylint: disable=useless-return self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: self.cancel() return None diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index a6e815b5..67777459 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + import asyncio import random -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, cast from .._cache import DNSCache from .._dns import ( @@ -106,7 +108,7 @@ from .._core import Zeroconf -def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) -> str: +def instance_name_from_service_info(info: ServiceInfo, strict: bool = True) -> str: """Calculate the instance name from the ServiceInfo.""" # This is kind of funky because of the subtype based tests # need to make subtypes a first class citizen @@ -168,17 +170,17 @@ def __init__( self, type_: str, name: str, - port: Optional[int] = None, + port: int | None = None, weight: int = 0, priority: int = 0, - properties: Union[bytes, Dict] = b"", - server: Optional[str] = None, + properties: bytes | dict = b"", + server: str | None = None, host_ttl: int = _DNS_HOST_TTL, other_ttl: int = _DNS_OTHER_TTL, *, - addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None, - interface_index: Optional[int] = None, + addresses: list[bytes] | None = None, + parsed_addresses: list[str] | None = None, + interface_index: int | None = None, ) -> None: # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: @@ -190,8 +192,8 @@ def __init__( self.type = type_ self._name = name self.key = name.lower() - self._ipv4_addresses: List[ZeroconfIPv4Address] = [] - self._ipv6_addresses: List[ZeroconfIPv6Address] = [] + self._ipv4_addresses: list[ZeroconfIPv4Address] = [] + self._ipv6_addresses: list[ZeroconfIPv6Address] = [] if addresses is not None: self.addresses = addresses elif parsed_addresses is not None: @@ -201,20 +203,20 @@ def __init__( self.priority = priority self.server = server if server else None self.server_key = server.lower() if server else None - self._properties: Optional[Dict[bytes, Optional[bytes]]] = None - self._decoded_properties: Optional[Dict[str, Optional[str]]] = None + self._properties: dict[bytes, bytes | None] | None = None + self._decoded_properties: dict[str, str | None] | None = None if isinstance(properties, bytes): self._set_text(properties) else: self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl - self._new_records_futures: Optional[Set[asyncio.Future]] = None - self._dns_address_cache: Optional[List[DNSAddress]] = None - self._dns_pointer_cache: Optional[DNSPointer] = None - self._dns_service_cache: Optional[DNSService] = None - self._dns_text_cache: Optional[DNSText] = None - self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None + self._new_records_futures: set[asyncio.Future] | None = None + self._dns_address_cache: list[DNSAddress] | None = None + self._dns_pointer_cache: DNSPointer | None = None + self._dns_service_cache: DNSService | None = None + self._dns_text_cache: DNSText | None = None + self._get_address_and_nsec_records_cache: set[DNSRecord] | None = None self._query_record_types = {_TYPE_SRV, _TYPE_TXT, _TYPE_A, _TYPE_AAAA} @property @@ -232,7 +234,7 @@ def name(self, name: str) -> None: self._dns_text_cache = None @property - def addresses(self) -> List[bytes]: + def addresses(self) -> list[bytes]: """IPv4 addresses of this service. Only IPv4 addresses are returned for backward compatibility. @@ -242,7 +244,7 @@ def addresses(self) -> List[bytes]: return self.addresses_by_version(IPVersion.V4Only) @addresses.setter - def addresses(self, value: List[bytes]) -> None: + def addresses(self, value: list[bytes]) -> None: """Replace the addresses list. This replaces all currently stored addresses, both IPv4 and IPv6. @@ -272,7 +274,7 @@ def addresses(self, value: List[bytes]) -> None: self._ipv6_addresses.append(addr) @property - def properties(self) -> Dict[bytes, Optional[bytes]]: + def properties(self) -> dict[bytes, bytes | None]: """Return properties as bytes.""" if self._properties is None: self._unpack_text_into_properties() @@ -281,7 +283,7 @@ def properties(self) -> Dict[bytes, Optional[bytes]]: return self._properties @property - def decoded_properties(self) -> Dict[str, Optional[str]]: + def decoded_properties(self) -> dict[str, str | None]: """Return properties as strings.""" if self._decoded_properties is None: self._generate_decoded_properties() @@ -297,7 +299,7 @@ def async_clear_cache(self) -> None: self._dns_text_cache = None self._get_address_and_nsec_records_cache = None - async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + async def async_wait(self, timeout: float, loop: asyncio.AbstractEventLoop | None = None) -> None: """Calling task waits for a given number of milliseconds or until notified.""" if not self._new_records_futures: self._new_records_futures = set() @@ -305,7 +307,7 @@ async def async_wait(self, timeout: float, loop: Optional[asyncio.AbstractEventL loop or asyncio.get_running_loop(), self._new_records_futures, timeout ) - def addresses_by_version(self, version: IPVersion) -> List[bytes]: + def addresses_by_version(self, version: IPVersion) -> list[bytes]: """List addresses matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -325,7 +327,7 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: def ip_addresses_by_version( self, version: IPVersion - ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: + ) -> list[ZeroconfIPv4Address] | list[ZeroconfIPv6Address]: """List ip_address objects matching IP version. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -338,7 +340,7 @@ def ip_addresses_by_version( def _ip_addresses_by_version_value( self, version_value: int_ - ) -> Union[List[ZeroconfIPv4Address], List[ZeroconfIPv6Address]]: + ) -> list[ZeroconfIPv4Address] | list[ZeroconfIPv6Address]: """Backend for addresses_by_version that uses the raw value.""" if version_value == _IPVersion_All_value: return [*self._ipv4_addresses, *self._ipv6_addresses] # type: ignore[return-value] @@ -346,7 +348,7 @@ def _ip_addresses_by_version_value( return self._ipv4_addresses return self._ipv6_addresses - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> list[str]: """List addresses in their parsed string form. Addresses are guaranteed to be returned in LIFO (last in, first out) @@ -357,7 +359,7 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """ return [str_without_scope_id(addr) for addr in self._ip_addresses_by_version_value(version.value)] - def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> list[str]: """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local addresses are qualified with % when available @@ -369,9 +371,9 @@ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[st """ return [str(addr) for addr in self._ip_addresses_by_version_value(version.value)] - def _set_properties(self, properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]]) -> None: + def _set_properties(self, properties: dict[str | bytes, str | bytes | None]) -> None: """Sets properties and text of this info from a dictionary""" - list_: List[bytes] = [] + list_: list[bytes] = [] properties_contain_str = False result = b"" for key, value in properties.items(): @@ -425,7 +427,7 @@ def _unpack_text_into_properties(self) -> None: return index = 0 - properties: Dict[bytes, Optional[bytes]] = {} + properties: dict[bytes, bytes | None] = {} while index < end: length = text[index] index += 1 @@ -443,10 +445,10 @@ def get_name(self) -> str: return self._name[: len(self._name) - len(self.type) - 1] def _get_ip_addresses_from_cache_lifo( - self, zc: "Zeroconf", now: float_, type: int_ - ) -> List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: + self, zc: Zeroconf, now: float_, type: int_ + ) -> list[ZeroconfIPv4Address | ZeroconfIPv6Address]: """Set IPv6 addresses from the cache.""" - address_list: List[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]] = [] + address_list: list[ZeroconfIPv4Address | ZeroconfIPv6Address] = [] for record in self._get_address_records_from_cache_by_type(zc, type): if record.is_expired(now): continue @@ -456,7 +458,7 @@ def _get_ip_addresses_from_cache_lifo( address_list.reverse() # Reverse to get LIFO order return address_list - def _set_ipv6_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: + def _set_ipv6_addresses_from_cache(self, zc: Zeroconf, now: float_) -> None: """Set IPv6 addresses from the cache.""" if TYPE_CHECKING: self._ipv6_addresses = cast( @@ -466,7 +468,7 @@ def _set_ipv6_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: else: self._ipv6_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_AAAA) - def _set_ipv4_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: + def _set_ipv4_addresses_from_cache(self, zc: Zeroconf, now: float_) -> None: """Set IPv4 addresses from the cache.""" if TYPE_CHECKING: self._ipv4_addresses = cast( @@ -476,7 +478,7 @@ def _set_ipv4_addresses_from_cache(self, zc: "Zeroconf", now: float_) -> None: else: self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A) - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Updates service information from a DNS record. This method will be run in the event loop. @@ -488,7 +490,7 @@ def async_update_records(self, zc: "Zeroconf", now: float_, records: List[Record if updated and new_records_futures: _resolve_all_futures_to_none(new_records_futures) - def _process_record_threadsafe(self, zc: "Zeroconf", record: DNSRecord, now: float_) -> bool: + def _process_record_threadsafe(self, zc: Zeroconf, record: DNSRecord, now: float_) -> bool: """Thread safe record updating. Returns True if a new record was added. @@ -575,17 +577,17 @@ def _process_record_threadsafe(self, zc: "Zeroconf", record: DNSRecord, now: flo def dns_addresses( self, - override_ttl: Optional[int] = None, + override_ttl: int | None = None, version: IPVersion = IPVersion.All, - ) -> List[DNSAddress]: + ) -> list[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return self._dns_addresses(override_ttl, version) def _dns_addresses( self, - override_ttl: Optional[int], + override_ttl: int | None, version: IPVersion, - ) -> List[DNSAddress]: + ) -> list[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" cacheable = version is IPVersion.All and override_ttl is None if self._dns_address_cache is not None and cacheable: @@ -609,11 +611,11 @@ def _dns_addresses( self._dns_address_cache = records return records - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + def dns_pointer(self, override_ttl: int | None = None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" return self._dns_pointer(override_ttl) - def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer: + def _dns_pointer(self, override_ttl: int | None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" cacheable = override_ttl is None if self._dns_pointer_cache is not None and cacheable: @@ -630,11 +632,11 @@ def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer: self._dns_pointer_cache = record return record - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + def dns_service(self, override_ttl: int | None = None) -> DNSService: """Return DNSService from ServiceInfo.""" return self._dns_service(override_ttl) - def _dns_service(self, override_ttl: Optional[int]) -> DNSService: + def _dns_service(self, override_ttl: int | None) -> DNSService: """Return DNSService from ServiceInfo.""" cacheable = override_ttl is None if self._dns_service_cache is not None and cacheable: @@ -657,11 +659,11 @@ def _dns_service(self, override_ttl: Optional[int]) -> DNSService: self._dns_service_cache = record return record - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + def dns_text(self, override_ttl: int | None = None) -> DNSText: """Return DNSText from ServiceInfo.""" return self._dns_text(override_ttl) - def _dns_text(self, override_ttl: Optional[int]) -> DNSText: + def _dns_text(self, override_ttl: int | None) -> DNSText: """Return DNSText from ServiceInfo.""" cacheable = override_ttl is None if self._dns_text_cache is not None and cacheable: @@ -678,11 +680,11 @@ def _dns_text(self, override_ttl: Optional[int]) -> DNSText: self._dns_text_cache = record return record - def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec: + def dns_nsec(self, missing_types: list[int], override_ttl: int | None = None) -> DNSNsec: """Return DNSNsec from ServiceInfo.""" return self._dns_nsec(missing_types, override_ttl) - def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DNSNsec: + def _dns_nsec(self, missing_types: list[int], override_ttl: int | None) -> DNSNsec: """Return DNSNsec from ServiceInfo.""" return DNSNsec( self._name, @@ -694,17 +696,17 @@ def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DN 0.0, ) - def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]: + def get_address_and_nsec_records(self, override_ttl: int | None = None) -> set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" return self._get_address_and_nsec_records(override_ttl) - def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSRecord]: + def _get_address_and_nsec_records(self, override_ttl: int | None) -> set[DNSRecord]: """Build a set of address records and NSEC records for non-present record types.""" cacheable = override_ttl is None if self._get_address_and_nsec_records_cache is not None and cacheable: return self._get_address_and_nsec_records_cache - missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy() - records: Set[DNSRecord] = set() + missing_types: set[int] = _ADDRESS_RECORD_TYPES.copy() + records: set[DNSRecord] = set() for dns_address in self._dns_addresses(override_ttl, IPVersion.All): missing_types.discard(dns_address.type) records.add(dns_address) @@ -715,7 +717,7 @@ def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSR self._get_address_and_nsec_records_cache = records return records - def _get_address_records_from_cache_by_type(self, zc: "Zeroconf", _type: int_) -> List[DNSAddress]: + def _get_address_records_from_cache_by_type(self, zc: Zeroconf, _type: int_) -> list[DNSAddress]: """Get the addresses from the cache.""" if self.server_key is None: return [] @@ -738,14 +740,14 @@ def set_server_if_missing(self) -> None: self.server = self._name self.server_key = self.key - def load_from_cache(self, zc: "Zeroconf", now: Optional[float_] = None) -> bool: + def load_from_cache(self, zc: Zeroconf, now: float_ | None = None) -> bool: """Populate the service info from the cache. This method is designed to be threadsafe. """ return self._load_from_cache(zc, now or current_time_millis()) - def _load_from_cache(self, zc: "Zeroconf", now: float_) -> bool: + def _load_from_cache(self, zc: Zeroconf, now: float_) -> bool: """Populate the service info from the cache. This method is designed to be threadsafe. @@ -775,10 +777,10 @@ def _is_complete(self) -> bool: def request( self, - zc: "Zeroconf", + zc: Zeroconf, timeout: float, - question_type: Optional[DNSQuestionType] = None, - addr: Optional[str] = None, + question_type: DNSQuestionType | None = None, + addr: str | None = None, port: int = _MDNS_PORT, ) -> bool: """Returns true if the service could be discovered on the @@ -814,10 +816,10 @@ def _get_random_delay(self) -> int_: async def async_request( self, - zc: "Zeroconf", + zc: Zeroconf, timeout: float, - question_type: Optional[DNSQuestionType] = None, - addr: Optional[str] = None, + question_type: DNSQuestionType | None = None, + addr: str | None = None, port: int = _MDNS_PORT, ) -> bool: """Returns true if the service could be discovered on the @@ -914,7 +916,7 @@ def _add_question_with_known_answers( out.add_answer_at_time(answer, now) def _generate_request_query( - self, zc: "Zeroconf", now: float_, question_type: DNSQuestionType + self, zc: Zeroconf, now: float_, question_type: DNSQuestionType ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index 4100c690..937992eb 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -20,7 +20,7 @@ USA """ -from typing import Dict, List, Optional, Union +from __future__ import annotations from .._exceptions import ServiceNameAlreadyRegistered from .info import ServiceInfo @@ -41,16 +41,16 @@ def __init__( self, ) -> None: """Create the ServiceRegistry class.""" - self._services: Dict[str, ServiceInfo] = {} - self.types: Dict[str, List] = {} - self.servers: Dict[str, List] = {} + self._services: dict[str, ServiceInfo] = {} + self.types: dict[str, list] = {} + self.servers: dict[str, list] = {} self.has_entries: bool = False def async_add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" self._add(info) - def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: + def async_remove(self, info: list[ServiceInfo] | ServiceInfo) -> None: """Remove a new service from the registry.""" self._remove(info if isinstance(info, list) else [info]) @@ -59,27 +59,27 @@ def async_update(self, info: ServiceInfo) -> None: self._remove([info]) self._add(info) - def async_get_service_infos(self) -> List[ServiceInfo]: + def async_get_service_infos(self) -> list[ServiceInfo]: """Return all ServiceInfo.""" return list(self._services.values()) - def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: + def async_get_info_name(self, name: str) -> ServiceInfo | None: """Return all ServiceInfo for the name.""" return self._services.get(name) - def async_get_types(self) -> List[str]: + def async_get_types(self) -> list[str]: """Return all types.""" return list(self.types) - def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: + def async_get_infos_type(self, type_: str) -> list[ServiceInfo]: """Return all ServiceInfo matching type.""" return self._async_get_by_index(self.types, type_) - def async_get_infos_server(self, server: str) -> List[ServiceInfo]: + def async_get_infos_server(self, server: str) -> list[ServiceInfo]: """Return all ServiceInfo matching server.""" return self._async_get_by_index(self.servers, server) - def _async_get_by_index(self, records: Dict[str, List], key: _str) -> List[ServiceInfo]: + def _async_get_by_index(self, records: dict[str, list], key: _str) -> list[ServiceInfo]: """Return all ServiceInfo matching the index.""" record_list = records.get(key) if record_list is None: @@ -98,7 +98,7 @@ def _add(self, info: ServiceInfo) -> None: self.servers.setdefault(info.server_key, []).append(info.key) self.has_entries = True - def _remove(self, infos: List[ServiceInfo]) -> None: + def _remove(self, infos: list[ServiceInfo]) -> None: """Remove a services under the lock.""" for info in infos: old_service_info = self._services.get(info.key) diff --git a/src/zeroconf/_services/types.py b/src/zeroconf/_services/types.py index 63b6d19a..af25dc6d 100644 --- a/src/zeroconf/_services/types.py +++ b/src/zeroconf/_services/types.py @@ -20,8 +20,9 @@ USA """ +from __future__ import annotations + import time -from typing import Optional, Set, Tuple, Union from .._core import Zeroconf from .._services import ServiceListener @@ -37,7 +38,7 @@ class ZeroconfServiceTypes(ServiceListener): def __init__(self) -> None: """Keep track of found services in a set.""" - self.found_services: Set[str] = set() + self.found_services: set[str] = set() def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: """Service added.""" @@ -52,11 +53,11 @@ def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: @classmethod def find( cls, - zc: Optional[Zeroconf] = None, - timeout: Union[int, float] = 5, + zc: Zeroconf | None = None, + timeout: int | float = 5, interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: + ip_version: IPVersion | None = None, + ) -> tuple[str, ...]: """ Return all of the advertised services on any local networks. diff --git a/src/zeroconf/_transport.py b/src/zeroconf/_transport.py index b0811094..c8d7699b 100644 --- a/src/zeroconf/_transport.py +++ b/src/zeroconf/_transport.py @@ -20,9 +20,10 @@ USA """ +from __future__ import annotations + import asyncio import socket -from typing import Tuple class _WrappedTransport: @@ -42,7 +43,7 @@ def __init__( is_ipv6: bool, sock: socket.socket, fileno: int, - sock_name: Tuple, + sock_name: tuple, ) -> None: """Initialize the wrapped transport. diff --git a/src/zeroconf/_updates.py b/src/zeroconf/_updates.py index 58be33d8..c0bf9b8c 100644 --- a/src/zeroconf/_updates.py +++ b/src/zeroconf/_updates.py @@ -20,7 +20,9 @@ USA """ -from typing import TYPE_CHECKING, List +from __future__ import annotations + +from typing import TYPE_CHECKING from ._dns import DNSRecord from ._record_update import RecordUpdate @@ -40,7 +42,7 @@ class RecordUpdateListener: """ def update_record( # pylint: disable=no-self-use - self, zc: "Zeroconf", now: float, record: DNSRecord + self, zc: Zeroconf, now: float, record: DNSRecord ) -> None: """Update a single record. @@ -49,7 +51,7 @@ def update_record( # pylint: disable=no-self-use """ raise RuntimeError("update_record is deprecated and will be removed in a future version.") - def async_update_records(self, zc: "Zeroconf", now: float_, records: List[RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float_, records: list[RecordUpdate]) -> None: """Update multiple records in one shot. All records that are received in a single packet are passed diff --git a/src/zeroconf/_utils/__init__.py b/src/zeroconf/_utils/__init__.py index 30920c6a..584a74ec 100644 --- a/src/zeroconf/_utils/__init__.py +++ b/src/zeroconf/_utils/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/src/zeroconf/_utils/asyncio.py b/src/zeroconf/_utils/asyncio.py index 6d070e30..07b3f422 100644 --- a/src/zeroconf/_utils/asyncio.py +++ b/src/zeroconf/_utils/asyncio.py @@ -20,11 +20,13 @@ USA """ +from __future__ import annotations + import asyncio import concurrent.futures import contextlib import sys -from typing import Any, Awaitable, Coroutine, Optional, Set +from typing import Any, Awaitable, Coroutine if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout @@ -47,7 +49,7 @@ def _set_future_none_if_not_done(fut: asyncio.Future) -> None: fut.set_result(None) -def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None: +def _resolve_all_futures_to_none(futures: set[asyncio.Future]) -> None: """Resolve all futures to None.""" for fut in futures: _set_future_none_if_not_done(fut) @@ -55,7 +57,7 @@ def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None: async def wait_for_future_set_or_timeout( - loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float + loop: asyncio.AbstractEventLoop, future_set: set[asyncio.Future], timeout: float ) -> None: """Wait for a future or timeout (in milliseconds).""" future = loop.create_future() @@ -75,7 +77,7 @@ async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: await event.wait() -async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: +async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> set[asyncio.Task]: """Return all tasks running.""" await asyncio.sleep(0) # flush out any call_soon_threadsafe # If there are multiple event loops running, all_tasks is not @@ -87,7 +89,7 @@ async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.T return set() -async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: +async def _wait_for_loop_tasks(wait_tasks: set[asyncio.Task]) -> None: """Wait for the event loop thread we started to shutdown.""" await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) @@ -130,7 +132,7 @@ def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: loop.call_soon_threadsafe(loop.stop) -def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: +def get_running_loop() -> asyncio.AbstractEventLoop | None: """Check if an event loop is already running.""" with contextlib.suppress(RuntimeError): return asyncio.get_running_loop() diff --git a/src/zeroconf/_utils/ipaddress.py b/src/zeroconf/_utils/ipaddress.py index 64cdfb63..d172d0c9 100644 --- a/src/zeroconf/_utils/ipaddress.py +++ b/src/zeroconf/_utils/ipaddress.py @@ -20,9 +20,11 @@ USA """ +from __future__ import annotations + from functools import cache, lru_cache from ipaddress import AddressValueError, IPv4Address, IPv6Address, NetmaskValueError -from typing import Any, Optional, Union +from typing import Any from .._dns import DNSAddress from ..const import _TYPE_AAAA @@ -99,8 +101,8 @@ def is_loopback(self) -> bool: @lru_cache(maxsize=512) def _cached_ip_addresses( - address: Union[str, bytes, int], -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: + address: str | bytes | int, +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Cache IP addresses.""" try: return ZeroconfIPv4Address(address) @@ -119,7 +121,7 @@ def _cached_ip_addresses( def get_ip_address_object_from_record( record: DNSAddress, -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Get the IP address object from the record.""" if record.type == _TYPE_AAAA and record.scope_id: return ip_bytes_and_scope_to_address(record.address, record.scope_id) @@ -128,7 +130,7 @@ def get_ip_address_object_from_record( def ip_bytes_and_scope_to_address( address: bytes_, scope: int_ -) -> Optional[Union[ZeroconfIPv4Address, ZeroconfIPv6Address]]: +) -> ZeroconfIPv4Address | ZeroconfIPv6Address | None: """Convert the bytes and scope to an IP address object.""" base_address = cached_ip_addresses_wrapper(address) if base_address is not None and base_address.is_link_local: @@ -137,7 +139,7 @@ def ip_bytes_and_scope_to_address( return base_address -def str_without_scope_id(addr: Union[ZeroconfIPv4Address, ZeroconfIPv6Address]) -> str: +def str_without_scope_id(addr: ZeroconfIPv4Address | ZeroconfIPv6Address) -> str: """Return the string representation of the address without the scope id.""" if addr.version == 6: address_str = str(addr) diff --git a/src/zeroconf/_utils/name.py b/src/zeroconf/_utils/name.py index cda01b28..de35f7af 100644 --- a/src/zeroconf/_utils/name.py +++ b/src/zeroconf/_utils/name.py @@ -20,8 +20,9 @@ USA """ +from __future__ import annotations + from functools import lru_cache -from typing import Set from .._exceptions import BadTypeInNameException from ..const import ( @@ -162,7 +163,7 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis return service_name + trailer -def possible_types(name: str) -> Set[str]: +def possible_types(name: str) -> set[str]: """Build a set of all possible types from a fully qualified name.""" labels = name.split(".") label_count = len(labels) diff --git a/src/zeroconf/_utils/net.py b/src/zeroconf/_utils/net.py index 7298bec4..3cc4336b 100644 --- a/src/zeroconf/_utils/net.py +++ b/src/zeroconf/_utils/net.py @@ -20,13 +20,15 @@ USA """ +from __future__ import annotations + import enum import errno import ipaddress import socket import struct import sys -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Sequence, Tuple, Union, cast import ifaddr @@ -70,11 +72,11 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) -def get_all_addresses() -> List[str]: +def get_all_addresses() -> list[str]: return list({addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4}) -def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: +def get_all_addresses_v6() -> list[tuple[tuple[str, int, int], int]]: # IPv6 multicast uses positive indexes for interfaces # TODO: What about multi-address interfaces? return list( @@ -82,7 +84,7 @@ def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: ) -def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: +def ip6_to_address_and_index(adapters: list[Any], ip: str) -> tuple[tuple[str, int, int], int]: if "%" in ip: ip = ip[: ip.index("%")] # Strip scope_id. ipaddr = ipaddress.ip_address(ip) @@ -98,7 +100,7 @@ def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, i raise RuntimeError(f"No adapter found for IP address {ip}") -def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: +def interface_index_to_ip6_address(adapters: list[Any], index: int) -> tuple[str, int, int]: for adapter in adapters: if adapter.index == index: for adapter_ip in adapter.ips: @@ -110,8 +112,8 @@ def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str def ip6_addresses_to_indexes( - interfaces: Sequence[Union[str, int, Tuple[Tuple[str, int, int], int]]], -) -> List[Tuple[Tuple[str, int, int], int]]: + interfaces: Sequence[str | int | tuple[tuple[str, int, int], int]], +) -> list[tuple[tuple[str, int, int], int]]: """Convert IPv6 interface addresses to interface indexes. IPv4 addresses are ignored. @@ -133,14 +135,14 @@ def ip6_addresses_to_indexes( def normalize_interface_choice( choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: +) -> list[str | tuple[tuple[str, int, int], int]]: """Convert the interfaces choice into internal representation. :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). :param ip_address: IP version to use (ignored if `choice` is a list). :returns: List of IP addresses (for IPv4) and indexes (for IPv6). """ - result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = [] + result: list[str | tuple[tuple[str, int, int], int]] = [] if choice is InterfaceChoice.Default: if ip_version != IPVersion.V4Only: # IPv6 multicast uses interface 0 to mean the default @@ -196,7 +198,7 @@ def set_so_reuseport_if_available(s: socket.socket) -> None: def set_mdns_port_socket_options_for_ip_version( s: socket.socket, - bind_addr: Union[Tuple[str], Tuple[str, int, int]], + bind_addr: tuple[str] | tuple[str, int, int], ip_version: IPVersion, ) -> None: """Set ttl/hops and loop for mdns port.""" @@ -219,11 +221,11 @@ def set_mdns_port_socket_options_for_ip_version( def new_socket( - bind_addr: Union[Tuple[str], Tuple[str, int, int]], + bind_addr: tuple[str] | tuple[str, int, int], port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False, -) -> Optional[socket.socket]: +) -> socket.socket | None: log.debug( "Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r", port, @@ -265,7 +267,7 @@ def new_socket( def add_multicast_member( listen_socket: socket.socket, - interface: Union[str, Tuple[Tuple[str, int, int], int]], + interface: str | tuple[tuple[str, int, int], int], ) -> bool: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, tuple) @@ -331,9 +333,9 @@ def add_multicast_member( def new_respond_socket( - interface: Union[str, Tuple[Tuple[str, int, int], int]], + interface: str | tuple[tuple[str, int, int], int], apple_p2p: bool = False, -) -> Optional[socket.socket]: +) -> socket.socket | None: is_v6 = isinstance(interface, tuple) respond_socket = new_socket( ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), @@ -360,7 +362,7 @@ def create_sockets( unicast: bool = False, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False, -) -> Tuple[Optional[socket.socket], List[socket.socket]]: +) -> tuple[socket.socket | None, list[socket.socket]]: if unicast: listen_socket = None else: diff --git a/src/zeroconf/_utils/time.py b/src/zeroconf/_utils/time.py index 055e0658..4057f063 100644 --- a/src/zeroconf/_utils/time.py +++ b/src/zeroconf/_utils/time.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import time _float = float diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index 926ef509..2a29a4bb 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -20,10 +20,12 @@ USA """ +from __future__ import annotations + import asyncio import contextlib from types import TracebackType # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Awaitable, Callable from ._core import Zeroconf from ._dns import DNSQuestionType @@ -63,14 +65,14 @@ class AsyncServiceBrowser(_ServiceBrowserBase): def __init__( self, - zeroconf: "Zeroconf", - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, + zeroconf: Zeroconf, + type_: str | list, + handlers: ServiceListener | list[Callable[..., None]] | None = None, + listener: ServiceListener | None = None, + addr: str | None = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, - question_type: Optional[DNSQuestionType] = None, + question_type: DNSQuestionType | None = None, ) -> None: super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) self._async_start() @@ -79,15 +81,15 @@ async def async_cancel(self) -> None: """Cancel the browser.""" self._async_cancel() - async def __aenter__(self) -> "AsyncServiceBrowser": + async def __aenter__(self) -> AsyncServiceBrowser: return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: await self.async_cancel() return None @@ -98,11 +100,11 @@ class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): @classmethod async def async_find( cls, - aiozc: Optional["AsyncZeroconf"] = None, - timeout: Union[int, float] = 5, + aiozc: AsyncZeroconf | None = None, + timeout: int | float = 5, interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: + ip_version: IPVersion | None = None, + ) -> tuple[str, ...]: """ Return all of the advertised services on any local networks. @@ -145,9 +147,9 @@ def __init__( self, interfaces: InterfacesType = InterfaceChoice.All, unicast: bool = False, - ip_version: Optional[IPVersion] = None, + ip_version: IPVersion | None = None, apple_p2p: bool = False, - zc: Optional[Zeroconf] = None, + zc: Zeroconf | None = None, ) -> None: """Creates an instance of the Zeroconf class, establishing multicast communications, and listening. @@ -170,12 +172,12 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) - self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} + self.async_browsers: dict[ServiceListener, AsyncServiceBrowser] = {} async def async_register_service( self, info: ServiceInfo, - ttl: Optional[int] = None, + ttl: int | None = None, allow_name_change: bool = False, cooperating_responders: bool = False, strict: bool = True, @@ -236,8 +238,8 @@ async def async_get_service_info( type_: str, name: str, timeout: int = 3000, - question_type: Optional[DNSQuestionType] = None, - ) -> Optional[AsyncServiceInfo]: + question_type: DNSQuestionType | None = None, + ) -> AsyncServiceInfo | None: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds. @@ -268,14 +270,14 @@ async def async_remove_all_service_listeners(self) -> None: *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers)) ) - async def __aenter__(self) -> "AsyncZeroconf": + async def __aenter__(self) -> AsyncZeroconf: return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: await self.async_close() return None diff --git a/src/zeroconf/const.py b/src/zeroconf/const.py index d84cb73b..3b4b3abc 100644 --- a/src/zeroconf/const.py +++ b/src/zeroconf/const.py @@ -20,6 +20,8 @@ USA """ +from __future__ import annotations + import re import socket diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py index e69de29b..9d48db4f 100644 --- a/tests/benchmarks/__init__.py +++ b/tests/benchmarks/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tests/benchmarks/helpers.py b/tests/benchmarks/helpers.py index e701e0b6..4f5f7d66 100644 --- a/tests/benchmarks/helpers.py +++ b/tests/benchmarks/helpers.py @@ -1,5 +1,7 @@ """Benchmark helpers.""" +from __future__ import annotations + import socket from zeroconf import DNSAddress, DNSOutgoing, DNSService, DNSText, const diff --git a/tests/benchmarks/test_cache.py b/tests/benchmarks/test_cache.py index 6fde9438..7813f679 100644 --- a/tests/benchmarks/test_cache.py +++ b/tests/benchmarks/test_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pytest_codspeed import BenchmarkFixture from zeroconf import DNSCache, DNSPointer, current_time_millis diff --git a/tests/benchmarks/test_incoming.py b/tests/benchmarks/test_incoming.py index e0552f3a..6d31e51e 100644 --- a/tests/benchmarks/test_incoming.py +++ b/tests/benchmarks/test_incoming.py @@ -1,5 +1,7 @@ """Benchmark for DNSIncoming.""" +from __future__ import annotations + import socket from pytest_codspeed import BenchmarkFixture diff --git a/tests/benchmarks/test_outgoing.py b/tests/benchmarks/test_outgoing.py index 69de540e..a8db4d6f 100644 --- a/tests/benchmarks/test_outgoing.py +++ b/tests/benchmarks/test_outgoing.py @@ -1,5 +1,7 @@ """Benchmark for DNSOutgoing.""" +from __future__ import annotations + from pytest_codspeed import BenchmarkFixture from zeroconf._protocol.outgoing import State diff --git a/tests/benchmarks/test_send.py b/tests/benchmarks/test_send.py index 7a6d664b..596662a2 100644 --- a/tests/benchmarks/test_send.py +++ b/tests/benchmarks/test_send.py @@ -1,5 +1,7 @@ """Benchmark for sending packets.""" +from __future__ import annotations + import pytest from pytest_codspeed import BenchmarkFixture diff --git a/tests/benchmarks/test_txt_properties.py b/tests/benchmarks/test_txt_properties.py index ad75ab35..72afa0b6 100644 --- a/tests/benchmarks/test_txt_properties.py +++ b/tests/benchmarks/test_txt_properties.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pytest_codspeed import BenchmarkFixture from zeroconf import ServiceInfo diff --git a/tests/conftest.py b/tests/conftest.py index ba49cef6..1f323785 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ """conftest for zeroconf tests.""" +from __future__ import annotations + import threading from unittest.mock import patch diff --git a/tests/services/__init__.py b/tests/services/__init__.py index 30920c6a..584a74ec 100644 --- a/tests/services/__init__.py +++ b/tests/services/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 5268c341..986df64e 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._services.browser.""" +from __future__ import annotations + import asyncio import logging import os @@ -863,7 +865,7 @@ def test_legacy_record_update_listener(): class LegacyRecordUpdateListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def update_record(self, zc: "Zeroconf", now: float, record: r.DNSRecord) -> None: + def update_record(self, zc: Zeroconf, now: float, record: r.DNSRecord) -> None: nonlocal updates updates.append(record) diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 999e422c..c3ae3a28 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._services.registry.""" +from __future__ import annotations + import socket import unittest diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 811b22c5..63292246 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._services.types.""" +from __future__ import annotations + import logging import os import socket diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 86e9e8c7..40ecf816 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -1,5 +1,7 @@ """Unit tests for aio.py.""" +from __future__ import annotations + import asyncio import logging import os diff --git a/tests/test_cache.py b/tests/test_cache.py index f5304cef..9d55435d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._cache.""" +from __future__ import annotations + import logging import unittest.mock from heapq import heapify, heappop diff --git a/tests/test_circular_imports.py b/tests/test_circular_imports.py index 8bd443a4..74ed1f12 100644 --- a/tests/test_circular_imports.py +++ b/tests/test_circular_imports.py @@ -1,5 +1,7 @@ """Test to check for circular imports.""" +from __future__ import annotations + import asyncio import sys diff --git a/tests/test_dns.py b/tests/test_dns.py index 491e2ca7..246c8dcf 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._dns.""" +from __future__ import annotations + import logging import os import socket diff --git a/tests/test_engine.py b/tests/test_engine.py index 23a03949..b7a94c86 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._engine""" +from __future__ import annotations + import asyncio import itertools import logging diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index cf004d2c..ab181db1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._exceptions""" +from __future__ import annotations + import logging import unittest.mock diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 80ee7f40..fd0e689c 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._handlers""" +from __future__ import annotations + import asyncio import logging import os @@ -1371,7 +1373,7 @@ async def test_record_update_manager_add_listener_callsback_existing_records(): class MyListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def async_update_records(self, zc: "Zeroconf", now: float, records: list[r.RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float, records: list[r.RecordUpdate]) -> None: """Update multiple records in one shot.""" updated.extend(records) @@ -1973,7 +1975,7 @@ async def test_add_listener_warns_when_not_using_record_update_listener(caplog): class MyListener: """A RecordUpdateListener that does not implement update_records.""" - def async_update_records(self, zc: "Zeroconf", now: float, records: list[r.RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float, records: list[r.RecordUpdate]) -> None: """Update multiple records in one shot.""" updated.extend(records) @@ -2005,7 +2007,7 @@ async def test_async_updates_iteration_safe(): class OtherListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def async_update_records(self, zc: "Zeroconf", now: float, records: list[r.RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float, records: list[r.RecordUpdate]) -> None: """Update multiple records in one shot.""" updated.extend(records) @@ -2014,7 +2016,7 @@ def async_update_records(self, zc: "Zeroconf", now: float, records: list[r.Recor class ListenerThatAddsListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def async_update_records(self, zc: "Zeroconf", now: float, records: list[r.RecordUpdate]) -> None: + def async_update_records(self, zc: Zeroconf, now: float, records: list[r.RecordUpdate]) -> None: """Update multiple records in one shot.""" updated.extend(records) zc.async_add_listener(other, None) diff --git a/tests/test_history.py b/tests/test_history.py index 606362d1..4c9836ce 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -1,5 +1,7 @@ """Unit tests for _history.py.""" +from __future__ import annotations + import zeroconf as r import zeroconf.const as const from zeroconf._history import QuestionHistory diff --git a/tests/test_init.py b/tests/test_init.py index 78fb1e37..a36ff8fd 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf.py""" +from __future__ import annotations + import logging import socket import time diff --git a/tests/test_logger.py b/tests/test_logger.py index ecaf9dd0..aa5b5382 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,5 +1,7 @@ """Unit tests for logger.py.""" +from __future__ import annotations + import logging from unittest.mock import call, patch diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1397c60c..08d7e600 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._protocol""" +from __future__ import annotations + import copy import logging import os diff --git a/tests/test_services.py b/tests/test_services.py index 992070e2..e93174cc 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._services.""" +from __future__ import annotations + import logging import os import socket diff --git a/tests/test_updates.py b/tests/test_updates.py index 1af85736..a057486c 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._updates.""" +from __future__ import annotations + import logging import socket import time @@ -45,7 +47,7 @@ def test_legacy_record_update_listener(): class LegacyRecordUpdateListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def update_record(self, zc: "Zeroconf", now: float, record: r.DNSRecord) -> None: + def update_record(self, zc: Zeroconf, now: float, record: r.DNSRecord) -> None: nonlocal updates updates.append(record) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 30920c6a..584a74ec 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -19,3 +19,5 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from __future__ import annotations diff --git a/tests/utils/test_ipaddress.py b/tests/utils/test_ipaddress.py index c6f63aaf..4379f458 100644 --- a/tests/utils/test_ipaddress.py +++ b/tests/utils/test_ipaddress.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._utils.ipaddress.""" +from __future__ import annotations + from zeroconf import const from zeroconf._dns import DNSAddress from zeroconf._utils import ipaddress diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py index 6f2c6b13..1feb7713 100644 --- a/tests/utils/test_name.py +++ b/tests/utils/test_name.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._utils.name.""" +from __future__ import annotations + import socket import pytest diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 17212af2..489a6460 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -1,5 +1,7 @@ """Unit tests for zeroconf._utils.net.""" +from __future__ import annotations + import errno import socket import unittest