Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/resolve_address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python

"""Example of resolving a name to an IP address."""

import asyncio
import logging
import sys

from zeroconf import AddressResolver, IPVersion
from zeroconf.asyncio import AsyncZeroconf


async def resolve_name(name: str) -> None:
aiozc = AsyncZeroconf()
await aiozc.zeroconf.async_wait_for_start()
resolver = AddressResolver(name)
if await resolver.async_request(aiozc.zeroconf, 3000):
print(f"{name} IP addresses:", resolver.ip_addresses_by_version(IPVersion.All))
else:
print(f"Name {name} not resolved")
await aiozc.async_close()


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
argv = sys.argv.copy()
if "--debug" in argv:
logging.getLogger("zeroconf").setLevel(logging.DEBUG)
argv.remove("--debug")

if len(argv) < 2 or not argv[1]:
raise ValueError("Usage: resolve_address.py [--debug] <name>")

name = argv[1]
if not name.endswith("."):
name += "."

asyncio.run(resolve_name(name))
3 changes: 3 additions & 0 deletions src/zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from ._services.browser import ServiceBrowser
from ._services.info import ( # noqa # import needed for backwards compat
ServiceInfo,
AddressResolver,
AddressResolverIPv4,
AddressResolverIPv6,
instance_name_from_service_info,
)
from ._services.registry import ( # noqa # import needed for backwards compat
Expand Down
13 changes: 13 additions & 0 deletions src/zeroconf/_services/info.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ from .._utils.ipaddress cimport (
)
from .._utils.time cimport current_time_millis

cdef cython.set _TYPE_AAAA_RECORDS
cdef cython.set _TYPE_A_RECORDS
cdef cython.set _TYPE_A_AAAA_RECORDS

cdef object _resolve_all_futures_to_none

Expand Down Expand Up @@ -75,6 +78,7 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef public DNSText _dns_text_cache
cdef public cython.list _dns_address_cache
cdef public cython.set _get_address_and_nsec_records_cache
cdef public cython.set _query_record_types

@cython.locals(record_update=RecordUpdate, update=bint, cache=DNSCache)
cpdef void async_update_records(self, object zc, double now, cython.list records)
Expand Down Expand Up @@ -155,3 +159,12 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef double _get_initial_delay(self)

cdef double _get_random_delay(self)

cdef class AddressResolver(ServiceInfo):
pass

cdef class AddressResolverIPv6(ServiceInfo):
pass

cdef class AddressResolverIPv4(ServiceInfo):
pass
76 changes: 64 additions & 12 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
# the A/AAAA/SRV records for a host.
_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)

_TYPE_AAAA_RECORDS = {_TYPE_AAAA}
_TYPE_A_RECORDS = {_TYPE_A}
_TYPE_A_AAAA_RECORDS = {_TYPE_A, _TYPE_AAAA}

bytes_ = bytes
float_ = float
int_ = int
Expand Down Expand Up @@ -146,6 +150,7 @@ class ServiceInfo(RecordUpdateListener):
"_name",
"_new_records_futures",
"_properties",
"_query_record_types",
"host_ttl",
"interface_index",
"key",
Expand Down Expand Up @@ -210,6 +215,7 @@ def __init__(
self._dns_service_cache: Optional[DNSService] = None
self._dns_text_cache: Optional[DNSText] = None
self._get_address_and_nsec_records_cache: Optional[Set[DNSRecord]] = None
self._query_record_types = {_TYPE_SRV, _TYPE_TXT, _TYPE_A, _TYPE_AAAA}

@property
def name(self) -> str:
Expand Down Expand Up @@ -917,18 +923,22 @@ def _generate_request_query(
cache = zc.cache
history = zc.question_history
qu_question = question_type is QU_QUESTION
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_SRV, _CLASS_IN, True
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_TXT, _CLASS_IN, True
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_A, _CLASS_IN, False
)
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_AAAA, _CLASS_IN, False
)
if _TYPE_SRV in self._query_record_types:
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_SRV, _CLASS_IN, True
)
if _TYPE_TXT in self._query_record_types:
self._add_question_with_known_answers(
out, qu_question, history, cache, now, name, _TYPE_TXT, _CLASS_IN, True
)
if _TYPE_A in self._query_record_types:
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_A, _CLASS_IN, False
)
if _TYPE_AAAA in self._query_record_types:
self._add_question_with_known_answers(
out, qu_question, history, cache, now, server, _TYPE_AAAA, _CLASS_IN, False
)
return out

