|
21 | 21 | """ |
22 | 22 |
|
23 | 23 | import itertools |
24 | | -from typing import Dict, Iterable, Iterator, List, Optional, Union, cast |
| 24 | +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast |
25 | 25 |
|
26 | 26 | from ._dns import ( |
27 | 27 | DNSAddress, |
|
34 | 34 | DNSText, |
35 | 35 | ) |
36 | 36 | from ._utils.time import current_time_millis |
37 | | -from .const import _TYPE_PTR |
| 37 | +from .const import _ONE_SECOND, _TYPE_PTR |
38 | 38 |
|
39 | 39 | _UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) |
40 | 40 | _UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] |
41 | 41 | _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] |
42 | 42 | _DNSRecord = DNSRecord |
43 | 43 | _str = str |
| 44 | +_float = float |
| 45 | +_int = int |
44 | 46 |
|
45 | 47 |
|
46 | 48 | def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None: |
@@ -134,19 +136,29 @@ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: |
134 | 136 | return None |
135 | 137 | return store.get(entry) |
136 | 138 |
|
137 | | - def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterator[DNSRecord]: |
| 139 | + def async_all_by_details(self, name: _str, type_: int, class_: int) -> Iterable[DNSRecord]: |
138 | 140 | """Gets all matching entries by details. |
139 | 141 |
|
140 | | - This function is not threadsafe and must be called from |
| 142 | + This function is not thread-safe and must be called from |
| 143 | + the event loop. |
| 144 | + """ |
| 145 | + return self._async_all_by_details(name, type_, class_) |
| 146 | + |
| 147 | + def _async_all_by_details(self, name: _str, type_: int, class_: int) -> List[DNSRecord]: |
| 148 | + """Gets all matching entries by details. |
| 149 | +
|
| 150 | + This function is not thread-safe and must be called from |
141 | 151 | the event loop. |
142 | 152 | """ |
143 | 153 | key = name.lower() |
144 | 154 | records = self.cache.get(key) |
| 155 | + matches: List[DNSRecord] = [] |
145 | 156 | if records is None: |
146 | | - return |
147 | | - for entry in records: |
148 | | - if _dns_record_matches(entry, key, type_, class_): |
149 | | - yield entry |
| 157 | + return matches |
| 158 | + for record in records: |
| 159 | + if _dns_record_matches(record, key, type_, class_): |
| 160 | + matches.append(record) |
| 161 | + return matches |
150 | 162 |
|
151 | 163 | def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: |
152 | 164 | """Returns a dict of entries whose key matches the name. |
@@ -226,6 +238,25 @@ def names(self) -> List[str]: |
226 | 238 | """Return a copy of the list of current cache names.""" |
227 | 239 | return list(self.cache) |
228 | 240 |
|
| 241 | + def async_mark_unique_records_older_than_1s_to_expire( |
| 242 | + self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float |
| 243 | + ) -> None: |
| 244 | + self._async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now) |
| 245 | + |
| 246 | + def _async_mark_unique_records_older_than_1s_to_expire( |
| 247 | + self, unique_types: Set[Tuple[_str, _int, _int]], answers: Iterable[DNSRecord], now: _float |
| 248 | + ) -> None: |
| 249 | + # rfc6762#section-10.2 para 2 |
| 250 | + # Since unique is set, all old records with that name, rrtype, |
| 251 | + # and rrclass that were received more than one second ago are declared |
| 252 | + # invalid, and marked to expire from the cache in one second. |
| 253 | + answers_rrset = set(answers) |
| 254 | + for name, type_, class_ in unique_types: |
| 255 | + for record in self._async_all_by_details(name, type_, class_): |
| 256 | + if (now - record.created > _ONE_SECOND) and record not in answers_rrset: |
| 257 | + # Expire in 1s |
| 258 | + record.set_created_ttl(now, 1) |
| 259 | + |
229 | 260 |
|
230 | 261 | def _dns_record_matches(record: _DNSRecord, key: _str, type_: int, class_: int) -> bool: |
231 | 262 | return key == record.key and type_ == record.type and class_ == record.class_ |
0 commit comments