diff --git a/src/zeroconf/_cache.pxd b/src/zeroconf/_cache.pxd index a39ed756..273d46c3 100644 --- a/src/zeroconf/_cache.pxd +++ b/src/zeroconf/_cache.pxd @@ -39,7 +39,7 @@ cdef class DNSCache: @cython.locals(store=cython.dict) cpdef DNSRecord async_get_unique(self, DNSRecord entry) - @cython.locals(record=DNSRecord) + @cython.locals(record=DNSRecord, when_record=tuple, when=double) cpdef list async_expire(self, double now) @cython.locals(records=cython.dict, record=DNSRecord) @@ -57,8 +57,10 @@ cdef class DNSCache: @cython.locals( store=cython.dict, + service_store=cython.dict, service_record=DNSService, - when=object + when=object, + new=bint ) cdef bint _async_add(self, DNSRecord record) diff --git a/src/zeroconf/_cache.py b/src/zeroconf/_cache.py index a43bdc5c..1b7aae38 100644 --- a/src/zeroconf/_cache.py +++ b/src/zeroconf/_cache.py @@ -86,7 +86,8 @@ def _async_add(self, record: _DNSRecord) -> bool: # replaces any existing records that are __eq__ to each other which # removes the risk that accessing the cache from the wrong # direction would return the old incorrect entry. - store = self.cache.setdefault(record.key, {}) + if (store := self.cache.get(record.key)) is None: + store = self.cache[record.key] = {} new = record not in store and not isinstance(record, DNSNsec) store[record] = record when = record.created + (record.ttl * 1000) @@ -97,7 +98,9 @@ def _async_add(self, record: _DNSRecord) -> bool: if isinstance(record, DNSService): service_record = record - self.service_cache.setdefault(record.server_key, {})[service_record] = service_record + if (service_store := self.service_cache.get(service_record.server_key)) is None: + service_store = self.service_cache[service_record.server_key] = {} + service_store[service_record] = service_record return new def async_add_records(self, entries: Iterable[DNSRecord]) -> bool: @@ -145,7 +148,8 @@ def async_expire(self, now: _float) -> List[DNSRecord]: expired: List[DNSRecord] = [] # Find any expired records and add them to the to-delete list while self._expire_heap: - when, record = self._expire_heap[0] + when_record = self._expire_heap[0] + when = when_record[0] if when > now: break heappop(self._expire_heap) @@ -153,6 +157,7 @@ def async_expire(self, now: _float) -> List[DNSRecord]: # with a different expiration time as it will be removed # later when it reaches the top of the heap and its # expiration time is met. + record = when_record[1] if self._expirations.get(record) == when: expired.append(record)