forked from tc-wolf/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_cache.py
More file actions
279 lines (228 loc) · 9.56 KB
/
Copy pathllama_cache.py
File metadata and controls
279 lines (228 loc) · 9.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import pickle
import sys
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Optional, Sequence, Tuple
import diskcache
import pytrie
import llama_cpp.llama
from .llama_types import *
class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""
def __init__(self, capacity_bytes: int = (2 << 30)):
self.capacity_bytes = capacity_bytes
@property
@abstractmethod
def cache_size(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def is_ro(self) -> bool:
raise NotImplementedError
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
pass
@abstractmethod
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
raise NotImplementedError
@abstractmethod
def __contains__(self, key: Sequence[int]) -> bool:
raise NotImplementedError
@abstractmethod
def __setitem__(
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
) -> None:
raise NotImplementedError
class LlamaRAMCache(BaseLlamaCache):
"""Cache for a llama.cpp model using RAM."""
def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = (
OrderedDict()
)
@property
def cache_size(self):
return sum([state.llama_state_size for state in self.cache_state.values()])
@property
def is_ro(self) -> bool:
return False
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key = None
keys = (
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
self.cache_state.popitem(last=False)
# Alias for backwards compatibility
LlamaCache = LlamaRAMCache
class LlamaDiskCache(BaseLlamaCache):
"""Cache for a llama.cpp model using disk."""
def __init__(
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
):
super().__init__(capacity_bytes)
self.cache = diskcache.Cache(cache_dir)
@property
def cache_size(self):
return int(self.cache.volume()) # type: ignore
@property
def is_ro(self) -> bool:
return False
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key: Optional[Tuple[int, ...]] = None
for k in self.cache.iterkeys(): # type: ignore
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
if prefix_len > min_len:
min_len = prefix_len
min_key = k # type: ignore
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
key = tuple(key)
if key in self.cache:
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
del self.cache[key]
self.cache[key] = value
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
key_to_remove = next(iter(self.cache))
del self.cache[key_to_remove]
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
class LlamaStaticDiskCache(BaseLlamaCache):
"""
Cache that only reads from the cache, doesn't store / overwrite items, and
doesn't pop from cache.
Still using diskcache.Cache for underlying cache, but uses a trie to store
keys so that can more efficiently look for prefixes.
Want to store C++ state as bytes (from `llama_copy_state_data`), but for now
still storing LlamaState, because need scores/input_ids/n_tokens so that Python
code can continue inference.
"""
def __init__(
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
):
self.cache = diskcache.Cache(
cache_dir, size_limit=capacity_bytes, cull_limit=0, eviction_policy="none"
)
self.capacity_bytes = capacity_bytes
# Don't want to have to iterate over all keys when doing longest matching prefix search
self.keys = pytrie.Trie.fromkeys(self.cache.iterkeys())
@property
def cache_size(self):
return int(self.cache.volume()) # type: ignore
@property
def is_ro(self) -> bool:
return True
def _private_setitem(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
if self.cache_size > self.capacity_bytes:
# I think it's okay to raise an error here, because only done when building cache anyway.
raise ValueError("Cache is full, refusing to set more")
key = tuple(key)
if key in self.cache:
print(
"LlamaStaticDiskCache._private_setitem: delete (overwriting)",
file=sys.stderr,
)
del self.cache[key]
# This is what diskcache does anyway, eventually want this to be more compact
print("LlamaStaticDiskCache._private_setitem: set", file=sys.stderr)
self.cache[key] = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
@staticmethod
def build_cache(
cache_dir: str,
prompts: Sequence[str],
model: "llama_cpp.Llama",
# Same default as LlamaDiskCache, 1 GB
capacity_bytes: int = 2 << 30,
seed: Optional[int] = None,
) -> "LlamaStaticDiskCache":
"""
Using model passed in, evaluates each prompt and stores LlamaState in cache.
Returns a new LlamaStaticDiskCache instance with cache at cache_dir.
"""
cache = LlamaStaticDiskCache(cache_dir, capacity_bytes)
for p in prompts:
if seed:
model.set_seed(seed)
# Special tokens == control characters like in ChatML
toks = model.tokenize(p.encode("utf-8"), add_bos=True, special=True)
# Will always eval at least one token, same logic as in
# `Llama.generate` for prefix-match hit.
# pylint: disable=protected-access
shared_prefix_len = model.longest_token_prefix(toks[:-1], model._input_ids)
# Reset to shared prefix length so that don't have to re-eval system prompt
model.n_tokens = shared_prefix_len
eval_toks = toks[shared_prefix_len:]
print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr)
model.eval(eval_toks)
state = model.save_state()
cache._private_setitem(toks, state) # pylint: disable=protected-access
# Set up Trie for efficient prefix search
for key in cache.cache.iterkeys():
cache.keys[key] = None
return cache
def _find_longest_prefix_key(self, key: Tuple[int]) -> Optional[Tuple[int, ...]]:
try:
longest_prefix = self.keys.longest_prefix(key)
return longest_prefix
except KeyError:
return None
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
"""
Only handling exact matches (not prefixes). Use case is that have some
prompt + context that want to match against.
"""
key = tuple(key)
# Don't worry about KeyError, that's handled by caller
longest_prefix = self._find_longest_prefix_key(key)
value: "llama_cpp.llama.LlamaState" = pickle.loads(self.cache[longest_prefix])
return value
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
# Should this just be a warning?
raise ValueError("Cannot set items in a static cache")