Skip to content

Commit 9be3cd1

Browse files
allthatidoAnkur Kaul
andauthored
fix: preserve recurrent/hybrid model state when the full prompt is already cached (abetlen#2306)
Co-authored-by: Ankur Kaul <akaul36@gatech.edu>
1 parent b11fe07 commit 9be3cd1

3 files changed

Lines changed: 323 additions & 16 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306
11+
1012
## [0.3.31]
1113

1214
- feat: update llama.cpp to ggml-org/llama.cpp@f449e0553

llama_cpp/llama.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ def free_lora_adapter():
471471
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)
472472

473473
self.n_tokens = 0
474+
# Restored or truncated state must decode before sampling.
475+
self._requires_eval = True
474476
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
475477
self.scores: npt.NDArray[np.single] = np.ndarray(
476478
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
@@ -647,6 +649,7 @@ def set_seed(self, seed: int):
647649
def reset(self):
648650
"""Reset the model state."""
649651
self.n_tokens = 0
652+
self._requires_eval = True
650653

651654
if self._is_recurrent or self._is_hybrid:
652655
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
@@ -689,6 +692,7 @@ def eval(self, tokens: Sequence[int]):
689692
pass
690693
# Update n_tokens
691694
self.n_tokens += n_tokens
695+
self._requires_eval = False
692696

693697
def _init_sampler(
694698
self,
@@ -900,41 +904,53 @@ def generate(
900904
grammar=grammar,
901905
)
902906

907+
tokens = list(tokens)
908+
903909
# Check for kv cache prefix match
904910
if reset and self.n_tokens > 0:
905911
longest_prefix = 0
906-
for a, b in zip(self._input_ids, tokens[:-1]):
912+
for a, b in zip(self._input_ids, tokens):
907913
if a == b:
908914
longest_prefix += 1
909915
else:
910916
break
911917

912-
# Recurrent and hybrid models cannot rewind state; reset if needed
913-
if (
914-
self._is_recurrent or self._is_hybrid
915-
) and longest_prefix < self.n_tokens:
916-
longest_prefix = 0
917-
reset = True
918+
prompt_consumed = longest_prefix == len(tokens)
919+
exact_prompt_cached = self.n_tokens == len(tokens) and prompt_consumed
920+
921+
# Exact cache hits can sample immediately only when the current
922+
# logits were produced by a live decode, not restored state.
923+
if exact_prompt_cached and not self._requires_eval:
924+
reset = False
925+
tokens = []
926+
reuse_prefix = 0
918927
if self.verbose:
919928
print(
920-
"Llama.generate: recurrent/hybrid model requires full state reset",
929+
"Llama.generate: full prompt already cached, skipping reset",
921930
file=sys.stderr,
922931
)
923-
924-
if longest_prefix > 0:
925-
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
932+
else:
933+
# If there is no suffix to decode, replay one token to refresh
934+
# logits after truncating to a valid prefix.
935+
reuse_prefix = longest_prefix - 1 if prompt_consumed else longest_prefix
936+
937+
# Prefix hits can reuse memory because the suffix decode refreshes
938+
# logits before sampling.
939+
if reuse_prefix > 0:
940+
if self._ctx.kv_cache_seq_rm(-1, reuse_prefix, -1):
926941
reset = False
927-
tokens = tokens[longest_prefix:]
928-
self.n_tokens = longest_prefix
942+
tokens = tokens[reuse_prefix:]
943+
self.n_tokens = reuse_prefix
944+
self._requires_eval = True
929945
if self.verbose:
930946
print(
931-
f"Llama.generate: {longest_prefix} prefix-match hit, "
947+
f"Llama.generate: {reuse_prefix} prefix-match hit, "
932948
f"remaining {len(tokens)} prompt tokens to eval",
933949
file=sys.stderr,
934950
)
935951
elif self.verbose:
936952
print(
937-
f"Llama.generate: {longest_prefix} prefix-match found "
953+
f"Llama.generate: {reuse_prefix} prefix-match found "
938954
f"but partial kv removal not supported, re-evaluating full prompt",
939955
file=sys.stderr,
940956
)
@@ -948,7 +964,6 @@ def generate(
948964
# grammar.reset()
949965

950966
sample_idx = self.n_tokens + len(tokens) - 1
951-
tokens = list(tokens)
952967

953968
# Eval and sample
954969
while True:
@@ -988,6 +1003,7 @@ def generate(
9881003
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
9891004
self.n_tokens = sample_idx
9901005
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
1006+
self._requires_eval = True
9911007
break
9921008

9931009
if self.draft_model is not None:
@@ -2217,6 +2233,7 @@ def load_state(self, state: LlamaState) -> None:
22172233
rest[rest > 0] = 0.0
22182234
self.input_ids = state.input_ids.copy()
22192235
self.n_tokens = state.n_tokens
2236+
self._requires_eval = True
22202237
self._seed = state.seed
22212238
state_size = state.llama_state_size
22222239
LLamaStateArrayType = ctypes.c_uint8 * state_size

0 commit comments

Comments
 (0)