@@ -471,6 +471,8 @@ def free_lora_adapter():
471471 self ._candidates = internals .LlamaTokenDataArray (n_vocab = self ._n_vocab )
472472
473473 self .n_tokens = 0
474+ # Restored or truncated state must decode before sampling.
475+ self ._requires_eval = True
474476 self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
475477 self .scores : npt .NDArray [np .single ] = np .ndarray (
476478 (n_ctx if logits_all == True else n_batch , self ._n_vocab ), dtype = np .single
@@ -647,6 +649,7 @@ def set_seed(self, seed: int):
647649 def reset (self ):
648650 """Reset the model state."""
649651 self .n_tokens = 0
652+ self ._requires_eval = True
650653
651654 if self ._is_recurrent or self ._is_hybrid :
652655 mem = llama_cpp .llama_get_memory (self ._ctx .ctx )
@@ -689,6 +692,7 @@ def eval(self, tokens: Sequence[int]):
689692 pass
690693 # Update n_tokens
691694 self .n_tokens += n_tokens
695+ self ._requires_eval = False
692696
693697 def _init_sampler (
694698 self ,
@@ -900,41 +904,53 @@ def generate(
900904 grammar = grammar ,
901905 )
902906
907+ tokens = list (tokens )
908+
903909 # Check for kv cache prefix match
904910 if reset and self .n_tokens > 0 :
905911 longest_prefix = 0
906- for a , b in zip (self ._input_ids , tokens [: - 1 ] ):
912+ for a , b in zip (self ._input_ids , tokens ):
907913 if a == b :
908914 longest_prefix += 1
909915 else :
910916 break
911917
912- # Recurrent and hybrid models cannot rewind state; reset if needed
913- if (
914- self ._is_recurrent or self ._is_hybrid
915- ) and longest_prefix < self .n_tokens :
916- longest_prefix = 0
917- reset = True
918+ prompt_consumed = longest_prefix == len (tokens )
919+ exact_prompt_cached = self .n_tokens == len (tokens ) and prompt_consumed
920+
921+ # Exact cache hits can sample immediately only when the current
922+ # logits were produced by a live decode, not restored state.
923+ if exact_prompt_cached and not self ._requires_eval :
924+ reset = False
925+ tokens = []
926+ reuse_prefix = 0
918927 if self .verbose :
919928 print (
920- "Llama.generate: recurrent/hybrid model requires full state reset" ,
929+ "Llama.generate: full prompt already cached, skipping reset" ,
921930 file = sys .stderr ,
922931 )
923-
924- if longest_prefix > 0 :
925- if self ._ctx .kv_cache_seq_rm (- 1 , longest_prefix , - 1 ):
932+ else :
933+ # If there is no suffix to decode, replay one token to refresh
934+ # logits after truncating to a valid prefix.
935+ reuse_prefix = longest_prefix - 1 if prompt_consumed else longest_prefix
936+
937+ # Prefix hits can reuse memory because the suffix decode refreshes
938+ # logits before sampling.
939+ if reuse_prefix > 0 :
940+ if self ._ctx .kv_cache_seq_rm (- 1 , reuse_prefix , - 1 ):
926941 reset = False
927- tokens = tokens [longest_prefix :]
928- self .n_tokens = longest_prefix
942+ tokens = tokens [reuse_prefix :]
943+ self .n_tokens = reuse_prefix
944+ self ._requires_eval = True
929945 if self .verbose :
930946 print (
931- f"Llama.generate: { longest_prefix } prefix-match hit, "
947+ f"Llama.generate: { reuse_prefix } prefix-match hit, "
932948 f"remaining { len (tokens )} prompt tokens to eval" ,
933949 file = sys .stderr ,
934950 )
935951 elif self .verbose :
936952 print (
937- f"Llama.generate: { longest_prefix } prefix-match found "
953+ f"Llama.generate: { reuse_prefix } prefix-match found "
938954 f"but partial kv removal not supported, re-evaluating full prompt" ,
939955 file = sys .stderr ,
940956 )
@@ -948,7 +964,6 @@ def generate(
948964 # grammar.reset()
949965
950966 sample_idx = self .n_tokens + len (tokens ) - 1
951- tokens = list (tokens )
952967
953968 # Eval and sample
954969 while True :
@@ -988,6 +1003,7 @@ def generate(
9881003 if sample_idx < self .n_tokens and token != self ._input_ids [sample_idx ]:
9891004 self .n_tokens = sample_idx
9901005 self ._ctx .kv_cache_seq_rm (- 1 , self .n_tokens , - 1 )
1006+ self ._requires_eval = True
9911007 break
9921008
9931009 if self .draft_model is not None :
@@ -2217,6 +2233,7 @@ def load_state(self, state: LlamaState) -> None:
22172233 rest [rest > 0 ] = 0.0
22182234 self .input_ids = state .input_ids .copy ()
22192235 self .n_tokens = state .n_tokens
2236+ self ._requires_eval = True
22202237 self ._seed = state .seed
22212238 state_size = state .llama_state_size
22222239 LLamaStateArrayType = ctypes .c_uint8 * state_size
0 commit comments