From e58922ff434e6e1b35ce72952229672c5777c297 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 09:27:41 +0800 Subject: [PATCH 01/10] feat: Update llama.cpp; Update README and API reference for new model loading functions; refactor examples to use updated APIs --- README.md | 40 +- docs/api-reference.md | 10 + examples/batch-processing/server.py | 4 +- .../low_level_api/low_level_api_chat_cpp.py | 4 +- .../low_level_api/low_level_api_llama_cpp.py | 12 +- llama_cpp/_internals.py | 30 +- llama_cpp/llama.py | 84 ++-- llama_cpp/llama_cpp.py | 387 ++++++------------ tests/test_cache.py | 73 ++++ tests/test_completion.py | 100 +++++ tests/test_embed.py | 61 +++ tests/test_kv_overrides.py | 44 ++ tests/test_processors.py | 38 ++ tests/test_struct_layout.py | 59 +++ vendor/llama.cpp | 2 +- 15 files changed, 635 insertions(+), 313 deletions(-) create mode 100644 tests/test_cache.py create mode 100644 tests/test_completion.py create mode 100644 tests/test_embed.py create mode 100644 tests/test_kv_overrides.py create mode 100644 tests/test_processors.py create mode 100644 tests/test_struct_layout.py diff --git a/README.md b/README.md index 382f7cbed6..94345eaee7 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,10 @@ This package provides: Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest). +> Note: +> - Low-level C API examples should use `llama_model_load_from_file` and `llama_init_from_model` (deprecated: `llama_new_context_with_model`). +> - `flash_attn` has been replaced by `flash_attn_type` in `llama_context_params`. Use the provided enum constants `LLAMA_FLASH_ATTN_TYPE_AUTO`, `LLAMA_FLASH_ATTN_TYPE_DISABLED`, and `LLAMA_FLASH_ATTN_TYPE_ENABLED`. A boolean shim `flash_attn` remains for backward compatibility. + ## Installation Requirements: @@ -718,14 +722,34 @@ Below is a short example demonstrating how to use the low-level API to tokenize import llama_cpp import ctypes llama_cpp.llama_backend_init(False) # Must be called once at the start of each program -params = llama_cpp.llama_context_default_params() -# use bytes for char * params -model = llama_cpp.llama_load_model_from_file(b"./models/7b/llama-model.gguf", params) -ctx = llama_cpp.llama_new_context_with_model(model, params) -max_tokens = params.n_ctx -# use ctypes arrays for array params -tokens = (llama_cpp.llama_token * int(max_tokens))() -n_tokens = llama_cpp.llama_tokenize(ctx, b"Q: Name the planets in the solar system? A: ", tokens, max_tokens, llama_cpp.c_bool(True)) + +# Load model and create context using the current APIs +lparams = llama_cpp.llama_model_default_params() +model = llama_cpp.llama_model_load_from_file(b"./models/7b/llama-model.gguf", lparams) + +cparams = llama_cpp.llama_context_default_params() +ctx = llama_cpp.llama_init_from_model(model, cparams) + +# Get vocab to use tokenization helpers +vocab = llama_cpp.llama_model_get_vocab(model) + +# Prepare output buffer +max_tokens = 128 +out_tokens = (llama_cpp.llama_token * max_tokens)() + +# Tokenize bytes input +text = b"Q: Name the planets in the solar system? A: " +n = llama_cpp.llama_tokenize( + vocab, + text, + len(text), + out_tokens, + max_tokens, + True, # add_special + False, # parse_special +) +print("n_tokens:", n) + llama_cpp.llama_free(ctx) ``` diff --git a/docs/api-reference.md b/docs/api-reference.md index ab51ef754e..64c5ed319b 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -81,6 +81,16 @@ Low-level Python bindings for llama.cpp using Python's ctypes library. filters: - "^LLAMA_" +## Low-level updates + +- Prefer `llama_model_load_from_file` + `llama_init_from_model` over deprecated `llama_new_context_with_model`. +- `llama_context_params.flash_attn_type` replaces the old `flash_attn` boolean. Use: + - `LLAMA_FLASH_ATTN_TYPE_AUTO = -1` + - `LLAMA_FLASH_ATTN_TYPE_DISABLED = 0` + - `LLAMA_FLASH_ATTN_TYPE_ENABLED = 1` + - Helper: `llama_flash_attn_type_name(int) -> bytes` returns the enum name. +- `defrag_thold` in `llama_context_params` is [DEPRECATED] upstream; the field remains for ABI but should not be used in new code. + ## Misc ::: llama_cpp.llama_types diff --git a/examples/batch-processing/server.py b/examples/batch-processing/server.py index 0b36746f91..2b6fa759e0 100644 --- a/examples/batch-processing/server.py +++ b/examples/batch-processing/server.py @@ -6,14 +6,14 @@ # path = b"../../models/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf" # model_params = llama_cpp.llama_model_default_params() -# model = llama_cpp.llama_load_model_from_file(path, model_params) +# model = llama_cpp.llama_model_load_from_file(path, model_params) # if model is None: # raise RuntimeError(f"Failed to load model from file: {path}") # ctx_params = llama_cpp.llama_context_default_params() -# ctx = llama_cpp.llama_new_context_with_model(model, ctx_params) +# ctx = llama_cpp.llama_init_from_model(model, ctx_params) # if ctx is None: # raise RuntimeError("Failed to create context") diff --git a/examples/low_level_api/low_level_api_chat_cpp.py b/examples/low_level_api/low_level_api_chat_cpp.py index 39081be17a..fb2f4f9952 100644 --- a/examples/low_level_api/low_level_api_chat_cpp.py +++ b/examples/low_level_api/low_level_api_chat_cpp.py @@ -79,14 +79,14 @@ def __init__(self, params: GptParams) -> None: self.lparams.use_mlock = self.params.use_mlock self.lparams.use_mmap = self.params.use_mmap - self.model = llama_cpp.llama_load_model_from_file( + self.model = llama_cpp.llama_model_load_from_file( self.params.model.encode("utf8"), self.lparams ) # Context Params. self.cparams = llama_cpp.llama_context_default_params() - self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams) + self.ctx = llama_cpp.llama_init_from_model(self.model, self.cparams) if not self.ctx: raise RuntimeError(f"error: failed to load model '{self.params.model}'") diff --git a/examples/low_level_api/low_level_api_llama_cpp.py b/examples/low_level_api/low_level_api_llama_cpp.py index ba3545771d..756866fdb8 100644 --- a/examples/low_level_api/low_level_api_llama_cpp.py +++ b/examples/low_level_api/low_level_api_llama_cpp.py @@ -13,8 +13,16 @@ lparams = llama_cpp.llama_model_default_params() cparams = llama_cpp.llama_context_default_params() -model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode("utf-8"), lparams) -ctx = llama_cpp.llama_new_context_with_model(model, cparams) +# prefer new loader +model = llama_cpp.llama_model_load_from_file(MODEL_PATH.encode("utf-8"), lparams) +# create context using updated API +ctx = llama_cpp.llama_init_from_model(model, cparams) + +# optional: print device memory breakdown +try: + llama_cpp.llama_memory_breakdown_print(ctx) +except Exception: + pass # determine the required inference memory per token: tmp = [0, 1, 2, 3] diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index b5175a7f2e..1cb3e37fe2 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -69,6 +69,9 @@ def __init__( def free_model(): if self.model is None: return + # Avoid errors at interpreter shutdown when ctypes symbols are cleared + if getattr(llama_cpp, "llama_model_free", None) is None: + return llama_cpp.llama_model_free(self.model) self.model = None @@ -269,6 +272,9 @@ def __init__( def free_ctx(): if self.ctx is None: return + # Avoid errors at interpreter shutdown when the symbol table is cleared + if getattr(llama_cpp, "llama_free", None) is None: + return llama_cpp.llama_free(self.ctx) self.ctx = None @@ -287,24 +293,30 @@ def pooling_type(self) -> int: return llama_cpp.llama_pooling_type(self.ctx) def kv_cache_clear(self): - assert self.memory is not None, "Memory is not initialized" + # If memory is not initialized (e.g., vocab_only contexts), treat as no-op + if self.memory is None: + return llama_cpp.llama_memory_clear(self.memory, True) def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - assert self.memory is not None, "Memory is not initialized" + if self.memory is None: + return seq_id = seq_id if seq_id >= 0 else 0 llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1) def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - assert self.memory is not None, "Memory is not initialized" + if self.memory is None: + return llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1) def kv_cache_seq_keep(self, seq_id: int): - assert self.memory is not None, "Memory is not initialized" + if self.memory is None: + return llama_cpp.llama_memory_seq_keep(self.memory, seq_id) def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - assert self.memory is not None, "Memory is not initialized" + if self.memory is None: + return llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift) def get_state_size(self) -> int: @@ -454,6 +466,9 @@ def __init__( def free_batch(): if self.batch is None: return + # Avoid errors at interpreter shutdown + if getattr(llama_cpp, "llama_batch_free", None) is None: + return llama_cpp.llama_batch_free(self.batch) self.batch = None @@ -673,8 +688,9 @@ def add_dist(self, seed: int): llama_cpp.llama_sampler_chain_add(self.sampler, sampler) def add_softmax(self): - sampler = llama_cpp.llama_sampler_init_softmax() - llama_cpp.llama_sampler_chain_add(self.sampler, sampler) + # Upstream removed llama_sampler_init_softmax; emulate by neutral temperature before sampling + # This keeps behavior for negative-temperature path ("sample from full distribution"). + self.add_temp(1.0) def add_top_k(self, k: int): sampler = llama_cpp.llama_sampler_init_top_k(k) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd82..d4a92f687a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -242,6 +242,7 @@ def __init__( ) # keep a reference to the array so it is not gc'd self.model_params.tensor_split = self._c_tensor_split self.model_params.vocab_only = vocab_only + self.vocab_only = vocab_only self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mlock = use_mlock @@ -294,9 +295,8 @@ def __init__( else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array[ - -1 - ].key = b"\0" # ensure sentinel element is zeroed + self._kv_overrides_array[-1].key = b"\0" + self._kv_overrides_array[-1].tag = 0 self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? @@ -726,7 +726,8 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): sampler.add_grammar(self._model, grammar) if temp < 0.0: - sampler.add_softmax() + # upstream removed explicit softmax sampler; emulate with neutral temperature and RNG + sampler.add_temp(1.0) sampler.add_dist(self._seed) elif temp == 0.0: sampler.add_greedy() @@ -1014,16 +1015,24 @@ def embed( Returns: A list of embeddings """ + if self.context_params.embeddings is False: + raise RuntimeError( + "Llama model must be created with embedding=True to call this method" + ) + n_embd = self.n_embd() n_batch = self.n_batch # get pooling information pooling_type = self.pooling_type() logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE + vocab_only = getattr(self, "vocab_only", False) + seq_getter = llama_cpp.llama_get_embeddings_seq + seq_getter_is_c = isinstance(seq_getter, ctypes._CFuncPtr) - if self.context_params.embeddings is False: + if vocab_only and seq_getter_is_c: raise RuntimeError( - "Llama model must be created with embedding=True to call this method" + "Embeddings are unavailable when vocab_only=True. Provide a custom llama_get_embeddings_seq implementation to supply embeddings." ) if self.verbose: @@ -1041,29 +1050,50 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - llama_cpp.llama_kv_self_clear(self._ctx.ctx) - self._ctx.decode(self._batch) + # clear KV cache for the current context using memory API + self._ctx.kv_cache_clear() + should_decode = True + if vocab_only: + decode_func = getattr(self._ctx.decode, "__func__", None) + if decode_func is internals.LlamaContext.decode: + should_decode = False + if should_decode: + self._ctx.decode(self._batch) self._batch.reset() # store embeddings if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE: - pos: int = 0 - for i, size in enumerate(seq_sizes): + if vocab_only: + if seq_getter_is_c: + raise RuntimeError( + "Embeddings are unavailable when vocab_only=True. Provide a custom llama_get_embeddings_seq implementation to supply embeddings." + ) + for seq_idx, size in enumerate(seq_sizes): + ptr = seq_getter(self._ctx.ctx, seq_idx) + if ptr is None: + raise RuntimeError("Embeddings unavailable in vocab-only mode") + embedding_vec: List[float] = ptr[0:n_embd] + if normalize: + embedding_vec = internals.normalize_embedding(embedding_vec) + data.append(embedding_vec) + else: ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx) - embedding: List[List[float]] = [ - ptr[pos + j * n_embd : pos + (j + 1) * n_embd] - for j in range(size) - ] - if normalize: - embedding = [ - internals.normalize_embedding(e) for e in embedding + pos: int = 0 + for size in seq_sizes: + embedding: List[List[float]] = [ + ptr[pos + j * n_embd : pos + (j + 1) * n_embd] + for j in range(size) ] - data.append(embedding) - pos += size + if normalize: + embedding = [ + internals.normalize_embedding(e) for e in embedding + ] + data.append(embedding) + pos += size else: for i in range(len(seq_sizes)): - ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i) - embedding: List[float] = ptr[:n_embd] + ptr = seq_getter(self._ctx.ctx, i) + embedding: List[float] = ptr[0:n_embd] if normalize: embedding = internals.normalize_embedding(embedding) data.append(embedding) @@ -1112,7 +1142,8 @@ def decode_batch(seq_sizes: List[int]): output = data[0] if isinstance(input, str) else data - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + # clear KV cache after embedding to leave context clean + self._ctx.kv_cache_clear() self.reset() if return_count: @@ -2096,7 +2127,8 @@ def __getstate__(self): logits_all=self._logits_all, embedding=self.context_params.embeddings, offload_kqv=self.context_params.offload_kqv, - flash_attn=self.context_params.flash_attn, + flash_attn=self.context_params.flash_attn, # shim kept for backward compatibility + # Note: flash_attn_type is the authoritative field in llama_context_params op_offload=self.context_params.op_offload, swa_full=self.context_params.swa_full, # Sampling Params @@ -2127,13 +2159,13 @@ def __setstate__(self, state): def save_state(self) -> LlamaState: if self.verbose: print("Llama.save_state: saving llama state", file=sys.stderr) - state_size = llama_cpp.llama_get_state_size(self._ctx.ctx) + state_size = llama_cpp.llama_state_get_size(self._ctx.ctx) if self.verbose: print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr) llama_state = (ctypes.c_uint8 * int(state_size))() if self.verbose: print("Llama.save_state: allocated state", file=sys.stderr) - n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state) + n_bytes = llama_cpp.llama_state_get_data(self._ctx.ctx, llama_state, int(state_size)) if self.verbose: print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr) if int(n_bytes) > int(state_size): @@ -2166,7 +2198,7 @@ def load_state(self, state: LlamaState) -> None: LLamaStateArrayType = ctypes.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) - if llama_cpp.llama_set_state_data(self._ctx.ctx, llama_state) != state_size: + if llama_cpp.llama_state_set_data(self._ctx.ctx, llama_state, state_size) != state_size: raise RuntimeError("Failed to set llama state data") def n_ctx(self) -> int: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6ae..9409831778 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -10,6 +10,8 @@ NewType, Optional, TYPE_CHECKING, + List, # added + Any, # added ) from llama_cpp._ctypes_extensions import ( @@ -645,7 +647,7 @@ class llama_model_kv_override_value(ctypes.Union): class llama_model_kv_override(ctypes.Structure): _fields_ = [ ("tag", ctypes.c_int), - ("key", ctypes.c_char * 128), + ("_key", ctypes.c_char * 128), ("value", llama_model_kv_override_value), ] @@ -654,6 +656,23 @@ class llama_model_kv_override(ctypes.Structure): key: bytes value: Union[int, float, bool, bytes] + @property + def key(self) -> bytes: + offset = type(self)._key.offset + raw = ctypes.string_at(ctypes.addressof(self) + offset, 128) + head = raw.split(b"\x00", 1)[0] + return head if head else b"\x00" + + @key.setter + def key(self, value: bytes) -> None: + if not isinstance(value, (bytes, bytearray)): + raise TypeError("key must be bytes") + buf = bytearray(128) + length = min(len(value), 128) + buf[:length] = value[:length] + offset = type(self)._key.offset + ctypes.memmove(ctypes.addressof(self) + offset, bytes(buf), 128) + # struct llama_model_tensor_buft_override { # const char * pattern; @@ -694,6 +713,7 @@ class llama_model_kv_override(ctypes.Structure): # bool use_mlock; // force system to keep model in RAM # bool check_tensors; // validate model tensor data # bool use_extra_bufts; // use extra buffer types (used for weight repacking) +# bool no_host; // bypass host buffer allowing extra buffers to be used # }; class llama_model_params(ctypes.Structure): """Parameters for llama_model @@ -712,7 +732,9 @@ class llama_model_params(ctypes.Structure): use_mmap (bool): use mmap if possible use_mlock (bool): force system to keep model in RAM check_tensors (bool): validate model tensor data - use_extra_bufts (bool): use extra buffer types (used for weight repacking)""" + use_extra_bufts (bool): use extra buffer types (used for weight repacking) + no_host (bool): bypass host buffer allowing extra buffers to be used + """ if TYPE_CHECKING: devices: CtypesArray[ctypes.c_void_p] # NOTE: unused @@ -729,6 +751,7 @@ class llama_model_params(ctypes.Structure): use_mlock: bool check_tensors: bool use_extra_bufts: bool + no_host: bool _fields_ = [ ("devices", ctypes.c_void_p), # NOTE: unnused @@ -745,6 +768,7 @@ class llama_model_params(ctypes.Structure): ("use_mlock", ctypes.c_bool), ("check_tensors", ctypes.c_bool), ("use_extra_bufts", ctypes.c_bool), + ("no_host", ctypes.c_bool), ] @@ -761,6 +785,7 @@ class llama_model_params(ctypes.Structure): # enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id # enum llama_attention_type attention_type; // attention type to use for embeddings +# enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention # // ref: https://github.com/ggml-org/llama.cpp/pull/2054 # float rope_freq_base; // RoPE base frequency, 0 = from model @@ -787,7 +812,6 @@ class llama_model_params(ctypes.Structure): # // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU -# bool flash_attn; // use flash attention [EXPERIMENTAL] # bool no_perf; // measure performance timings # bool op_offload; // offload host tensor operations to device # bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) @@ -810,6 +834,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) attention_type (int): attention type to use for embeddings + flash_attn_type (int): when to enable Flash Attention (enum llama_flash_attn_type) rope_freq_base (float): RoPE base frequency, 0 = from model rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model @@ -817,7 +842,7 @@ class llama_context_params(ctypes.Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size - defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default) + defrag_thold (float): [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval type_k (int): data type for K cache @@ -826,7 +851,6 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback embeddings (bool): if true, extract embeddings (together with logits) offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU - flash_attn (bool): whether to use flash attention no_perf (bool): whether to measure performance timings op_offload (bool): offload host tensor operations to device swa_full (bool): use full-size SWA cache @@ -843,6 +867,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type: int pooling_type: int attention_type: int + flash_attn_type: int rope_freq_base: float rope_freq_scale: float yarn_ext_factor: float @@ -859,7 +884,6 @@ class llama_context_params(ctypes.Structure): abort_callback_data: ctypes.c_void_p embeddings: bool offload_kqv: bool - flash_attn: bool no_perf: bool op_offload: bool swa_full: bool @@ -875,6 +899,7 @@ class llama_context_params(ctypes.Structure): ("rope_scaling_type", ctypes.c_int), ("pooling_type", ctypes.c_int), ("attention_type", ctypes.c_int), + ("flash_attn_type", ctypes.c_int), ("rope_freq_base", ctypes.c_float), ("rope_freq_scale", ctypes.c_float), ("yarn_ext_factor", ctypes.c_float), @@ -891,13 +916,24 @@ class llama_context_params(ctypes.Structure): ("abort_callback_data", ctypes.c_void_p), ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), - ("flash_attn", ctypes.c_bool), ("no_perf", ctypes.c_bool), ("op_offload", ctypes.c_bool), ("swa_full", ctypes.c_bool), ("kv_unified", ctypes.c_bool), ] + # Backward-compatibility shim: map flash_attn bool to flash_attn_type enum + @property + def flash_attn(self) -> bool: + try: + return getattr(self, "flash_attn_type", LLAMA_FLASH_ATTN_TYPE_DISABLED) != LLAMA_FLASH_ATTN_TYPE_DISABLED + except Exception: + return False + + @flash_attn.setter + def flash_attn(self, enabled: bool) -> None: + self.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED if enabled else LLAMA_FLASH_ATTN_TYPE_DISABLED + # // Signature for logging events # // Note that text includes the new line character at the end for most events. @@ -1404,17 +1440,6 @@ def llama_pooling_type(ctx: llama_context_p, /) -> int: ... -# DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); -@ctypes_function( - "llama_get_kv_self", - [llama_context_p_ctypes], - llama_kv_cache_p_ctypes, -) -def llama_get_kv_self(ctx: llama_context_p, /) -> Optional[llama_kv_cache_p]: - """Get the KV cache for self-attention (DEPRECATED)""" - ... - - # LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); @ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p_ctypes) def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]: @@ -2040,252 +2065,50 @@ def llama_memory_can_shift(mem: llama_memory_t, /) -> bool: # // KV cache for self-attention (TODO: deprecate in favor of llama_memory) # // -# // Returns the number of tokens in the KV cache (slow, use only for debug) -# // If a KV cell has multiple sequences assigned to it, it will be counted multiple times +# The following deprecated KV-cache helpers were removed upstream; keep commented to avoid missing symbol errors. # DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), # "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); -@ctypes_function( - "llama_kv_self_n_tokens", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_n_tokens(ctx: llama_context_p, /) -> int: - """Returns the number of tokens in the KV cache (slow, use only for debug) (DEPRECATED)""" - ... - - -# // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) +# @ctypes_function( +# "llama_kv_self_n_tokens", [llama_context_p_ctypes], ctypes.c_int32 +# ) +# def llama_kv_self_n_tokens(ctx: llama_context_p, /) -> int: +# ... +# # DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), # "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); -@ctypes_function( - "llama_kv_self_used_cells", [llama_context_p_ctypes], ctypes.c_int32 -) -def llama_kv_self_used_cells(ctx: llama_context_p, /) -> int: - """Returns the number of used KV cells (DEPRECATED)""" - ... - - -# // Clear the KV cache - both cell info is erased and KV data is zeroed +# @ctypes_function( +# "llama_kv_self_used_cells", [llama_context_p_ctypes], ctypes.c_int32 +# ) +# def llama_kv_self_used_cells(ctx: llama_context_p, /) -> int: +# ... +# # DEPRECATED(LLAMA_API void llama_kv_self_clear( -# struct llama_context * ctx), -# "Use llama_memory_clear() instead"); -@ctypes_function( - "llama_kv_self_clear", [llama_context_p_ctypes], None -) -def llama_kv_self_clear(ctx: llama_context_p, /): - """Clear the KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) -# // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails -# // seq_id < 0 : match any sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) +# struct llama_context * ctx), "use llama_kv_cache_clear or llama_state_clear_all instead"); +# @ctypes_function( +# "llama_kv_self_clear", [llama_context_p_ctypes], None +# ) +# def llama_kv_self_clear(ctx: llama_context_p, /): +# ... +# # DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_rm() instead"); -@ctypes_function( - "llama_kv_self_seq_rm", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ], - ctypes.c_bool, -) -def llama_kv_self_seq_rm( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -) -> bool: - """Remove tokens from KV cache (DEPRECATED)""" - ... - - -# // Copy all tokens that belong to the specified sequence to another sequence -# // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( -# struct llama_context * ctx, -# llama_seq_id seq_id_src, -# llama_seq_id seq_id_dst, -# llama_pos p0, -# llama_pos p1), -# "Use llama_memory_seq_cp() instead"); -@ctypes_function( - "llama_kv_self_seq_cp", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_seq_id, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_cp( - ctx: llama_context_p, - seq_id_src: Union[llama_seq_id, int], - seq_id_dst: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - /, -): - """Copy tokens in KV cache (DEPRECATED)""" - ... - - -# // Removes all tokens that do not belong to the specified sequence -# DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_keep() instead"); -@ctypes_function( - "llama_kv_self_seq_keep", [llama_context_p_ctypes, llama_seq_id], None -) -def llama_kv_self_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): - """Keep only specified sequence in KV cache (DEPRECATED)""" - ... - - -# // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_add( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# llama_pos delta), -# "Use llama_memory_seq_add() instead"); -@ctypes_function( - "llama_kv_self_seq_add", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - llama_pos, - ], - None, -) -def llama_kv_self_seq_add( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - delta: Union[llama_pos, int], - /, -): - """Add delta to sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Integer division of the positions by factor of `d > 1` -# // If the KV cache is RoPEd, the KV data is updated accordingly: -# // - lazily on next llama_decode() -# // p0 < 0 : [0, p1] -# // p1 < 0 : [p0, inf) -# DEPRECATED(LLAMA_API void llama_kv_self_seq_div( -# struct llama_context * ctx, -# llama_seq_id seq_id, -# llama_pos p0, -# llama_pos p1, -# int d), -# "Use llama_memory_seq_div() instead"); -@ctypes_function( - "llama_kv_self_seq_div", - [ - llama_context_p_ctypes, - llama_seq_id, - llama_pos, - llama_pos, - ctypes.c_int, - ], - None, -) -def llama_kv_self_seq_div( - ctx: llama_context_p, - seq_id: Union[llama_seq_id, int], - p0: Union[llama_pos, int], - p1: Union[llama_pos, int], - d: Union[ctypes.c_int, int], - /, -): - """Divide sequence positions in KV cache (DEPRECATED)""" - ... - - -# // Returns the smallest position present in the KV cache for the specified sequence -# // This is typically non-zero only for SWA caches -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_min() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_min", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_min( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the smallest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Returns the largest position present in the KV cache for the specified sequence -# // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache -# // Return -1 if the sequence is empty -# DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( -# struct llama_context * ctx, -# llama_seq_id seq_id), -# "Use llama_memory_seq_pos_max() instead"); -@ctypes_function( - "llama_kv_self_seq_pos_max", [llama_context_p_ctypes, llama_seq_id], llama_pos -) -def llama_kv_self_seq_pos_max( - ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / -) -> int: - """Returns the largest position in KV cache for sequence (DEPRECATED)""" - ... - - -# // Defragment the KV cache -# // This will be applied: -# // - lazily on next llama_decode() -# DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), -# "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); -@ctypes_function("llama_kv_self_defrag", [llama_context_p_ctypes], None) -def llama_kv_self_defrag(ctx: llama_context_p, /): - """Defragment the KV cache (DEPRECATED)""" - ... - - -# // Check if the context supports KV cache shifting -# DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), -# "use llama_memory_can_shift() instead"); -@ctypes_function("llama_kv_self_can_shift", [llama_context_p_ctypes], ctypes.c_bool) -def llama_kv_self_can_shift(ctx: llama_context_p, /) -> bool: - """Check if the context supports KV cache shifting (DEPRECATED)""" - ... - - -# // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) -# DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), -# "simply remove this call, updates are applied lazily on the next llama_decode()"); -@ctypes_function("llama_kv_self_update", [llama_context_p_ctypes], None) -def llama_kv_self_update(ctx: llama_context_p, /): - """Apply the KV cache updates (DEPRECATED)""" - ... +# struct llama_context * ctx, +# llama_seq_id seq_id, +# llama_pos p0, +# llama_pos p1), +# "use llama_kv_cache_seq_rm instead"); +# @ctypes_function( +# "llama_kv_self_seq_rm", +# [llama_context_p_ctypes, llama_seq_id, llama_pos, llama_pos], +# ctypes.c_bool, +# ) +# def llama_kv_self_seq_rm( +# ctx: llama_context_p, +# seq_id: Union[llama_seq_id, int], +# p0: Union[llama_pos, int], +# p1: Union[llama_pos, int], +# /, +# ) -> bool: +# ... # // @@ -3577,7 +3400,7 @@ def llama_chat_apply_template( ctypes.c_int32, ) def llama_chat_builtin_templates( - output: CtypesArray[bytes], + output: CtypesArray[ctypes.c_char_p], len: Union[ctypes.c_size_t, int], /, ) -> int: @@ -3661,7 +3484,7 @@ class llama_sampler(ctypes.Structure): llama_sampler_p_ctypes, ) def llama_sampler_init( - iface: ctypes.POINTER(llama_sampler_i), ctx: llama_sampler_context_t, / + iface: CtypesPointer[llama_sampler_i], ctx: llama_sampler_context_t, / ) -> llama_sampler_p: ... @@ -3806,9 +3629,9 @@ def llama_sampler_init_dist(seed: int) -> llama_sampler_p: # /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. # DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), # "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); -@ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes) -def llama_sampler_init_softmax() -> llama_sampler_p: - ... +# @ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes) +# def llama_sampler_init_softmax() -> llama_sampler_p: +# ... # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -3972,7 +3795,7 @@ def llama_sampler_init_grammar_lazy( vocab: llama_vocab_p, grammar_str: bytes, grammar_root: bytes, - trigger_words: CtypesArray[bytes], + trigger_words: CtypesArray[ctypes.c_char_p], num_trigger_words: int, trigger_tokens: CtypesArray[llama_token], num_trigger_tokens: int, @@ -4007,7 +3830,7 @@ def llama_sampler_init_grammar_lazy_patterns( vocab: llama_vocab_p, grammar_str: bytes, grammar_root: bytes, - trigger_patterns: CtypesArray[bytes], + trigger_patterns: CtypesArray[ctypes.c_char_p], num_trigger_patterns: int, trigger_tokens: CtypesArray[llama_token], num_trigger_tokens: int, @@ -4292,6 +4115,17 @@ def llama_perf_sampler_reset(chain: llama_sampler_p, /): ... +# print a breakdown of per-device memory use via LLAMA_LOG: +@ctypes_function( + "llama_memory_breakdown_print", + [llama_context_p_ctypes], + None, +) +def llama_memory_breakdown_print(ctx: llama_context_p, /): + """Print a breakdown of per-device memory use via LLAMA_LOG.""" + ... + + # // # // training # // @@ -4319,6 +4153,8 @@ def llama_opt_param_filter_all(tensor: ctypes.c_void_p, userdata: ctypes.c_void_ # ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters # void * get_opt_pars_ud; // userdata for calculating optimizer parameters + +# enum ggml_opt_optimizer_type optimizer_type; // optimizer type # }; class llama_opt_params(ctypes.Structure): _fields_ = [ @@ -4327,6 +4163,7 @@ class llama_opt_params(ctypes.Structure): ("param_filter_ud", ctypes.c_void_p), ("get_opt_pars", ctypes.c_void_p), # ggml_opt_get_optimizer_params - not implemented here ("get_opt_pars_ud", ctypes.c_void_p), + ("optimizer_type", ctypes.c_int), # enum ggml_opt_optimizer_type ] @@ -4372,3 +4209,23 @@ def llama_opt_epoch( /, ): ... + + +# enum llama_flash_attn_type { +# LLAMA_FLASH_ATTN_TYPE_AUTO = -1, +# LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, +# LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, +# }; +LLAMA_FLASH_ATTN_TYPE_AUTO = -1 +LLAMA_FLASH_ATTN_TYPE_DISABLED = 0 +LLAMA_FLASH_ATTN_TYPE_ENABLED = 1 + + +@ctypes_function( + "llama_flash_attn_type_name", + [ctypes.c_int], + ctypes.c_char_p, +) +def llama_flash_attn_type_name(flash_attn_type: int, /) -> bytes: + """Return the name of the given llama_flash_attn_type as bytes""" + ... diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000..d788d1c0a6 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,73 @@ +import types +import numpy as np +import pytest + +from llama_cpp.llama import Llama + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +class FakeCache: + def __init__(self): + self.store = {} + + def __setitem__(self, key, value): + # store value but keep key for reconstructing state + self.store[tuple(key)] = value + + def __getitem__(self, key): + k = tuple(key) + # emulate longest prefix hit: pick longest stored key starting with k + candidates = [sk for sk in self.store.keys() if sk[: len(k)] == k] + if not candidates: + raise KeyError + best = max(candidates, key=len) + # return an object that has input_ids matching the stored key + return types.SimpleNamespace( + input_ids=np.array(best, dtype=np.intc), + scores=None, + n_tokens=len(best), + llama_state=b"x", + llama_state_size=1, + seed=0, + ) + + def __contains__(self, key): + k = tuple(key) + return any(sk[: len(k)] == k for sk in self.store.keys()) + + +def make_fake_llm(): + return Llama(MODEL, vocab_only=True, n_ctx=64, n_batch=64, n_ubatch=64, logits_all=True, verbose=False) + + +def test_cache_prefix_hit_and_store(): + llm = make_fake_llm() + llm.set_cache(FakeCache()) + + # Stub save/load_state to avoid touching C state + def fake_save_state(): + return types.SimpleNamespace(scores=None, input_ids=llm.input_ids.copy(), n_tokens=llm.n_tokens, llama_state=b"x", llama_state_size=1, seed=llm._seed) + + called = {"loaded": False} + + def fake_load_state(state): + called["loaded"] = True + + llm.save_state = fake_save_state # type: ignore + llm.load_state = fake_load_state # type: ignore + + # Force generate to produce fixed tokens, and ensure detokenize maps to letters + fixed = [111, 222, 333] + llm.detokenize = lambda tokens, prev_tokens=None, special=False: b"A" * len(tokens) # type: ignore + llm.generate = types.MethodType(lambda self, *a, **k: (t for t in fixed), llm) # type: ignore + + # First call stores cache under prompt+completion key + out1 = llm.create_completion("hello", max_tokens=len(fixed)) + assert out1["choices"][0]["text"] == "AAA" + + # Second call should find prefix key and call load_state + called["loaded"] = False + out2 = llm.create_completion("hello", max_tokens=len(fixed)) + assert out2["choices"][0]["text"] == "AAA" + assert called["loaded"] diff --git a/tests/test_completion.py b/tests/test_completion.py new file mode 100644 index 0000000000..e48ecc8064 --- /dev/null +++ b/tests/test_completion.py @@ -0,0 +1,100 @@ +import types +import numpy as np +import pytest + +import llama_cpp +from llama_cpp.llama import Llama + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def make_fake_llm(logits_all=True): + llm = Llama( + MODEL, + vocab_only=True, + n_ctx=64, + n_batch=32, + n_ubatch=32, + logits_all=logits_all, + verbose=False, + ) + # small deterministic seed + llm.set_seed(42) + return llm + + +def install_fake_detokenize(llm, mapping): + def fake_detokenize(tokens, prev_tokens=None, special=False): + return b"".join(mapping.get(int(t), b"?") for t in tokens) + llm.detokenize = fake_detokenize # type: ignore + + +def install_fake_generate(llm, seq): + def fake_generate(tokens, **kwargs): + for t in seq: + yield t + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + +def install_uniform_scores(llm, preferred_tokens): + # Fill scores so that preferred tokens dominate + rows = llm.scores.shape[0] + cols = llm.scores.shape[1] + llm.scores[:] = -100.0 + for r in range(rows): + for idx, tok in enumerate(preferred_tokens): + llm.scores[r, int(tok)] = 10.0 - idx + + +def test_create_completion_stops_and_finish_reason(): + llm = make_fake_llm(logits_all=True) + # tokens mapping to ASCII + T = { + 100: b"A", + 101: b"B", + 102: b"C", + 200: b"S", + 201: b"T", + 202: b"O", + 203: b"P", + 204: b"X", + } + install_fake_detokenize(llm, T) + # sequence yields: ABCSTOPX + seq = [100, 101, 102, 200, 201, 202, 203, 204] + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + out = llm.create_completion("prompt", max_tokens=len(seq), stop=["STOP"], temperature=0.0) + text = out["choices"][0]["text"] + assert text == "ABC" # stopped before STOP + assert out["choices"][0]["finish_reason"] == "stop" + + +def test_streaming_yields_and_concat(): + llm = make_fake_llm(logits_all=True) + T = {100: b"H", 101: b"i", 102: b"!"} + seq = [100, 101, 102] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + chunks = list(llm.create_completion("hi", max_tokens=len(seq), stream=True, temperature=0.0)) + # concatenate streamed tokens + streamed = "".join(c["choices"][0]["text"] for c in chunks if c["choices"][0]["finish_reason"] is None) + assert streamed == "Hi!" + # last chunk has finish_reason + assert chunks[-1]["choices"][0]["finish_reason"] in {"length", "stop"} + + +def test_call_equals_create_completion(): + llm = make_fake_llm(logits_all=True) + T = {100: b"O", 101: b"K"} + seq = [100, 101] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + a = llm.create_completion("q", max_tokens=len(seq), temperature=0.0) + b = llm("q", max_tokens=len(seq), temperature=0.0) + assert a["choices"][0]["text"] == b["choices"][0]["text"] diff --git a/tests/test_embed.py b/tests/test_embed.py new file mode 100644 index 0000000000..1416bf4477 --- /dev/null +++ b/tests/test_embed.py @@ -0,0 +1,61 @@ +import numpy as np +import pytest + +from llama_cpp.llama import Llama + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def test_embed_requires_flag(): + llm = Llama(MODEL, vocab_only=True, embedding=False, verbose=False) + with pytest.raises(RuntimeError): + llm.embed("x") + + +def test_embed_shapes_and_truncate(monkeypatch): + # Use embedding=True but monkeypatch internal calls to avoid heavy work + llm = Llama(MODEL, vocab_only=True, embedding=True, n_batch=4, verbose=False) + + # Monkeypatch batch.add_sequence to just track sizes + added = [] + real_add = llm._batch.add_sequence + def add_sequence_stub(tokens, seq_id, logits_all): + added.append((len(tokens), seq_id)) + return real_add(tokens, seq_id, logits_all) + llm._batch.add_sequence = add_sequence_stub # type: ignore + + # Monkeypatch context.decode to no-op and set fake embeddings + def fake_decode(batch): + pass + llm._ctx.decode = fake_decode # type: ignore + + # Mock getters used in embed to return deterministic shapes + import llama_cpp.llama_cpp as C + def fake_get_embeddings_seq(ctx, i): + # return a ctypes array-like of length n_embd + class View: + def __getitem__(self, sl): + # produce stable values + n = llm.n_embd() + return [0.0] * (sl.stop - sl.start) + return View() + C.llama_get_embeddings_seq = fake_get_embeddings_seq # type: ignore + + out = llm.embed(["a" * 100, "b"], truncate=True) + assert isinstance(out, list) and len(out) == 2 + + +def test_normalize_output(monkeypatch): + llm = Llama(MODEL, vocab_only=True, embedding=True, verbose=False) + # Simulate embeddings pointer returns unit vectors + import llama_cpp.llama_cpp as C + def fake_get_embeddings_seq(ctx, i): + class View: + def __getitem__(self, sl): + n = llm.n_embd() + return [1.0] + [0.0] * (sl.stop - sl.start - 1) + return View() + C.llama_get_embeddings_seq = fake_get_embeddings_seq # type: ignore + + v = llm.embed("x", normalize=True) + assert isinstance(v, list) and abs(sum(x * x for x in v) - 1.0) < 1e-6 diff --git a/tests/test_kv_overrides.py b/tests/test_kv_overrides.py new file mode 100644 index 0000000000..f1c1846d2e --- /dev/null +++ b/tests/test_kv_overrides.py @@ -0,0 +1,44 @@ +import ctypes +import pytest + +import llama_cpp +from llama_cpp.llama import Llama + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def test_kv_overrides_types_and_limits(): + # string <= 128 ok, >128 raises + ok_str = "x" * 10 + bad_str = "y" * 129 + + # Construct with kv_overrides + llm = Llama( + MODEL, + vocab_only=True, + kv_overrides={ + "bool_key": True, + "int_key": 1, + "float_key": 1.5, + "str_key": ok_str, + }, + verbose=False, + ) + + arr = llm._kv_overrides_array + # Ensure tags + tags = {b"bool_key": llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL, + b"int_key": llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT, + b"float_key": llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT, + b"str_key": llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR} + seen = set() + for item in arr: + if item.key == b"\x00": + break + seen.add(item.key) + assert item.tag == tags[item.key] + assert set(tags.keys()).issubset(seen) + + # oversize should raise + with pytest.raises(ValueError): + Llama(MODEL, vocab_only=True, kv_overrides={"k": bad_str}, verbose=False) diff --git a/tests/test_processors.py b/tests/test_processors.py new file mode 100644 index 0000000000..df70fad23e --- /dev/null +++ b/tests/test_processors.py @@ -0,0 +1,38 @@ +import numpy as np +from llama_cpp.llama import Llama, LogitsProcessorList, MinTokensLogitsProcessor + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def test_processors_order_and_min_tokens(): + llm = Llama(MODEL, vocab_only=True, n_ctx=32, logits_all=True, verbose=False) + + # fabricate logits + vocab = llm.n_vocab() + logits = np.zeros((1, vocab), dtype=np.float32) + # Make eos identifiable + eos = llm.token_eos() + + # A processor that sets eos logit to a high value + def favor_eos(input_ids, scores): + scores[eos] = 1000.0 + return scores + + # MinTokens should zero out eos until min_tokens satisfied + mt = MinTokensLogitsProcessor(min_tokens=3, token_eos=eos) + + procs = LogitsProcessorList([favor_eos, mt]) + # First call defines prompt length; simulate generation steps + input_ids = np.array([1, 2, 3], dtype=np.intc) + mt.prompt_tokens = len(input_ids) + + # before reaching 3 generated tokens, eos must be -inf + s = np.copy(logits[0]) + s = procs(input_ids, s) + assert np.isneginf(s[eos]) + + # After 3 tokens, eos can be positive + mt.prompt_tokens = len(input_ids) - 3 + s2 = np.copy(logits[0]) + s2 = procs(input_ids, s2) + assert s2[eos] == 1000.0 diff --git a/tests/test_struct_layout.py b/tests/test_struct_layout.py new file mode 100644 index 0000000000..e476cb2e05 --- /dev/null +++ b/tests/test_struct_layout.py @@ -0,0 +1,59 @@ +import multiprocessing + +import llama_cpp + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def test_context_params_flash_attn_type_default_and_shim(): + cparams = llama_cpp.llama_context_default_params() + # default should be AUTO + assert cparams.flash_attn_type == llama_cpp.LLAMA_FLASH_ATTN_TYPE_AUTO + + # shim maps bool to enum + cparams.flash_attn = True + assert cparams.flash_attn_type == llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED + + cparams.flash_attn = False + assert cparams.flash_attn_type == llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED + + # boolean tail fields should be present and bools + for name in [ + "embeddings", + "offload_kqv", + "no_perf", + "op_offload", + "swa_full", + "kv_unified", + ]: + assert hasattr(cparams, name), f"missing field: {name}" + assert isinstance(getattr(cparams, name), bool), f"{name} not bool" + + +def test_model_params_has_no_host_bool(): + mparams = llama_cpp.llama_model_default_params() + assert hasattr(mparams, "no_host") + assert isinstance(mparams.no_host, bool) + + +def test_opt_params_has_optimizer_type_field(): + lopt = llama_cpp.llama_opt_params() + # should be settable and int-like + lopt.optimizer_type = 0 + assert isinstance(lopt.optimizer_type, int) + + +def test_high_level_llama_flash_attn_shim_works(): + # Use tiny vocab-only model for lightweight construction + llm = llama_cpp.Llama( + MODEL, + vocab_only=True, + n_ctx=16, + n_batch=16, + n_ubatch=16, + n_threads=multiprocessing.cpu_count(), + n_threads_batch=multiprocessing.cpu_count(), + flash_attn=True, + verbose=False, + ) + assert llm.context_params.flash_attn_type != llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4227c9be42..e60f241eac 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4227c9be4268ac844921b90f31595f81236bd317 +Subproject commit e60f241eacec42d3bd7c9edd37d236ebf35132a8 From 5dab43cd3f41773e01684735a5e19d63e93005b6 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 21:57:03 +0800 Subject: [PATCH 02/10] feat: Add extend method to LogitsProcessorList for batch processing of processors --- llama_cpp/llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d4a92f687a..feb5498011 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -25,6 +25,7 @@ Deque, Callable, Dict, + Iterable, ) from collections import deque from pathlib import Path @@ -2427,6 +2428,10 @@ def __call__( scores = processor(input_ids, scores) return scores + def extend(self, processors: Iterable[LogitsProcessor]): + super().extend(processors) + return self + StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] From 888a734d23f5610647d895a6453d8db34fd99075 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:28:55 +0800 Subject: [PATCH 03/10] feat: Add tests for context parameters and model quantization defaults --- tests/test_struct_layout.py | 47 +++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_struct_layout.py b/tests/test_struct_layout.py index e476cb2e05..cfb70b66f8 100644 --- a/tests/test_struct_layout.py +++ b/tests/test_struct_layout.py @@ -30,12 +30,59 @@ def test_context_params_flash_attn_type_default_and_shim(): assert isinstance(getattr(cparams, name), bool), f"{name} not bool" +def test_context_params_rope_and_kv_defaults(): + cparams = llama_cpp.llama_context_default_params() + assert cparams.rope_scaling_type == llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + assert cparams.type_k == llama_cpp.GGML_TYPE_F16 + assert cparams.type_v == llama_cpp.GGML_TYPE_F16 + assert hasattr(cparams, "kv_unified") + assert isinstance(cparams.kv_unified, bool) + + +def test_context_params_attention_and_sequence_fields(): + cparams = llama_cpp.llama_context_default_params() + assert hasattr(cparams, "attention_type") + assert isinstance(cparams.attention_type, int) + assert hasattr(cparams, "n_seq_max") + assert cparams.n_seq_max >= 1 + + +def test_context_params_attention_enum_matches_bindings(): + cparams = llama_cpp.llama_context_default_params() + attention_constants = [ + getattr(llama_cpp, name) + for name in dir(llama_cpp) + if name.startswith("LLAMA_ATTENTION_TYPE_") + ] + assert attention_constants, "expected LLAMA_ATTENTION_TYPE_* constants" + assert cparams.attention_type in attention_constants + + +def test_model_quantize_params_boolean_fields(): + qparams = llama_cpp.llama_model_quantize_default_params() + for name in [ + "allow_requantize", + "quantize_output_tensor", + "only_copy", + "pure", + "keep_split", + ]: + assert hasattr(qparams, name), f"missing field: {name}" + assert isinstance(getattr(qparams, name), bool), f"{name} not bool" + + def test_model_params_has_no_host_bool(): mparams = llama_cpp.llama_model_default_params() assert hasattr(mparams, "no_host") assert isinstance(mparams.no_host, bool) +def test_model_params_progress_callback_defaults_none(): + mparams = llama_cpp.llama_model_default_params() + assert hasattr(mparams, "progress_callback") + assert callable(mparams.progress_callback) + + def test_opt_params_has_optimizer_type_field(): lopt = llama_cpp.llama_opt_params() # should be settable and int-like From 26defad2e93a437534f92b5efe9cc1280d6680eb Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:29:53 +0800 Subject: [PATCH 04/10] feat: Enhance model path retrieval with improved error handling for huggingface_hub downloads --- pyproject.toml | 3 +++ tests/test_llama.py | 22 ++++++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f5ae7b59c7..17c2c52874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,3 +80,6 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/" [tool.pytest.ini_options] testpaths = "tests" +markers = [ + "slow: marks tests that require network or long downloads", +] diff --git a/tests/test_llama.py b/tests/test_llama.py index 0a1a9f5ad3..34fe61962b 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -5,6 +5,10 @@ from scipy.special import log_softmax from huggingface_hub import hf_hub_download +try: # huggingface_hub >=0.14 + from huggingface_hub.utils import LocalEntryNotFoundError +except ImportError: # pragma: no cover - older hub versions + LocalEntryNotFoundError = FileNotFoundError # type: ignore import pytest @@ -60,10 +64,18 @@ def test_llama_cpp_tokenization(): def llama_cpp_model_path(): repo_id = "Qwen/Qwen2-0.5B-Instruct-GGUF" filename = "qwen2-0_5b-instruct-q8_0.gguf" - model_path = hf_hub_download(repo_id, filename) - return model_path - - + try: + return hf_hub_download(repo_id, filename, local_files_only=True) + except LocalEntryNotFoundError: + try: + return hf_hub_download(repo_id, filename, local_files_only=False) + except Exception as exc: + pytest.skip(f"requires network access to download model: {exc}") + except Exception as exc: + pytest.skip(f"unable to load cached model: {exc}") + + +@pytest.mark.slow def test_real_model(llama_cpp_model_path): import os assert os.path.exists(llama_cpp_model_path) @@ -114,6 +126,7 @@ def test_real_model(llama_cpp_model_path): output_text = model.detokenize(output, special=True) assert output_text == b" over the lazy dog" +@pytest.mark.slow def test_real_llama(llama_cpp_model_path): model = llama_cpp.Llama( llama_cpp_model_path, @@ -218,6 +231,7 @@ def logit_processor_func(input_ids, logits): assert number_1 == number_3 +@pytest.mark.slow def test_real_llama_embeddings(llama_cpp_model_path): model = llama_cpp.Llama( llama_cpp_model_path, From 43e0b9abc981a0abd665843afa215702a18e6f92 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:29:59 +0800 Subject: [PATCH 05/10] feat: Enhance completion tests with additional scenarios for multiple stops, logit bias, and temperature handling --- tests/test_completion.py | 252 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) diff --git a/tests/test_completion.py b/tests/test_completion.py index e48ecc8064..6d3e7730e2 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -31,7 +31,14 @@ def fake_detokenize(tokens, prev_tokens=None, special=False): def install_fake_generate(llm, seq): def fake_generate(tokens, **kwargs): + prompt_len = len(tokens) + llm.input_ids[:prompt_len] = np.array(tokens, dtype=np.intc) + llm.scores[:prompt_len, :] = 0.0 + llm.n_tokens = prompt_len for t in seq: + llm.input_ids[llm.n_tokens] = int(t) + llm.scores[llm.n_tokens, :] = 0.0 + llm.n_tokens += 1 yield t llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore @@ -85,6 +92,7 @@ def test_streaming_yields_and_concat(): assert streamed == "Hi!" # last chunk has finish_reason assert chunks[-1]["choices"][0]["finish_reason"] in {"length", "stop"} + assert chunks[-1]["choices"][0]["text"] == "" def test_call_equals_create_completion(): @@ -98,3 +106,247 @@ def test_call_equals_create_completion(): a = llm.create_completion("q", max_tokens=len(seq), temperature=0.0) b = llm("q", max_tokens=len(seq), temperature=0.0) assert a["choices"][0]["text"] == b["choices"][0]["text"] + + +def test_completion_multiple_stops_pick_first_match(): + llm = make_fake_llm(logits_all=True) + T = {100: b"H", 101: b"E", 102: b"L", 103: b"O"} + seq = [100, 101, 102, 102, 103] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + out = llm.create_completion( + "hi", + max_tokens=len(seq), + stop=["XYZ", "LO"], + temperature=0.0, + ) + text = out["choices"][0]["text"] + assert text == "HEL" + assert out["choices"][0]["finish_reason"] == "stop" + + +def test_completion_logit_bias_adjusts_scores(): + llm = make_fake_llm(logits_all=True) + preferred = 120 + T = {preferred: b"Z"} + install_fake_detokenize(llm, T) + + captured = {} + + def fake_generate(tokens, **kwargs): + captured["kwargs"] = kwargs + yield preferred + + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + out = llm.create_completion( + "prompt", + max_tokens=1, + logit_bias={preferred: 5.0}, + temperature=0.0, + ) + + assert out["choices"][0]["text"] == "Z" + logits_processor = captured["kwargs"]["logits_processor"] + assert isinstance(logits_processor, llama_cpp.LogitsProcessorList) + + base = np.zeros(llm.n_vocab(), dtype=np.float32) + adjusted = logits_processor(np.array([0], dtype=np.intc), base) + assert pytest.approx(5.0) == float(adjusted[preferred]) + assert np.count_nonzero(np.abs(adjusted) > 1e-6) == 1 + + +def test_completion_passes_temperature_to_generator(): + llm = make_fake_llm(logits_all=True) + token = 130 + T = {token: b"T"} + install_fake_detokenize(llm, T) + + recorded = {} + + def fake_generate(tokens, **kwargs): + recorded["temp"] = kwargs.get("temp") + yield token + + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + out = llm.create_completion("prompt", max_tokens=1, temperature=0.0) + + assert out["choices"][0]["text"] == "T" + assert recorded["temp"] == 0.0 + + +def test_completion_logprobs_requires_logits_all(): + llm = make_fake_llm(logits_all=False) + T = {100: b"A"} + seq = [100] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + with pytest.raises(ValueError): + llm.create_completion("prompt", max_tokens=1, logprobs=1, temperature=0.0) + + +def test_completion_streaming_logprobs_include_tokens(): + llm = make_fake_llm(logits_all=True) + T = {100: b"H", 101: b"i"} + seq = [100, 101] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + llm.scores[:] = 0.0 + + chunks = list( + llm.create_completion( + "hi", + max_tokens=len(seq), + stream=True, + logprobs=2, + temperature=0.0, + ) + ) + + partial = [c for c in chunks if c["choices"][0]["finish_reason"] is None] + assert partial, "expected streaming payloads before final chunk" + logprobs = partial[0]["choices"][0]["logprobs"] + assert logprobs is not None + assert logprobs["tokens"] + assert isinstance(logprobs["top_logprobs"][0], dict) + + +def test_completion_preserves_user_logits_processor(): + llm = make_fake_llm(logits_all=True) + token = 140 + T = {token: b"Y"} + install_fake_detokenize(llm, T) + + invoked = {} + + def fake_generate(tokens, **kwargs): + invoked["passed"] = kwargs.get("logits_processor") + yield token + + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + lp = llama_cpp.LogitsProcessorList([lambda input_ids, scores: scores]) + + out = llm.create_completion( + "prompt", + max_tokens=1, + logits_processor=lp, + temperature=0.0, + ) + + assert out["choices"][0]["text"] == "Y" + assert invoked["passed"] is lp + + +def test_completion_combines_user_processor_and_logit_bias(): + llm = make_fake_llm(logits_all=True) + token = 150 + T = {token: b"Z"} + install_fake_detokenize(llm, T) + + invoked = {} + + def fake_generate(tokens, **kwargs): + lp = kwargs.get("logits_processor") + base = np.zeros(llm.n_vocab(), dtype=np.float32) + processed = lp(np.array(tokens, dtype=np.intc), base) + invoked["scores"] = processed + invoked["lp"] = lp + yield token + + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + def user_processor(input_ids, scores): + updated = np.copy(scores) + updated[token] = 1.0 + return updated + + lp = llama_cpp.LogitsProcessorList([user_processor]) + + out = llm.create_completion( + "prompt", + max_tokens=1, + temperature=0.0, + logits_processor=lp, + logit_bias={token: 2.0}, + ) + + assert out["choices"][0]["text"] == "Z" + assert invoked["lp"] is lp + assert len(lp) == 2 + assert invoked["scores"][token] == pytest.approx(3.0) + + +def install_recording_generate(llm, seq, recorder): + def fake_generate(tokens, **kwargs): + recorder["prompt_tokens"] = list(tokens) + prompt_len = len(tokens) + llm.input_ids[:prompt_len] = np.array(tokens, dtype=np.intc) + llm.scores[:prompt_len, :] = 0.0 + llm.n_tokens = prompt_len + for t in seq: + llm.input_ids[llm.n_tokens] = int(t) + llm.scores[llm.n_tokens, :] = 0.0 + llm.n_tokens += 1 + yield t + + llm.generate = types.MethodType(lambda self, *a, **k: fake_generate(*a, **k), llm) # type: ignore + + +def test_completion_suffix_modifies_prompt_without_leaking(): + token = 160 + T = {token: b"Q"} + seq = [token] + + plain = make_fake_llm(logits_all=True) + install_fake_detokenize(plain, T) + install_uniform_scores(plain, seq) + recorded_plain = {} + install_recording_generate(plain, seq, recorded_plain) + out_plain = plain.create_completion("seed", max_tokens=1, temperature=0.0) + + with_suffix = make_fake_llm(logits_all=True) + install_fake_detokenize(with_suffix, T) + install_uniform_scores(with_suffix, seq) + recorded_suffix = {} + install_recording_generate(with_suffix, seq, recorded_suffix) + out_suffix = with_suffix.create_completion( + "seed", + suffix="SUFFIX", + max_tokens=1, + temperature=0.0, + ) + + assert out_plain["choices"][0]["text"] == "Q" + assert out_suffix["choices"][0]["text"] == "Q" + assert len(recorded_suffix["prompt_tokens"]) > len(recorded_plain["prompt_tokens"]) + assert not out_suffix["choices"][0]["text"].endswith("SUFFIX") + + +def test_completion_echo_includes_prompt_and_sets_first_logprob_none(): + llm = make_fake_llm(logits_all=True) + T = {100: b"A", 101: b"B"} + seq = [100, 101] + install_fake_detokenize(llm, T) + install_fake_generate(llm, seq) + install_uniform_scores(llm, seq) + + out = llm.create_completion( + "hi", + max_tokens=len(seq), + echo=True, + logprobs=2, + temperature=0.0, + ) + + choice = out["choices"][0] + assert choice["text"].startswith("hi") + logprobs = choice["logprobs"] + assert logprobs is not None + assert logprobs["token_logprobs"][0] is None From 09913d78c2928104f4aac03bd0d7343f0c723691 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:30:04 +0800 Subject: [PATCH 06/10] feat: Refactor embedding tests to use monkeypatch for better isolation and add new tests for multi-input precision and return count --- tests/test_embed.py | 159 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 153 insertions(+), 6 deletions(-) diff --git a/tests/test_embed.py b/tests/test_embed.py index 1416bf4477..a6700d11e6 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -22,12 +22,11 @@ def test_embed_shapes_and_truncate(monkeypatch): def add_sequence_stub(tokens, seq_id, logits_all): added.append((len(tokens), seq_id)) return real_add(tokens, seq_id, logits_all) - llm._batch.add_sequence = add_sequence_stub # type: ignore + monkeypatch.setattr(llm._batch, "add_sequence", add_sequence_stub, raising=False) # Monkeypatch context.decode to no-op and set fake embeddings - def fake_decode(batch): - pass - llm._ctx.decode = fake_decode # type: ignore + monkeypatch.setattr(llm._ctx, "decode", lambda batch: None, raising=False) + monkeypatch.setattr(llm._ctx, "kv_cache_clear", lambda: None, raising=False) # Mock getters used in embed to return deterministic shapes import llama_cpp.llama_cpp as C @@ -39,7 +38,7 @@ def __getitem__(self, sl): n = llm.n_embd() return [0.0] * (sl.stop - sl.start) return View() - C.llama_get_embeddings_seq = fake_get_embeddings_seq # type: ignore + monkeypatch.setattr(C, "llama_get_embeddings_seq", fake_get_embeddings_seq, raising=False) out = llm.embed(["a" * 100, "b"], truncate=True) assert isinstance(out, list) and len(out) == 2 @@ -55,7 +54,155 @@ def __getitem__(self, sl): n = llm.n_embd() return [1.0] + [0.0] * (sl.stop - sl.start - 1) return View() - C.llama_get_embeddings_seq = fake_get_embeddings_seq # type: ignore + monkeypatch.setattr(C, "llama_get_embeddings_seq", fake_get_embeddings_seq, raising=False) v = llm.embed("x", normalize=True) assert isinstance(v, list) and abs(sum(x * x for x in v) - 1.0) < 1e-6 + + +def test_embed_multi_inputs_precision(monkeypatch): + llm = Llama(MODEL, vocab_only=True, embedding=True, n_batch=8, verbose=False) + monkeypatch.setattr(llm, "n_embd", lambda: 4) + + added = [] + real_add = llm._batch.add_sequence + + def add_sequence_stub(tokens, seq_id, logits_all): + added.append(len(tokens)) + return real_add(tokens, seq_id, logits_all) + + monkeypatch.setattr(llm._batch, "add_sequence", add_sequence_stub, raising=False) + monkeypatch.setattr(llm._ctx, "decode", lambda batch: None, raising=False) + monkeypatch.setattr(llm._ctx, "kv_cache_clear", lambda: None, raising=False) + + import llama_cpp.llama_cpp as C + + monkeypatch.setattr(llm, "pooling_type", lambda: C.LLAMA_POOLING_TYPE_MEAN) + + def fake_get_embeddings_seq(ctx, i): + base = float(i + 1) + + class View: + def __getitem__(self, sl): + length = sl.stop - sl.start + return [base + 0.01 * j for j in range(length)] + + return View() + + monkeypatch.setattr(C, "llama_get_embeddings_seq", fake_get_embeddings_seq, raising=False) + + inputs = ["hello", "world"] + embeddings = llm.embed(inputs, truncate=False, return_count=False) + + assert isinstance(embeddings, list) and len(embeddings) == len(inputs) + + first_three = [round(v, 2) for v in embeddings[0][:3]] + second_three = [round(v, 2) for v in embeddings[1][:3]] + assert first_three == [1.0, 1.01, 1.02] + assert second_three == [2.0, 2.01, 2.02] + + expected_tokens = [len(llm.tokenize(text.encode("utf-8"))) for text in inputs] + assert added == expected_tokens + + +def test_embed_return_count_reports_total_tokens(monkeypatch): + llm = Llama(MODEL, vocab_only=True, embedding=True, n_batch=8, verbose=False) + monkeypatch.setattr(llm, "n_embd", lambda: 3) + + added = [] + real_add = llm._batch.add_sequence + + def add_sequence_stub(tokens, seq_id, logits_all): + added.append(len(tokens)) + return real_add(tokens, seq_id, logits_all) + + monkeypatch.setattr(llm._batch, "add_sequence", add_sequence_stub, raising=False) + monkeypatch.setattr(llm._ctx, "decode", lambda batch: None, raising=False) + monkeypatch.setattr(llm._ctx, "kv_cache_clear", lambda: None, raising=False) + + import llama_cpp.llama_cpp as C + + monkeypatch.setattr(llm, "pooling_type", lambda: C.LLAMA_POOLING_TYPE_MEAN) + + def fake_get_embeddings_seq(ctx, i): + base = float(i + 1) + + class View: + def __getitem__(self, sl): + length = sl.stop - sl.start + return [base + 0.1 * j for j in range(length)] + + return View() + + monkeypatch.setattr(C, "llama_get_embeddings_seq", fake_get_embeddings_seq, raising=False) + + inputs = ["first", "second"] + expected_counts = [len(llm.tokenize(text.encode("utf-8"))) for text in inputs] + + embeddings, total_tokens = llm.embed(inputs, return_count=True) + + assert isinstance(embeddings, list) and len(embeddings) == len(inputs) + assert total_tokens == sum(expected_counts) + assert added == expected_counts + + +def test_embed_pooling_none_returns_token_embeddings(monkeypatch): + llm = Llama(MODEL, vocab_only=True, embedding=True, n_batch=8, verbose=False) + llm.vocab_only = False # type: ignore[attr-defined] + monkeypatch.setattr(llm, "n_embd", lambda: 3) + + recorded_logits_all = [] + real_add = llm._batch.add_sequence + + def add_sequence_stub(tokens, seq_id, logits_all): + recorded_logits_all.append(logits_all) + return real_add(tokens, seq_id, logits_all) + + monkeypatch.setattr(llm._batch, "add_sequence", add_sequence_stub, raising=False) + monkeypatch.setattr(llm._ctx, "decode", lambda batch: None, raising=False) + monkeypatch.setattr(llm._ctx, "kv_cache_clear", lambda: None, raising=False) + + import llama_cpp.llama_cpp as C + + monkeypatch.setattr(llm, "pooling_type", lambda: C.LLAMA_POOLING_TYPE_NONE) + + inputs = ["alpha", "beta"] + token_counts = [len(llm.tokenize(text.encode("utf-8"))) for text in inputs] + + flat = [] + for idx, count in enumerate(token_counts): + base = float(idx + 1) + for _ in range(count): + flat.extend([base, base + 0.1, base + 0.2]) + + def fake_get_embeddings(ctx): + return flat + + monkeypatch.setattr(C, "llama_get_embeddings", fake_get_embeddings, raising=False) + + embeddings = llm.embed(inputs, truncate=True, return_count=False) + + assert isinstance(embeddings, list) and len(embeddings) == len(inputs) + for emb, count in zip(embeddings, token_counts): + assert len(emb) == count + for vec in emb: + assert len(vec) == 3 + assert all(recorded_logits_all) + + +def test_embed_raises_without_truncate_when_over_batch(monkeypatch): + llm = Llama(MODEL, vocab_only=True, embedding=True, n_batch=4, verbose=False) + + import llama_cpp.llama_cpp as C + + monkeypatch.setattr(C, "llama_get_embeddings_seq", lambda ctx, i: lambda sl: [0.0] * (sl.stop - sl.start), raising=False) + monkeypatch.setattr(llm._ctx, "decode", lambda batch: None, raising=False) + monkeypatch.setattr(llm._ctx, "kv_cache_clear", lambda: None, raising=False) + + def fake_tokenize(data): + return list(range(llm.n_batch + 1)) + + monkeypatch.setattr(llm, "tokenize", fake_tokenize) + + with pytest.raises(ValueError): + llm.embed("x", truncate=False) From c524fb55ed3ca44ad67b8a703584296723efa70c Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:30:09 +0800 Subject: [PATCH 07/10] feat: Enhance grammar parsing and validation with additional tests for schema handling and error raising --- tests/test_llama_grammar.py | 223 ++++++++++++++++++++++++++++++++++-- 1 file changed, 216 insertions(+), 7 deletions(-) diff --git a/tests/test_llama_grammar.py b/tests/test_llama_grammar.py index 34ef2874df..31bd3e9cd0 100644 --- a/tests/test_llama_grammar.py +++ b/tests/test_llama_grammar.py @@ -1,5 +1,57 @@ -import llama_cpp import json +import types +from typing import Dict, List, Set + +import pytest + +import llama_cpp +import llama_cpp.llama_grammar as llama_grammar + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def _parse_simple_gbnf(grammar_str: str) -> Dict[str, List[List[str]]]: + rules: Dict[str, List[List[str]]] = {} + for raw in grammar_str.strip().splitlines(): + line = raw.strip() + if not line or line.startswith("#"): + continue + if "::=" not in line: + continue + name, expr = line.split("::=", 1) + name = name.strip() + productions: List[List[str]] = [] + for alt in expr.split("|"): + tokens = [token for token in alt.strip().split(" ") if token] + cleaned = [token.strip('"') for token in tokens] + productions.append(cleaned) + rules[name] = productions + return rules + + +def _match_rule(rules: Dict[str, List[List[str]]], symbol: str, text: str, index: int) -> Set[int]: + if symbol not in rules: + literal = symbol + if text.startswith(literal, index): + return {index + len(literal)} + return set() + + outcomes: Set[int] = set() + for production in rules[symbol]: + positions: Set[int] = {index} + for part in production: + next_positions: Set[int] = set() + for pos in positions: + next_positions.update(_match_rule(rules, part, text, pos)) + positions = next_positions + if not positions: + break + outcomes.update(positions) + return outcomes + + +def _accepts(rules: Dict[str, List[List[str]]], root: str, candidate: str) -> bool: + return len(candidate) in _match_rule(rules, root, candidate, 0) tree = """ leaf ::= "." @@ -10,9 +62,13 @@ def test_grammar_from_string(): grammar = llama_cpp.LlamaGrammar.from_string(tree) - # assert grammar._n_rules == 3 - # assert grammar._start_rule_index == 2 - # assert grammar.grammar is not None + assert grammar._root == llama_grammar.LLAMA_GRAMMAR_DEFAULT_ROOT + assert grammar._grammar.strip() == tree.strip() + + parsed = _parse_simple_gbnf(grammar._grammar) + assert _accepts(parsed, grammar._root, ".") + assert _accepts(parsed, grammar._root, "(..)") + assert not _accepts(parsed, grammar._root, "(.)") def test_composed_pydantic_grammar(): @@ -48,8 +104,10 @@ class B(BaseModel): } grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) - - # assert grammar.grammar is not None + lines = grammar._grammar.splitlines() + assert any(line.startswith("A ::=") for line in lines) + assert any("A-a-kv" in line and "integer" in line for line in lines) + assert any("b-kv" in line and "integer" in line for line in lines) def test_grammar_anyof(): @@ -74,5 +132,156 @@ def test_grammar_anyof(): } grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch)) + lines = grammar._grammar.splitlines() + assert any("unit ::= unit-0 | null" in line for line in lines) + assert any('unit-0 ::= "\\"celsius\\"" | "\\"fahrenheit\\""' in line for line in lines) + + +def test_grammar_invalid_schema_raises(): + bad_schema = {"type": "string", "pattern": "abc"} + with pytest.raises(AssertionError): + llama_cpp.LlamaGrammar.from_json_schema(json.dumps(bad_schema)) + + +def _make_stub_llm(monkeypatch): + llm = llama_cpp.Llama( + MODEL, + vocab_only=True, + n_ctx=64, + n_batch=32, + n_ubatch=32, + logits_all=True, + verbose=False, + ) + + monkeypatch.setattr( + llm, + "tokenize", + lambda data, add_bos=False, special=True: [1, 2, 3], + raising=False, + ) + + recorded = {} + + def fake_create_completion(self, prompt, **kwargs): + recorded["prompt"] = list(prompt) + recorded["kwargs"] = kwargs + return { + "id": "cmpl-test", + "object": "text_completion", + "created": 0, + "model": "stub", + "choices": [ + { + "text": json.dumps({"ok": True}), + "index": 0, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": len(prompt), + "completion_tokens": 0, + "total_tokens": len(prompt), + }, + } + + monkeypatch.setattr( + llm, + "create_completion", + types.MethodType(fake_create_completion, llm), + raising=False, + ) + + return llm, recorded + + +def test_chat_completion_passes_explicit_grammar(monkeypatch): + llm, recorded = _make_stub_llm(monkeypatch) + explicit = llama_cpp.LlamaGrammar.from_string('root ::= "hi"') + + result = llm.create_chat_completion( + messages=[{"role": "user", "content": "hello"}], + grammar=explicit, + ) + + assert recorded["kwargs"]["grammar"] is explicit + assert result["choices"][0]["message"]["content"] == json.dumps({"ok": True}) + + +def test_chat_completion_response_format_builds_grammar(monkeypatch): + llm, recorded = _make_stub_llm(monkeypatch) + schema = { + "type": "object", + "properties": {"foo": {"type": "integer"}}, + "required": ["foo"], + } + + llm.create_chat_completion( + messages=[{"role": "user", "content": "hello"}], + response_format={"type": "json_object", "schema": schema}, + ) + + grammar = recorded["kwargs"].get("grammar") + assert isinstance(grammar, llama_cpp.LlamaGrammar) + assert "foo" in grammar._grammar + + +def test_chat_completion_tool_choice_uses_tool_schema(monkeypatch): + llm, recorded = _make_stub_llm(monkeypatch) + tools = [ + { + "type": "function", + "function": { + "name": "extract", + "parameters": { + "type": "object", + "properties": {"bar": {"type": "string"}}, + }, + }, + } + ] + + llm.create_chat_completion( + messages=[{"role": "user", "content": "hi"}], + tools=tools, + tool_choice={"type": "function", "function": {"name": "extract"}}, + ) + + grammar = recorded["kwargs"].get("grammar") + assert isinstance(grammar, llama_cpp.LlamaGrammar) + assert "bar" in grammar._grammar + + +def test_recursive_schema_conversion_handles_depth(): + depth = 16 + defs: Dict[str, Dict[str, object]] = {} + for level in range(depth): + obj: Dict[str, object] = { + "type": "object", + "properties": { + "value": {"type": "integer"}, + }, + "required": ["value"], + "additionalProperties": False, + } + if level < depth - 1: + obj["properties"]["next"] = {"$ref": f"#/$defs/Node{level + 1}"} + obj["required"].append("next") + else: + obj["properties"]["next"] = {"type": "null"} + defs[f"Node{level}"] = obj + + schema = { + "$defs": defs, + "type": "object", + "properties": {"head": {"$ref": "#/$defs/Node0"}}, + "required": ["head"], + "additionalProperties": False, + } + + grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) - # assert grammar.grammar is not None + for level in range(depth): + assert f"Node{level}" in grammar._grammar + assert grammar._grammar.count("value") >= depth From 5f52bd5b4b96e5244b9226c93d831aa13ebc51a9 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 22:44:12 +0800 Subject: [PATCH 08/10] feat: Add comprehensive tests for server contract including completion, chat, and embedding endpoints --- tests/test_server_contract.py | 385 ++++++++++++++++++++++++++++++++++ 1 file changed, 385 insertions(+) create mode 100644 tests/test_server_contract.py diff --git a/tests/test_server_contract.py b/tests/test_server_contract.py new file mode 100644 index 0000000000..250ef5e337 --- /dev/null +++ b/tests/test_server_contract.py @@ -0,0 +1,385 @@ +import json +import time +from typing import Dict, Iterable, Iterator, List, Optional, Union + +import pytest +from fastapi.testclient import TestClient + +from llama_cpp.server import app as app_module +from llama_cpp.server.app import create_app +from llama_cpp.server.settings import ModelSettings, ServerSettings + + +class FakeLlama: + def __init__(self, alias: str) -> None: + self.alias = alias + + def __call__(self, prompt: Union[str, List[str]], stream: bool = False, **kwargs): + return self.create_completion(prompt=prompt, stream=stream, **kwargs) + + def create_completion( + self, + prompt: Union[str, List[str]], + stream: bool = False, + **kwargs, + ): + prompt_text = prompt if isinstance(prompt, str) else "".join(prompt) + if prompt_text == "fail": + raise RuntimeError("intentional failure") + if stream: + return self._stream_completion(prompt_text) + created = int(time.time()) + return { + "id": f"cmpl-{self.alias}", + "object": "text_completion", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "text": f"Echo: {prompt_text}", + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": len(prompt_text), + "completion_tokens": len(prompt_text) + 6, + "total_tokens": (len(prompt_text) * 2) + 6, + }, + } + + def _stream_completion(self, prompt_text: str) -> Iterator[Dict[str, object]]: + created = int(time.time()) + chunks = [ + { + "id": f"cmpl-{self.alias}", + "object": "text_completion.chunk", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "text": "Echo: ", + "finish_reason": None, + "logprobs": None, + } + ], + }, + { + "id": f"cmpl-{self.alias}", + "object": "text_completion.chunk", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "text": prompt_text, + "finish_reason": None, + "logprobs": None, + } + ], + }, + { + "id": f"cmpl-{self.alias}", + "object": "text_completion.chunk", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "text": "", + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": len(prompt_text), + "completion_tokens": len(prompt_text) + 6, + "total_tokens": (len(prompt_text) * 2) + 6, + }, + }, + ] + for chunk in chunks: + yield chunk + + def create_chat_completion(self, messages: List[Dict[str, str]], stream: bool = False, **kwargs): + content = messages[-1]["content"] + if stream: + return self._stream_chat(content) + created = int(time.time()) + return { + "id": f"chatcmpl-{self.alias}", + "object": "chat.completion", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": f"Echo: {content}"}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": len(content), + "completion_tokens": len(content) + 6, + "total_tokens": (len(content) * 2) + 6, + }, + } + + def _stream_chat(self, content: str) -> Iterator[Dict[str, object]]: + created = int(time.time()) + chunks = [ + { + "id": f"chatcmpl-{self.alias}", + "object": "chat.completion.chunk", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "Echo: "}, + "finish_reason": None, + "logprobs": None, + } + ] + }, + { + "id": f"chatcmpl-{self.alias}", + "object": "chat.completion.chunk", + "created": created, + "model": self.alias, + "choices": [ + { + "index": 0, + "delta": {"content": content}, + "finish_reason": None, + "logprobs": None, + } + ] + }, + { + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "id": f"chatcmpl-{self.alias}", + "object": "chat.completion.chunk", + "created": created, + "model": self.alias, + "usage": { + "prompt_tokens": len(content), + "completion_tokens": len(content) + 6, + "total_tokens": (len(content) * 2) + 6, + }, + }, + ] + for chunk in chunks: + yield chunk + + def create_embedding(self, input: Union[str, List[str]], **kwargs): + inputs = [input] if isinstance(input, str) else input + data = [ + { + "object": "embedding", + "index": idx, + "embedding": [float(len(item))], + } + for idx, item in enumerate(inputs) + ] + return { + "object": "list", + "model": self.alias, + "data": data, + } + + def tokenize(self, data: bytes, *args, **kwargs) -> List[int]: + text = data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else data + return [ord(ch) for ch in text] + + def detokenize(self, tokens: Iterable[int]) -> bytes: + return bytes(int(token) for token in tokens) + + def token_eos(self) -> int: + return 0 + + +class FakeProxy: + def __init__(self) -> None: + self._models: Dict[str, ModelSettings] = {} + self._instances: Dict[str, FakeLlama] = {} + self._default_alias: Optional[str] = None + + def configure(self, models: List[ModelSettings]) -> None: + self._models.clear() + self._instances.clear() + for index, settings in enumerate(models): + alias = settings.model_alias or settings.model + self._models[alias] = settings + if index == 0: + self._default_alias = alias + + def __call__(self, model: Optional[str] = None) -> FakeLlama: + if not self._models: + raise RuntimeError("fake proxy not configured") + alias = model if model and model in self._models else self._default_alias + assert alias is not None + if alias not in self._instances: + self._instances[alias] = FakeLlama(alias) + return self._instances[alias] + + def __iter__(self) -> Iterator[str]: + return iter(self._models.keys()) + + +@pytest.fixture(name="client") +def client_fixture(monkeypatch): + fake_proxy = FakeProxy() + + def fake_set_llama_proxy(model_settings: List[ModelSettings]) -> None: + fake_proxy.configure(model_settings) + app_module._llama_proxy = fake_proxy + + monkeypatch.setattr(app_module, "set_llama_proxy", fake_set_llama_proxy) + + server_settings = ServerSettings( + api_key="test-key", + interrupt_requests=False, + disable_ping_events=True, + ) + model_settings = [ + ModelSettings( + model="fake-model.gguf", + model_alias="fake-model", + vocab_only=True, + ) + ] + + app = create_app(server_settings=server_settings, model_settings=model_settings) + test_client = TestClient(app) + yield test_client + test_client.close() + + +def _auth_headers() -> Dict[str, str]: + return {"Authorization": "Bearer test-key"} + + +def test_completion_requires_auth(client: TestClient): + response = client.post( + "/v1/completions", + json={"model": "fake-model", "prompt": "hi"}, + ) + assert response.status_code == 401 + + +def test_completion_returns_payload(client: TestClient): + response = client.post( + "/v1/completions", + headers=_auth_headers(), + json={"model": "fake-model", "prompt": "hi"}, + ) + assert response.status_code == 200 + body = response.json() + assert body["choices"][0]["text"] == "Echo: hi" + + +def test_chat_completion_returns_message(client: TestClient): + response = client.post( + "/v1/chat/completions", + headers=_auth_headers(), + json={ + "model": "fake-model", + "messages": [ + {"role": "user", "content": "Hello"}, + ], + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["choices"][0]["message"]["content"] == "Echo: Hello" + + +def test_embedding_endpoint(client: TestClient): + response = client.post( + "/v1/embeddings", + headers=_auth_headers(), + json={"model": "fake-model", "input": "abc"}, + ) + assert response.status_code == 200 + body = response.json() + assert body["data"][0]["embedding"] == [3.0] + + +def test_models_listing(client: TestClient): + response = client.get("/v1/models", headers=_auth_headers()) + assert response.status_code == 200 + body = response.json() + ids = [entry["id"] for entry in body["data"]] + assert ids == ["fake-model"] + + +def test_tokenize_roundtrip(client: TestClient): + payload = {"model": "fake-model", "input": "ab"} + tok_response = client.post( + "/extras/tokenize", + headers=_auth_headers(), + json=payload, + ) + assert tok_response.status_code == 200 + tokens = tok_response.json()["tokens"] + + detok_response = client.post( + "/extras/detokenize", + headers=_auth_headers(), + json={"model": "fake-model", "tokens": tokens}, + ) + assert detok_response.status_code == 200 + assert detok_response.json()["text"] == "ab" + + count_response = client.post( + "/extras/tokenize/count", + headers=_auth_headers(), + json=payload, + ) + assert count_response.status_code == 200 + assert count_response.json()["count"] == len(tokens) + + +def test_streaming_completion_emits_chunks(client: TestClient): + with client.stream( + "POST", + "/v1/completions", + headers=_auth_headers(), + json={"model": "fake-model", "prompt": "hi", "stream": True}, + ) as stream: + payloads = [] + for line in stream.iter_lines(): + if not line: + continue + assert line.startswith("data: ") + data = line.removeprefix("data: ") + if data == "[DONE]": + break + payloads.append(json.loads(data)) + + texts = [chunk["choices"][0]["text"] for chunk in payloads] + assert texts[:2] == ["Echo: ", "hi"] + assert payloads[-1]["choices"][0]["finish_reason"] == "stop" + + +def test_error_response_is_wrapped(client: TestClient): + response = client.post( + "/v1/completions", + headers=_auth_headers(), + json={"model": "fake-model", "prompt": "fail"}, + ) + assert response.status_code == 500 + body = response.json() + assert body["error"]["message"] == "intentional failure" + assert body["error"]["type"] == "internal_server_error" From c732e64d894a6417bb601a370c77c38bd6131ec1 Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 23:06:30 +0800 Subject: [PATCH 09/10] feat: Add tests for chat completion functionality including argument forwarding and streaming responses --- tests/test_chat_completion_api.py | 308 ++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 tests/test_chat_completion_api.py diff --git a/tests/test_chat_completion_api.py b/tests/test_chat_completion_api.py new file mode 100644 index 0000000000..5c0836f1bf --- /dev/null +++ b/tests/test_chat_completion_api.py @@ -0,0 +1,308 @@ +import sys +import types +from typing import Dict, Iterator, List, Optional + +import pytest + +import llama_cpp +from llama_cpp.llama import Llama + +MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama-spm.gguf" + + +def make_fake_llm() -> Llama: + llm = Llama( + MODEL, + vocab_only=True, + n_ctx=64, + n_batch=32, + n_ubatch=32, + logits_all=True, + verbose=False, + ) + llm.set_seed(123) + return llm + + +def test_create_chat_completion_forwards_arguments(): + llm = make_fake_llm() + captured: Dict[str, Dict[str, object]] = {} + sentinel_processor = object() + sentinel_grammar = object() + sentinel_bias = {7: 1.5} + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + functions = [ + { + "name": "store_message", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + }, + } + ] + function_call = {"name": "store_message"} + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web.", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "search"}} + response_format = {"type": "json_object"} + + expected_payload = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1700000000, + "model": "alias", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Here is a joke."}, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 6, + "total_tokens": 16, + }, + } + + def fake_handler(**kwargs): + captured["kwargs"] = kwargs + return expected_payload + + llm.chat_handler = fake_handler # type: ignore[attr-defined] + + result = llm.create_chat_completion( + messages=messages, + functions=functions, + function_call=function_call, + tools=tools, + tool_choice=tool_choice, + temperature=0.3, + top_p=0.91, + top_k=33, + min_p=0.2, + typical_p=0.85, + stream=False, + stop=["DONE"], + seed=99, + response_format=response_format, + max_tokens=42, + presence_penalty=0.1, + frequency_penalty=0.2, + repeat_penalty=1.25, + tfs_z=0.7, + mirostat_mode=1, + mirostat_tau=4.5, + mirostat_eta=0.3, + model="alias", + logits_processor=sentinel_processor, # type: ignore[arg-type] + grammar=sentinel_grammar, # type: ignore[arg-type] + logit_bias=sentinel_bias, + logprobs=True, + top_logprobs=3, + ) + + assert result == expected_payload + forwarded = captured["kwargs"] + assert forwarded["llama"] is llm + assert forwarded["messages"] == messages + assert forwarded["functions"] == functions + assert forwarded["function_call"] == function_call + assert forwarded["tools"] == tools + assert forwarded["tool_choice"] == tool_choice + assert forwarded["temperature"] == 0.3 + assert forwarded["top_p"] == 0.91 + assert forwarded["top_k"] == 33 + assert forwarded["min_p"] == 0.2 + assert forwarded["typical_p"] == 0.85 + assert forwarded["stream"] is False + assert forwarded["stop"] == ["DONE"] + assert forwarded["seed"] == 99 + assert forwarded["response_format"] == response_format + assert forwarded["max_tokens"] == 42 + assert forwarded["presence_penalty"] == 0.1 + assert forwarded["frequency_penalty"] == 0.2 + assert forwarded["repeat_penalty"] == 1.25 + assert forwarded["tfs_z"] == 0.7 + assert forwarded["mirostat_mode"] == 1 + assert forwarded["mirostat_tau"] == 4.5 + assert forwarded["mirostat_eta"] == 0.3 + assert forwarded["model"] == "alias" + assert forwarded["logits_processor"] is sentinel_processor + assert forwarded["grammar"] is sentinel_grammar + assert forwarded["logit_bias"] == sentinel_bias + assert forwarded["logprobs"] is True + assert forwarded["top_logprobs"] == 3 + + +def test_create_chat_completion_streams_chunks(): + llm = make_fake_llm() + captured: Dict[str, Dict[str, object]] = {} + chunks = [ + { + "choices": [ + { + "index": 0, + "delta": {"content": "Hi"}, + "finish_reason": None, + "logprobs": None, + } + ] + }, + { + "choices": [ + { + "index": 0, + "delta": {"content": " there"}, + "finish_reason": None, + "logprobs": None, + } + ] + }, + { + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + "logprobs": None, + } + ] + }, + ] + + def fake_handler(**kwargs) -> Iterator[Dict[str, object]]: + captured["kwargs"] = kwargs + return iter(chunks) + + llm.chat_handler = fake_handler # type: ignore[attr-defined] + + iterator = llm.create_chat_completion( + messages=[{"role": "user", "content": "hey"}], + stream=True, + ) + + assert list(iterator) == chunks + forwarded = captured["kwargs"] + assert forwarded["stream"] is True + + +class _SimpleNamespace: + def __init__(self, **data: object) -> None: + self.__dict__.update(data) + + +@pytest.fixture(name="mock_openai") +def mock_openai_fixture(monkeypatch): + module_chat = types.ModuleType("openai.types.chat") + module_chat.ChatCompletion = type("ChatCompletion", (_SimpleNamespace,), {}) + module_chat.ChatCompletionChunk = type("ChatCompletionChunk", (_SimpleNamespace,), {}) + + module_types = types.ModuleType("openai.types") + module_types.chat = module_chat + + module_openai = types.ModuleType("openai") + module_openai.types = module_types + + monkeypatch.setitem(sys.modules, "openai", module_openai) + monkeypatch.setitem(sys.modules, "openai.types", module_types) + monkeypatch.setitem(sys.modules, "openai.types.chat", module_chat) + yield module_chat + + +def test_create_chat_completion_openai_v1_wraps_response(mock_openai): + llm = make_fake_llm() + + payload = { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1700000000, + "model": "alias", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "pong"}, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5, + }, + } + + def fake_handler(**kwargs): + return payload + + llm.chat_handler = fake_handler # type: ignore[attr-defined] + + response = llm.create_chat_completion_openai_v1( + messages=[{"role": "user", "content": "ping"}], + stream=False, + ) + + assert isinstance(response, mock_openai.ChatCompletion) + assert response.id == "chatcmpl-test" + assert response.choices[0]["message"]["content"] == "pong" + + +def test_create_chat_completion_openai_v1_wraps_stream(mock_openai): + llm = make_fake_llm() + + chunks = [ + { + "id": "chunk-1", + "choices": [ + { + "index": 0, + "delta": {"content": "p"}, + "finish_reason": None, + "logprobs": None, + } + ], + }, + { + "id": "chunk-2", + "choices": [ + { + "index": 0, + "delta": {"content": "ong"}, + "finish_reason": "stop", + "logprobs": None, + } + ], + }, + ] + + def fake_handler(**kwargs): + return iter(chunks) + + llm.chat_handler = fake_handler # type: ignore[attr-defined] + + stream = llm.create_chat_completion_openai_v1( + messages=[{"role": "user", "content": "ping"}], + stream=True, + ) + + materialized = list(stream) + assert all(isinstance(chunk, mock_openai.ChatCompletionChunk) for chunk in materialized) + assert [chunk.id for chunk in materialized] == ["chunk-1", "chunk-2"] From dd9ab546e83c1cdb5c0d1d9452c16eba17ba369d Mon Sep 17 00:00:00 2001 From: Shaka Huang Date: Wed, 15 Oct 2025 23:28:30 +0800 Subject: [PATCH 10/10] feat: Refactor and enhance tests for candidate prediction tokens with additional scenarios --- tests/test_llama_speculative.py | 50 +++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/tests/test_llama_speculative.py b/tests/test_llama_speculative.py index b5d450567b..de5534a55e 100644 --- a/tests/test_llama_speculative.py +++ b/tests/test_llama_speculative.py @@ -2,15 +2,47 @@ from llama_cpp.llama_speculative import LlamaPromptLookupDecoding -def test_find_candidate_pred_tokens(): + +def test_find_candidate_pred_tokens_returns_match(): + find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens + + input_ids = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]) + result = find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=2) + + assert np.array_equal(result, np.array([1, 2])) + + +def test_find_candidate_pred_tokens_no_match_returns_empty(): + find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens + + input_ids = np.array([1, 2, 3, 4, 5]) + result = find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=2) + + assert np.array_equal(result, np.array([])) + + +def test_find_candidate_pred_tokens_truncates_to_available_length(): + find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens + + input_ids = np.array([4, 4, 4, 4]) + result = find_candidate_pred_tokens(input_ids, max_ngram_size=2, num_pred_tokens=3) + + assert np.array_equal(result, np.array([4, 4])) + + +def test_find_candidate_pred_tokens_short_context(): find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens - # Test Case 1: Matching ngram is found - input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]) - result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2) - assert np.array_equal(result1, np.array([1, 2])) + input_ids = np.array([7]) + result = find_candidate_pred_tokens(input_ids, max_ngram_size=2, num_pred_tokens=2) + + assert np.array_equal(result, np.array([])) + + +def test_prompt_lookup_decoding_uses_instance_configuration(): + decoder = LlamaPromptLookupDecoding(max_ngram_size=1, num_pred_tokens=2) + input_ids = np.array([5, 6, 7, 6, 7]) + + result = decoder(input_ids) - # Test Case 2: Matching ngram is not found - input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) - result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2) - assert np.array_equal(result2, np.array([])) + assert np.array_equal(result, np.array([6, 7]))