2727from .const import _TYPE_PTR
2828
2929
30+ _DNSRecordCacheType = Dict [str , Dict [DNSRecord , DNSRecord ]]
31+
32+
33+ def _remove_key (cache : _DNSRecordCacheType , key : str , entry : DNSRecord ) -> None :
34+ """Remove a key from a DNSRecord cache
35+
36+ This function must be run in from event loop.
37+ """
38+ del cache [key ][entry ]
39+ if not cache [key ]:
40+ del cache [key ]
41+
42+
3043class DNSCache :
3144 """A cache of DNS entries."""
3245
3346 def __init__ (self ) -> None :
34- self .cache : Dict [str , List [DNSRecord ]] = {}
35- self .service_cache : Dict [str , List [DNSRecord ]] = {}
47+ self .cache : _DNSRecordCacheType = {}
48+ self .service_cache : _DNSRecordCacheType = {}
49+
50+ # Functions prefixed with are NOT threadsafe and must
51+ # be run in the event loop.
3652
3753 def add (self , entry : DNSRecord ) -> None :
38- """Adds an entry"""
39- # Insert last in list, get will return newest entry
40- # iteration will result in last update winning
41- self .cache .setdefault (entry .key , []).append (entry )
54+ """Adds an entry.
55+
56+ This function must be run in from event loop.
57+ """
58+ # Previously storage of records was implemented as a list
59+ # instead a dict. Since DNSRecords are now hashable, the implementation
60+ # uses a dict to ensure that adding a new record to the cache
61+ # replaces any existing records that are __eq__ to each other which
62+ # removes the risk that accessing the cache from the wrong
63+ # direction would return the old incorrect entry.
64+ self .cache .setdefault (entry .key , {})[entry ] = entry
4265 if isinstance (entry , DNSService ):
43- self .service_cache .setdefault (entry .server , []). append ( entry )
66+ self .service_cache .setdefault (entry .server , {})[ entry ] = entry
4467
4568 def add_records (self , entries : Iterable [DNSRecord ]) -> None :
46- """Add multiple records."""
69+ """Add multiple records.
70+
71+ This function must be run in from event loop.
72+ """
4773 for entry in entries :
4874 self .add (entry )
4975
5076 def remove (self , entry : DNSRecord ) -> None :
51- """Removes an entry."""
77+ """Removes an entry.
78+
79+ This function must be run in from event loop.
80+ """
5281 if isinstance (entry , DNSService ):
53- DNSCache . remove_key (self .service_cache , entry .server , entry )
54- DNSCache . remove_key (self .cache , entry .key , entry )
82+ _remove_key (self .service_cache , entry .server , entry )
83+ _remove_key (self .cache , entry .key , entry )
5584
5685 def remove_records (self , entries : Iterable [DNSRecord ]) -> None :
57- """Remove multiple records."""
86+ """Remove multiple records.
87+
88+ This function must be run in from event loop.
89+ """
5890 for entry in entries :
5991 self .remove (entry )
6092
61- @staticmethod
62- def remove_key (cache : dict , key : str , entry : DNSRecord ) -> None :
63- """Forgiving remove of a cache key."""
64- try :
65- cache [key ].remove (entry )
66- if not cache [key ]:
67- del cache [key ]
68- except (KeyError , ValueError ):
69- pass
93+ def expire (self , now : float ) -> Iterable [DNSRecord ]:
94+ """Purge expired entries from the cache.
95+
96+ This function must be run in from event loop.
97+ """
98+ for name in self .names ():
99+ for record in self .entries_with_name (name ):
100+ if record .is_expired (now ):
101+ self .remove (record )
102+ yield record
103+
104+ # The below functions are threadsafe and do not need to be run in the
105+ # event loop, however they all make copies so they significantly
106+ # inefficent
70107
71108 def get (self , entry : DNSEntry ) -> Optional [DNSRecord ]:
72109 """Gets an entry by key. Will return None if there is no
@@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
77114 return None
78115
79116 def get_by_details (self , name : str , type_ : int , class_ : int ) -> Optional [DNSRecord ]:
80- """Gets the first matching entry by details. Returns None if no entries match."""
117+ """Gets the first matching entry by details. Returns None if no entries match.
118+
119+ Calling this function is not recommended as it will only
120+ return one record even if there are multiple entries.
121+
122+ For example if there are multiple A or AAAA addresses this
123+ function will return the last one that was added to the cache
124+ which may not be the one you expect.
125+
126+ Use get_all_by_details instead.
127+ """
81128 return self .get (DNSEntry (name , type_ , class_ ))
82129
83130 def get_all_by_details (self , name : str , type_ : int , class_ : int ) -> List [DNSRecord ]:
@@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco
87134
88135 def entries_with_server (self , server : str ) -> List [DNSRecord ]:
89136 """Returns a list of entries whose server matches the name."""
90- return self .service_cache .get (server , [])[:]
137+ return list ( self .service_cache .get (server , {}))
91138
92139 def entries_with_name (self , name : str ) -> List [DNSRecord ]:
93140 """Returns a list of entries whose key matches the name."""
94- return self .cache .get (name .lower (), [])[:]
141+ return list ( self .cache .get (name .lower (), {}))
95142
96143 def current_entry_with_name_and_alias (self , name : str , alias : str ) -> Optional [DNSRecord ]:
97144 now = current_time_millis ()
@@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D
107154 def names (self ) -> List [str ]:
108155 """Return a copy of the list of current cache names."""
109156 return list (self .cache )
110-
111- def expire (self , now : float ) -> Iterable [DNSRecord ]:
112- """Purge expired entries from the cache."""
113- for name in self .names ():
114- for record in self .entries_with_name (name ):
115- if record .is_expired (now ):
116- self .remove (record )
117- yield record
0 commit comments