def __repr__(self) -> str:
Expand All @@ -954,3 +964,45 @@ def __repr__(self) -> str:

class AsyncServiceInfo(ServiceInfo):
"""An async version of ServiceInfo."""


class AddressResolver(ServiceInfo):
"""Resolve a host name to an IP address."""

def __init__(self, server: str) -> None:
"""Initialize the AddressResolver."""
super().__init__(server, server, server=server)
self._query_record_types = _TYPE_A_AAAA_RECORDS

@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return bool(self._ipv4_addresses) or bool(self._ipv6_addresses)


class AddressResolverIPv6(ServiceInfo):
"""Resolve a host name to an IPv6 address."""

def __init__(self, server: str) -> None:
"""Initialize the AddressResolver."""
super().__init__(server, server, server=server)
self._query_record_types = _TYPE_AAAA_RECORDS

@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return bool(self._ipv6_addresses)


class AddressResolverIPv4(ServiceInfo):
"""Resolve a host name to an IPv4 address."""

def __init__(self, server: str) -> None:
"""Initialize the AddressResolver."""
super().__init__(server, server, server=server)
self._query_record_types = _TYPE_A_RECORDS

@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return bool(self._ipv4_addresses)
74 changes: 74 additions & 0 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,3 +1797,77 @@ async def test_service_info_nsec_records():
assert nsec_record.type == const._TYPE_NSEC
assert nsec_record.ttl == 50
assert nsec_record.rdtypes == [const._TYPE_A, const._TYPE_AAAA]


@pytest.mark.asyncio
async def test_address_resolver():
"""Test that the address resolver works."""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
await aiozc.zeroconf.async_wait_for_start()
resolver = r.AddressResolver("address_resolver_test.local.")
resolve_task = asyncio.create_task(resolver.async_request(aiozc.zeroconf, 3000))
outgoing = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
outgoing.add_answer_at_time(
r.DNSAddress(
"address_resolver_test.local.",
const._TYPE_A,
const._CLASS_IN,
10000,
b"\x7f\x00\x00\x01",
),
0,
)

aiozc.zeroconf.async_send(outgoing)
assert await resolve_task
assert resolver.addresses == [b"\x7f\x00\x00\x01"]


@pytest.mark.asyncio
async def test_address_resolver_ipv4():
"""Test that the IPv4 address resolver works."""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
await aiozc.zeroconf.async_wait_for_start()
resolver = r.AddressResolverIPv4("address_resolver_test_ipv4.local.")
resolve_task = asyncio.create_task(resolver.async_request(aiozc.zeroconf, 3000))
outgoing = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
outgoing.add_answer_at_time(
r.DNSAddress(
"address_resolver_test_ipv4.local.",
const._TYPE_A,
const._CLASS_IN,
10000,
b"\x7f\x00\x00\x01",
),
0,
)

aiozc.zeroconf.async_send(outgoing)
assert await resolve_task
assert resolver.addresses == [b"\x7f\x00\x00\x01"]


@pytest.mark.asyncio
@unittest.skipIf(not has_working_ipv6(), "Requires IPv6")
@unittest.skipIf(os.environ.get("SKIP_IPV6"), "IPv6 tests disabled")
async def test_address_resolver_ipv6():
"""Test that the IPv6 address resolver works."""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
await aiozc.zeroconf.async_wait_for_start()
resolver = r.AddressResolverIPv6("address_resolver_test_ipv6.local.")
resolve_task = asyncio.create_task(resolver.async_request(aiozc.zeroconf, 3000))
outgoing = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
outgoing.add_answer_at_time(
r.DNSAddress(
"address_resolver_test_ipv6.local.",
const._TYPE_AAAA,
const._CLASS_IN,
10000,
socket.inet_pton(socket.AF_INET6, "fe80::52e:c2f2:bc5f:e9c6"),
),
0,
)

aiozc.zeroconf.async_send(outgoing)
assert await resolve_task
assert resolver.ip_addresses_by_version(IPVersion.All) == [ip_address("fe80::52e:c2f2:bc5f:e9c6")]