diff --git a/tests/test_handlers.py b/tests/test_handlers.py index f9e7639ea..92d95fa2f 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._handlers """ +import asyncio import logging import pytest import socket @@ -834,3 +835,83 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): # unregister zc.registry.remove(info) await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_cache_flush_bit(): + """Test that the cache flush bit sets the TTL to one for matching records.""" + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + zc.cache.async_add_records([info.dns_pointer(), a_record, info.dns_text(), info.dns_service()]) + + info.addresses = [socket.inet_aton("10.0.1.5"), socket.inet_aton("10.0.1.6")] + new_records = info.dns_addresses() + for new_record in new_records: + assert new_record.unique is True + + original_a_record = zc.cache.get(a_record) + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert zc.cache.get(a_record) is original_a_record + assert original_a_record.ttl != 1 + for record in new_records: + assert zc.cache.get(record) is not None + + original_a_record.created = current_time_millis() - 1001 + + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert original_a_record.ttl == 1 + for record in new_records: + assert zc.cache.get(record) is not None + + cached_records = [zc.cache.get(record) for record in new_records] + for record in cached_records: + record.created = current_time_millis() - 1001 + + fresh_address = socket.inet_aton("4.4.4.4") + info.addresses = [fresh_address] + # Do the run within 1s to verify the two new records get marked as expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in info.dns_addresses(): + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + for record in cached_records: + assert record.ttl == 1 + + for entry in zc.cache.get_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + if entry.address == fresh_address: + assert entry.ttl > 1 + else: + assert entry.ttl == 1 + + # Wait for the ttl 1 records to expire + await asyncio.sleep(1.01) + + loaded_info = r.ServiceInfo(type_, registration_name) + loaded_info.load_from_cache(zc) + assert loaded_info.addresses == info.addresses + + await aiozc.async_close()