Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def test_order(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add(record1)
cache.add(record2)
cache.add_records([record1, record2])
entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN)
cached_record = cache.get(entry)
assert cached_record == record2
Expand All @@ -46,13 +45,11 @@ def test_adding_same_record_to_cache_different_ttls(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
cache = r.DNSCache()
cache.add(record1)
cache.add(record2)
cache.add_records([record1, record2])
entry = r.DNSEntry(record2)
cached_record = cache.get(entry)
assert cached_record == record2

@unittest.skip('This bug in the implementation needs to be fixed.')
def test_adding_same_record_to_cache_different_ttls(self):
"""Verify we only get one record back.

Expand All @@ -64,34 +61,26 @@ def test_adding_same_record_to_cache_different_ttls(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
cache = r.DNSCache()
cache.add(record1)
cache.add(record2)
cache.add_records([record1, record2])
cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)
assert cached_records == [record2]

def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add(record1)
cache.add(record2)
cache.add_records([record1, record2])
assert 'a' in cache.cache
cache.remove(record1)
cache.remove(record2)
cache.remove_records([record1, record2])
assert 'a' not in cache.cache

def test_cache_empty_multiple_calls_does_not_throw(self):
def test_cache_empty_multiple_calls(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add(record1)
cache.add(record2)
cache.add_records([record1, record2])
assert 'a' in cache.cache
cache.remove(record1)
cache.remove(record2)
# Ensure multiple removes does not throw
cache.remove(record1)
cache.remove(record2)
cache.remove_records([record1, record2])
assert 'a' not in cache.cache


Expand Down
103 changes: 71 additions & 32 deletions zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,46 +27,83 @@
from .const import _TYPE_PTR


_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]


def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
"""Remove a key from a DNSRecord cache

This function must be run in from event loop.
"""
del cache[key][entry]
if not cache[key]:
del cache[key]


class DNSCache:
"""A cache of DNS entries."""

def __init__(self) -> None:
self.cache: Dict[str, List[DNSRecord]] = {}
self.service_cache: Dict[str, List[DNSRecord]] = {}
self.cache: _DNSRecordCacheType = {}
self.service_cache: _DNSRecordCacheType = {}

# Functions prefixed with are NOT threadsafe and must
# be run in the event loop.

def add(self, entry: DNSRecord) -> None:
"""Adds an entry"""
# Insert last in list, get will return newest entry
# iteration will result in last update winning
self.cache.setdefault(entry.key, []).append(entry)
"""Adds an entry.

This function must be run in from event loop.
"""
# Previously storage of records was implemented as a list
# instead a dict. Since DNSRecords are now hashable, the implementation
# uses a dict to ensure that adding a new record to the cache
# replaces any existing records that are __eq__ to each other which
# removes the risk that accessing the cache from the wrong
# direction would return the old incorrect entry.
self.cache.setdefault(entry.key, {})[entry] = entry
if isinstance(entry, DNSService):
self.service_cache.setdefault(entry.server, []).append(entry)
self.service_cache.setdefault(entry.server, {})[entry] = entry

def add_records(self, entries: Iterable[DNSRecord]) -> None:
"""Add multiple records."""
"""Add multiple records.

This function must be run in from event loop.
"""
for entry in entries:
self.add(entry)

def remove(self, entry: DNSRecord) -> None:
"""Removes an entry."""
"""Removes an entry.

This function must be run in from event loop.
"""
if isinstance(entry, DNSService):
DNSCache.remove_key(self.service_cache, entry.server, entry)
DNSCache.remove_key(self.cache, entry.key, entry)
_remove_key(self.service_cache, entry.server, entry)
_remove_key(self.cache, entry.key, entry)

def remove_records(self, entries: Iterable[DNSRecord]) -> None:
"""Remove multiple records."""
"""Remove multiple records.

This function must be run in from event loop.
"""
for entry in entries:
self.remove(entry)

@staticmethod
def remove_key(cache: dict, key: str, entry: DNSRecord) -> None:
"""Forgiving remove of a cache key."""
try:
cache[key].remove(entry)
if not cache[key]:
del cache[key]
except (KeyError, ValueError):
pass
def expire(self, now: float) -> Iterable[DNSRecord]:
"""Purge expired entries from the cache.

This function must be run in from event loop.
"""
for name in self.names():
for record in self.entries_with_name(name):
if record.is_expired(now):
self.remove(record)
yield record

# The below functions are threadsafe and do not need to be run in the
# event loop, however they all make copies so they significantly
# inefficent

def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
"""Gets an entry by key. Will return None if there is no
Expand All @@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
return None

def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]:
"""Gets the first matching entry by details. Returns None if no entries match."""
"""Gets the first matching entry by details. Returns None if no entries match.

Calling this function is not recommended as it will only
return one record even if there are multiple entries.

For example if there are multiple A or AAAA addresses this
function will return the last one that was added to the cache
which may not be the one you expect.

Use get_all_by_details instead.
"""
return self.get(DNSEntry(name, type_, class_))

def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
Expand All @@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco

def entries_with_server(self, server: str) -> List[DNSRecord]:
"""Returns a list of entries whose server matches the name."""
return self.service_cache.get(server, [])[:]
return list(self.service_cache.get(server, {}))

def entries_with_name(self, name: str) -> List[DNSRecord]:
"""Returns a list of entries whose key matches the name."""
return self.cache.get(name.lower(), [])[:]
return list(self.cache.get(name.lower(), {}))

def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
now = current_time_millis()
Expand All @@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
def names(self) -> List[str]:
"""Return a copy of the list of current cache names."""
return list(self.cache)

def expire(self, now: float) -> Iterable[DNSRecord]:
"""Purge expired entries from the cache."""
for name in self.names():
for record in self.entries_with_name(name):
if record.is_expired(now):
self.remove(record)
yield record