Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 33 additions & 37 deletions src/zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
cast,
)

from .._dns import DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord
from .._dns import DNSPointer, DNSQuestion, DNSQuestionType
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._services import (
Expand Down Expand Up @@ -383,50 +383,46 @@ def _enqueue_callback(
):
self._pending_handlers[key] = state_change

def _async_process_record_update(
self, now: float, record: DNSRecord, old_record: Optional[DNSRecord]
) -> None:
"""Process a single record update from a batch of updates."""
record_type = record.type

if record_type is _TYPE_PTR:
if TYPE_CHECKING:
record = cast(DNSPointer, record)
for type_ in self.types.intersection(cached_possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
else:
self.reschedule_type(type_, now, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
return

# If its expired or already exists in the cache it cannot be updated.
if old_record or record.is_expired(now):
return

if record_type in _ADDRESS_RECORD_TYPES:
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
return

for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Callback invoked by Zeroconf when new information arrives.

Updates information required by browser in the Zeroconf cache.

Ensures that there is are no unecessary duplicates in the list.
Ensures that there is are no unnecessary duplicates in the list.

This method will be run in the event loop.
"""
for record in records:
self._async_process_record_update(now, record[0], record[1])
for record_update in records:
record, old_record = record_update
record_type = record.type

if record_type is _TYPE_PTR:
if TYPE_CHECKING:
record = cast(DNSPointer, record)
for type_ in self.types.intersection(cached_possible_types(record.name)):
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias)
else:
expire_time = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
self.reschedule_type(type_, now, expire_time)
continue

# If its expired or already exists in the cache it cannot be updated.
if old_record or record.is_expired(now):
continue

if record_type in _ADDRESS_RECORD_TYPES:
# Iterate through the DNSCache and callback any services that use this address
for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
continue

for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)

@abstractmethod
def async_update_records_complete(self) -> None:
Expand Down
22 changes: 2 additions & 20 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,35 +410,17 @@ def _set_ipv4_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None:
else:
self._ipv4_addresses = self._get_ip_addresses_from_cache_lifo(zc, now, _TYPE_A)

def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
"""Updates service information from a DNS record.

This method is deprecated and will be removed in a future version.
update_records should be implemented instead.

This method will be run in the event loop.
"""
if record is not None:
self._process_record_threadsafe(zc, record, now)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Updates service information from a DNS record.

This method will be run in the event loop.
"""
new_records_futures = self._new_records_futures
if self._process_records_threadsafe(zc, now, records) and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.

Returns True if new records were added.
"""
updated: bool = False
for record_update in records:
updated |= self._process_record_threadsafe(zc, record_update.new, now)
return updated
if updated and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_record_threadsafe(self, zc: 'Zeroconf', record: DNSRecord, now: float) -> bool:
"""Thread safe record updating.
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
All records that are received in a single packet are passed
to update_records.

This implementation is a compatiblity shim to ensure older code
This implementation is a compatibility shim to ensure older code
that uses RecordUpdateListener as a base class will continue to
get calls to update_record. This method will raise
NotImplementedError in a future version.
Expand Down
165 changes: 100 additions & 65 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

import zeroconf as r
from zeroconf import DNSAddress, const
from zeroconf import DNSAddress, RecordUpdate, const
from zeroconf._services import info
from zeroconf._services.info import ServiceInfo
from zeroconf._utils.net import IPVersion
Expand Down Expand Up @@ -68,89 +68,119 @@ def test_service_info_rejects_non_matching_updates(self):
service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
)
# Verify backwards compatiblity with calling with None
info.update_record(zc, now, None)
info.async_update_records(zc, now, [])
# Matching updates
info.update_record(
info.async_update_records(
zc,
now,
r.DNSText(
service_name,
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
),
[
RecordUpdate(
r.DNSText(
service_name,
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
),
None,
)
],
)
assert info.properties[b"ci"] == b"2"
info.update_record(
info.async_update_records(
zc,
now,
r.DNSService(
service_name,
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
0,
0,
80,
'ASH-2.local.',
),
[
RecordUpdate(
r.DNSService(
service_name,
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
0,
0,
80,
'ASH-2.local.',
),
None,
)
],
)
assert info.server_key == 'ash-2.local.'
assert info.server == 'ASH-2.local.'
new_address = socket.inet_aton("10.0.1.3")
info.update_record(
info.async_update_records(
zc,
now,
r.DNSAddress(
'ASH-2.local.',
const._TYPE_A,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
new_address,
),
[
RecordUpdate(
r.DNSAddress(
'ASH-2.local.',
const._TYPE_A,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
new_address,
),
None,
)
],
)
assert new_address in info.addresses
# Non-matching updates
info.update_record(
info.async_update_records(
zc,
now,
r.DNSText(
"incorrect.name.",
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
),
[
RecordUpdate(
r.DNSText(
"incorrect.name.",
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
),
None,
)
],
)
assert info.properties[b"ci"] == b"2"
info.update_record(
info.async_update_records(
zc,
now,
r.DNSService(
"incorrect.name.",
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
0,
0,
80,
'ASH-2.local.',
),
[
RecordUpdate(
r.DNSService(
"incorrect.name.",
const._TYPE_SRV,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
0,
0,
80,
'ASH-2.local.',
),
None,
)
],
)
assert info.server_key == 'ash-2.local.'
assert info.server == 'ASH-2.local.'
new_address = socket.inet_aton("10.0.1.4")
info.update_record(
info.async_update_records(
zc,
now,
r.DNSAddress(
"incorrect.name.",
const._TYPE_A,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
new_address,
),
[
RecordUpdate(
r.DNSAddress(
"incorrect.name.",
const._TYPE_A,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
new_address,
),
None,
)
],
)
assert new_address not in info.addresses
zc.close()
Expand All @@ -169,16 +199,21 @@ def test_service_info_rejects_expired_records(self):
service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
)
# Matching updates
info.update_record(
info.async_update_records(
zc,
now,
r.DNSText(
service_name,
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
),
[
RecordUpdate(
r.DNSText(
service_name,
const._TYPE_TXT,
const._CLASS_IN | const._CLASS_UNIQUE,
ttl,
b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
),
None,
)
],
)
assert info.properties[b"ci"] == b"2"
# Expired record
Expand All @@ -190,7 +225,7 @@ def test_service_info_rejects_expired_records(self):
b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
)
expired_record.set_created_ttl(1000, 1)
info.update_record(zc, now, expired_record)
info.async_update_records(zc, now, [RecordUpdate(expired_record, None)])
assert info.properties[b"ci"] == b"2"
zc.close()

Expand Down