diff --git a/.github/workflows/build-and-release.yaml b/.github/workflows/build-and-release.yaml index 4ae37b1745..c931ead34d 100644 --- a/.github/workflows/build-and-release.yaml +++ b/.github/workflows/build-and-release.yaml @@ -139,6 +139,37 @@ jobs: name: wheels_riscv64 path: ./wheelhouse/*.whl + build_wheels_pyodide: + name: Build Pyodide wheel + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + submodules: "recursive" + + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Build wheel + uses: pypa/cibuildwheel@v4.1.0 + env: + CIBW_PLATFORM: "pyodide" + CIBW_BUILD: "cp314-pyodide_wasm32" + CIBW_BUILD_VERBOSITY: "1" + CIBW_REPAIR_WHEEL_COMMAND: "" + CIBW_BEFORE_TEST: "curl -L --fail --retry 3 -o /tmp/stories260K.gguf https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf" + CIBW_TEST_COMMAND: "python -c \"import llama_cpp.mtmd_cpp as mtmd; from llama_cpp import Llama; print('mtmd marker', mtmd.mtmd_default_marker().decode()); llm = Llama(model_path='/tmp/stories260K.gguf', n_ctx=64, n_batch=8, n_threads=1, verbose=False); print('loaded', llm.n_vocab(), llm.n_ctx()); print('generated', llm('Once upon a', max_tokens=1, temperature=0)['choices'][0]['text'])\"" + CMAKE_ARGS: "-DLLAMA_WASM_MEM64=OFF -DEMSCRIPTEN_SYSTEM_PROCESSOR=wasm32 -DGGML_NATIVE=OFF -DGGML_OPENMP=OFF -DGGML_METAL=OFF -DGGML_BLAS=OFF -DGGML_CUDA=OFF -DGGML_HIP=OFF -DGGML_VULKAN=OFF -DGGML_OPENCL=OFF -DGGML_RPC=OFF -DLLAMA_CURL=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TOOLS=OFF -DLLAMA_BUILD_SERVER=OFF" + with: + output-dir: wheelhouse + + - name: Upload wheels as artifacts + uses: actions/upload-artifact@v7 + with: + name: wheels_pyodide + path: ./wheelhouse/*.whl + build_sdist: name: Build source distribution runs-on: ubuntu-latest @@ -183,7 +214,7 @@ jobs: release: name: Release - needs: [build_wheels, build_wheels_arm64, build_wheels_riscv64, build_sdist] + needs: [build_wheels, build_wheels_arm64, build_wheels_riscv64, build_wheels_pyodide, build_sdist] if: startsWith(github.ref, 'refs/tags/') runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 56c5ffb557..925e941d88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat: update llama.cpp to ggml-org/llama.cpp@92e854ab8 +- fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306 + +## [0.3.31] + +- feat: update llama.cpp to ggml-org/llama.cpp@f449e0553 + +## [0.3.30] + +- feat: update llama.cpp to ggml-org/llama.cpp@e3a74b299 +- feat: add Pyodide wheel support by @abetlen in #2309 + ## [0.3.29] - feat(example): use MTMD batch encoding by @abetlen in #2301 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0474863a48..5feaaca5b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,14 +10,22 @@ function(llama_cpp_python_install_target target) return() endif() - install( - TARGETS ${target} - LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - RUNTIME DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - ARCHIVE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - FRAMEWORK DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - RESOURCE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib - ) + if(EMSCRIPTEN) + set_target_properties(${target} PROPERTIES + OUTPUT_NAME "${target}.cpython-00-wasm32-emscripten" + ) + endif() + + if(NOT EMSCRIPTEN) + install( + TARGETS ${target} + LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + RUNTIME DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + ARCHIVE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + FRAMEWORK DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + RESOURCE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib + ) + endif() install( TARGETS ${target} LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp/lib @@ -65,6 +73,32 @@ if (LLAMA_BUILD) # Disable building curl support set(LLAMA_CURL OFF CACHE BOOL "llama.cpp: enable curl" FORCE) + if (EMSCRIPTEN) + if (DEFINED EMSCRIPTEN_SYSTEM_PROCESSOR) + set(CMAKE_SYSTEM_PROCESSOR ${EMSCRIPTEN_SYSTEM_PROCESSOR} CACHE STRING "Target processor" FORCE) + else() + set(CMAKE_SYSTEM_PROCESSOR wasm32 CACHE STRING "Target processor" FORCE) + endif() + + set(LLAMA_WASM_MEM64 OFF CACHE BOOL "llama.cpp: enable wasm64 memory" FORCE) + set(GGML_NATIVE OFF CACHE BOOL "ggml: enable -march=native" FORCE) + set(GGML_OPENMP OFF CACHE BOOL "ggml: use OpenMP" FORCE) + set(GGML_METAL OFF CACHE BOOL "ggml: use Metal" FORCE) + set(GGML_BLAS OFF CACHE BOOL "ggml: use BLAS" FORCE) + set(GGML_CUDA OFF CACHE BOOL "ggml: use CUDA" FORCE) + set(GGML_HIP OFF CACHE BOOL "ggml: use HIP" FORCE) + set(GGML_VULKAN OFF CACHE BOOL "ggml: use Vulkan" FORCE) + set(GGML_OPENCL OFF CACHE BOOL "ggml: use OpenCL" FORCE) + set(GGML_RPC OFF CACHE BOOL "ggml: use RPC" FORCE) + + # Pyodide auto-loads side modules from top-level site-packages/lib + # before Python imports run, so keep upstream installs package-local. + set(CMAKE_INSTALL_BINDIR llama_cpp/lib CACHE PATH "Install binaries" FORCE) + set(CMAKE_INSTALL_INCLUDEDIR llama_cpp/include CACHE PATH "Install headers" FORCE) + set(CMAKE_INSTALL_LIBDIR llama_cpp/lib CACHE PATH "Install libraries" FORCE) + set(LLAMA_BUILD_COMMON OFF CACHE BOOL "Build llama.cpp common library" FORCE) + endif() + # Architecture detection and settings for Apple platforms if (APPLE) # Get the target architecture diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 42f807ef61..ed3c342f20 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.3.29" +__version__ = "0.3.31" diff --git a/llama_cpp/_ctypes_extensions.py b/llama_cpp/_ctypes_extensions.py index e88ed387df..02cee8a88f 100644 --- a/llama_cpp/_ctypes_extensions.py +++ b/llama_cpp/_ctypes_extensions.py @@ -19,6 +19,9 @@ from typing_extensions import TypeAlias +_EMSCRIPTEN_SIDE_MODULE_SUFFIX = ".cpython-00-wasm32-emscripten.so" + + # Load the library def load_shared_library(lib_base_name: str, base_path: pathlib.Path): """Platform independent shared library loader""" @@ -26,7 +29,12 @@ def load_shared_library(lib_base_name: str, base_path: pathlib.Path): # for llamacpp) and "llama" (default name for this repo) lib_paths: List[pathlib.Path] = [] # Determine the file extension based on the platform - if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): + if sys.platform == "emscripten": + # Use a CPython-style tag that Pyodide skips during package auto-load. + lib_paths += [ + base_path / f"lib{lib_base_name}{_EMSCRIPTEN_SIDE_MODULE_SUFFIX}", + ] + elif sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): lib_paths += [ base_path / f"lib{lib_base_name}.so", ] @@ -60,6 +68,33 @@ def load_shared_library(lib_base_name: str, base_path: pathlib.Path): os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib")) cdll_args["winmode"] = ctypes.RTLD_GLOBAL + if sys.platform == "emscripten": + cdll_args["mode"] = ctypes.RTLD_GLOBAL + lib_dir = str(base_path) + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if lib_dir not in ld_library_path.split(os.pathsep): + os.environ["LD_LIBRARY_PATH"] = ( + lib_dir + if not ld_library_path + else f"{lib_dir}{os.pathsep}{ld_library_path}" + ) + + emscripten_dependencies = { + "llama": ("ggml-base", "ggml-cpu", "ggml"), + "mtmd": ("ggml-base", "ggml-cpu", "ggml", "llama"), + } + for dependency in emscripten_dependencies.get(lib_base_name, ()): + dependency_path = ( + base_path / f"lib{dependency}{_EMSCRIPTEN_SIDE_MODULE_SUFFIX}" + ) + if dependency_path.exists(): + try: + ctypes.CDLL(str(dependency_path), **cdll_args) # type: ignore + except Exception as e: + raise RuntimeError( + f"Failed to load shared library '{dependency_path}': {e}" + ) + # Try to load the shared library, handling potential errors for lib_path in lib_paths: if lib_path.exists(): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4a09b55ee5..b5bffd46b5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -471,6 +471,8 @@ def free_lora_adapter(): self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) self.n_tokens = 0 + # Restored or truncated state must decode before sampling. + self._requires_eval = True self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single @@ -647,6 +649,7 @@ def set_seed(self, seed: int): def reset(self): """Reset the model state.""" self.n_tokens = 0 + self._requires_eval = True if self._is_recurrent or self._is_hybrid: mem = llama_cpp.llama_get_memory(self._ctx.ctx) @@ -689,6 +692,7 @@ def eval(self, tokens: Sequence[int]): pass # Update n_tokens self.n_tokens += n_tokens + self._requires_eval = False def _init_sampler( self, @@ -900,41 +904,53 @@ def generate( grammar=grammar, ) + tokens = list(tokens) + # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): + for a, b in zip(self._input_ids, tokens): if a == b: longest_prefix += 1 else: break - # Recurrent and hybrid models cannot rewind state; reset if needed - if ( - self._is_recurrent or self._is_hybrid - ) and longest_prefix < self.n_tokens: - longest_prefix = 0 - reset = True + prompt_consumed = longest_prefix == len(tokens) + exact_prompt_cached = self.n_tokens == len(tokens) and prompt_consumed + + # Exact cache hits can sample immediately only when the current + # logits were produced by a live decode, not restored state. + if exact_prompt_cached and not self._requires_eval: + reset = False + tokens = [] + reuse_prefix = 0 if self.verbose: print( - "Llama.generate: recurrent/hybrid model requires full state reset", + "Llama.generate: full prompt already cached, skipping reset", file=sys.stderr, ) - - if longest_prefix > 0: - if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): + else: + # If there is no suffix to decode, replay one token to refresh + # logits after truncating to a valid prefix. + reuse_prefix = longest_prefix - 1 if prompt_consumed else longest_prefix + + # Prefix hits can reuse memory because the suffix decode refreshes + # logits before sampling. + if reuse_prefix > 0: + if self._ctx.kv_cache_seq_rm(-1, reuse_prefix, -1): reset = False - tokens = tokens[longest_prefix:] - self.n_tokens = longest_prefix + tokens = tokens[reuse_prefix:] + self.n_tokens = reuse_prefix + self._requires_eval = True if self.verbose: print( - f"Llama.generate: {longest_prefix} prefix-match hit, " + f"Llama.generate: {reuse_prefix} prefix-match hit, " f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr, ) elif self.verbose: print( - f"Llama.generate: {longest_prefix} prefix-match found " + f"Llama.generate: {reuse_prefix} prefix-match found " f"but partial kv removal not supported, re-evaluating full prompt", file=sys.stderr, ) @@ -948,7 +964,6 @@ def generate( # grammar.reset() sample_idx = self.n_tokens + len(tokens) - 1 - tokens = list(tokens) # Eval and sample while True: @@ -988,6 +1003,7 @@ def generate( if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: self.n_tokens = sample_idx self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._requires_eval = True break if self.draft_model is not None: @@ -2217,6 +2233,7 @@ def load_state(self, state: LlamaState) -> None: rest[rest > 0] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens + self._requires_eval = True self._seed = state.seed state_size = state.llama_state_size LLamaStateArrayType = ctypes.c_uint8 * state_size diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 21f85c81c3..176709d96e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1744,6 +1744,11 @@ def llama_model_n_embd_out(model: llama_model_p, /) -> int: def llama_model_n_layer(model: llama_model_p, /) -> int: ... +# LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model); +@ctypes_function("llama_model_n_layer_nextn", [llama_model_p_ctypes], ctypes.c_int32) +def llama_model_n_layer_nextn(model: llama_model_p, /) -> int: ... + + # LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); @ctypes_function("llama_model_n_head", [llama_model_p_ctypes], ctypes.c_int32) def llama_model_n_head(model: llama_model_p, /) -> int: ... diff --git a/llama_cpp/llama_cpp_ext.py b/llama_cpp/llama_cpp_ext.py index 284811086a..a4b424eb63 100644 --- a/llama_cpp/llama_cpp_ext.py +++ b/llama_cpp/llama_cpp_ext.py @@ -62,6 +62,25 @@ def llama_set_embeddings_nextn( ... +# LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset); +@_ctypes_function_from_names( + ( + "llama_set_nextn_layer_offset", + "_Z28llama_set_nextn_layer_offsetP13llama_contexti", + "?llama_set_nextn_layer_offset@@YAXPEAUllama_context@@H@Z", + ), + [llama_cpp.llama_context_p_ctypes, ctypes.c_int32], + None, +) +def llama_set_nextn_layer_offset( + ctx: llama_cpp.llama_context_p, + offset: Union[ctypes.c_int32, int], + /, +): + """Select which appended NextN block the decoder MTP graph runs.""" + ... + + # LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); @_ctypes_function_from_names( ( diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py index 46eb2c879b..35357a3279 100644 --- a/llama_cpp/mtmd_cpp.py +++ b/llama_cpp/mtmd_cpp.py @@ -20,6 +20,7 @@ ) import pathlib from typing import ( + Callable, Union, NewType, Optional, @@ -84,6 +85,8 @@ MTMD_INPUT_CHUNK_TYPE_IMAGE = 1 MTMD_INPUT_CHUNK_TYPE_AUDIO = 2 +mtmd_progress_callback = CFUNCTYPE(c_bool, c_float, c_void_p) + # Structures class mtmd_context_params(Structure): @@ -106,6 +109,8 @@ class mtmd_context_params(Structure): cb_eval: llama_cpp.ggml_backend_sched_eval_callback cb_eval_user_data: c_void_p batch_max_tokens: int + progress_callback: Callable[[float, c_void_p], bool] + progress_callback_user_data: c_void_p _fields_ = [ ("use_gpu", c_bool), @@ -120,6 +125,8 @@ class mtmd_context_params(Structure): ("cb_eval", llama_cpp.ggml_backend_sched_eval_callback), ("cb_eval_user_data", c_void_p), ("batch_max_tokens", c_int), + ("progress_callback", mtmd_progress_callback), + ("progress_callback_user_data", c_void_p), ] @@ -169,6 +176,12 @@ class mtmd_caps(Structure): POINTER(c_char_p), ) +mtmd_helper_post_decode_callback = CFUNCTYPE( + c_int, + llama_cpp.llama_batch, + c_void_p, +) + class mtmd_helper_bitmap_wrapper(Structure): """Bitmap wrapper returned by MTMD helper media loaders.""" @@ -860,7 +873,9 @@ def mtmd_helper_eval_chunk_single( # llama_pos n_past, # llama_seq_id seq_id, # int32_t n_batch, -# llama_pos * new_n_past); +# llama_pos * new_n_past, +# mtmd_helper_post_decode_callback callback, +# void * user_data); @ctypes_function( "mtmd_helper_decode_image_chunk", [ @@ -872,6 +887,8 @@ def mtmd_helper_eval_chunk_single( llama_cpp.llama_seq_id, c_int, POINTER(llama_cpp.llama_pos), + mtmd_helper_post_decode_callback, + c_void_p, ], c_int, ) @@ -884,6 +901,8 @@ def mtmd_helper_decode_image_chunk( seq_id: llama_cpp.llama_seq_id, n_batch: Union[c_int, int], new_n_past: "_Pointer[llama_cpp.llama_pos]", + callback: Optional[mtmd_helper_post_decode_callback], + user_data: c_void_p, /, ) -> int: """Decode a pre-encoded image chunk.""" diff --git a/tests/test_llama.py b/tests/test_llama.py index 336d6a6122..70fce12d8e 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,5 @@ import ctypes +import itertools import multiprocessing import numpy as np @@ -64,6 +65,14 @@ def llama_cpp_model_path(): return model_path +@pytest.fixture +def llama_cpp_transformer_model_path(): + repo_id = "ggml-org/models" + filename = "tinyllamas/stories15M-q4_0.gguf" + model_path = hf_hub_download(repo_id, filename) + return model_path + + @pytest.fixture def llama_cpp_embedding_model_path(): repo_id = "CompendiumLabs/bge-small-en-v1.5-gguf" @@ -339,6 +348,285 @@ def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path): ) +def _create_test_model(model_path): + return llama_cpp.Llama( + model_path, + n_ctx=64, + n_batch=64, + n_ubatch=64, + n_threads=multiprocessing.cpu_count(), + n_threads_batch=multiprocessing.cpu_count(), + logits_all=False, + verbose=False, + ) + + +def _generate_test_tokens(model, tokens, max_tokens=3): + return list( + itertools.islice( + model.generate( + tokens, + temp=0.0, + ), + max_tokens, + ) + ) + + +MODEL_CACHE_CASES = ( + ("llama_cpp_transformer_model_path", False, False), + ("llama_cpp_recurrent_model_path", True, False), + ("llama_cpp_hybrid_model_path", False, True), +) + +RESTORED_CACHE_CASES = MODEL_CACHE_CASES + + +def _eval_alternate_same_length_prompt(model, tokens, expected_next_token): + replacement_tokens = ( + model.token_eos(), + model.token_nl(), + 0, + 1, + 2, + model.n_vocab() - 1, + ) + + for replacement_token in replacement_tokens: + alternate_tokens = list(tokens) + alternate_tokens[-1] = replacement_token + if alternate_tokens == tokens: + continue + + model.reset() + model.eval(alternate_tokens) + if model.sample(temp=0.0, idx=len(tokens) - 1) != expected_next_token: + return + + raise AssertionError("failed to find an alternate same-length prompt") + + +def _assert_exact_cached_prompt_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + + assert fresh._is_recurrent is is_recurrent + assert fresh._is_hybrid is is_hybrid + + expected_tokens = _generate_test_tokens(fresh, tokens) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + + cached.eval(tokens) + assert cached.n_tokens == len(tokens) + assert cached.input_ids[: cached.n_tokens].tolist() == tokens + assert cached.sample(temp=0.0, idx=len(tokens) - 1) == expected_tokens[0] + + reset_calls = 0 + original_reset = cached.reset + + def reset_tracker(): + nonlocal reset_calls + reset_calls += 1 + original_reset() + + cached.reset = reset_tracker + + cached_tokens = _generate_test_tokens(cached, tokens) + assert reset_calls == 0 + assert cached_tokens == expected_tokens + assert cached.n_tokens == len(tokens) + len(cached_tokens) - 1 + + +def _assert_loaded_exact_cached_prompt_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + expected_tokens = _generate_test_tokens(fresh, tokens) + + source = _create_test_model(model_path) + assert source._is_recurrent is is_recurrent + assert source._is_hybrid is is_hybrid + + source.eval(tokens) + state = source.save_state() + + loaded = _create_test_model(model_path) + assert loaded._is_recurrent is is_recurrent + assert loaded._is_hybrid is is_hybrid + + _eval_alternate_same_length_prompt( + loaded, + tokens, + expected_tokens[0], + ) + loaded.load_state(state) + + assert loaded.n_tokens == len(tokens) + assert loaded.input_ids[: loaded.n_tokens].tolist() == tokens + + loaded_tokens = _generate_test_tokens(loaded, tokens) + assert loaded_tokens == expected_tokens + assert loaded.n_tokens == len(tokens) + len(loaded_tokens) - 1 + + +def _assert_ram_cache_exact_prompt_hit_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + expected = fresh.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + cache = llama_cpp.LlamaRAMCache() + writer = _create_test_model(model_path) + writer.set_cache(cache) + writer.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + cached.set_cache(cache) + + load_state_calls = 0 + original_load_state = cached.load_state + + def load_state_tracker(state): + nonlocal load_state_calls + load_state_calls += 1 + original_load_state(state) + + cached.load_state = load_state_tracker + + actual = cached.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + assert load_state_calls == 1 + assert actual["choices"][0]["text"] == expected["choices"][0]["text"] + assert ( + actual["usage"]["completion_tokens"] == expected["usage"]["completion_tokens"] + ) + + +def _assert_shorter_prompt_prefix_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + history = " jumps over the lazy dog" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + history_tokens = fresh.tokenize(history.encode(), add_bos=False, special=True) + expected_tokens = _generate_test_tokens(fresh, tokens) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + + cached.eval(tokens + history_tokens) + assert cached.n_tokens > len(tokens) + assert cached.input_ids[: len(tokens)].tolist() == tokens + + cached_tokens = _generate_test_tokens(cached, tokens) + assert cached_tokens == expected_tokens + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), MODEL_CACHE_CASES +) +def test_exact_cached_prompt_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_exact_cached_prompt_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), RESTORED_CACHE_CASES +) +def test_loaded_exact_cached_prompt_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_loaded_exact_cached_prompt_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), RESTORED_CACHE_CASES +) +def test_ram_cache_exact_prompt_hit_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_ram_cache_exact_prompt_hit_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), MODEL_CACHE_CASES +) +def test_shorter_prompt_prefix_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_shorter_prompt_prefix_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + def test_real_llama_embeddings(llama_cpp_embedding_model_path): model = llama_cpp.Llama( llama_cpp_embedding_model_path, diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f05cf4676a..92e854ab83 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f05cf4676af46c2f017c0e6ba25b6e20204f700e +Subproject commit 92e854ab836254bb7f2eb49babd5613474bdb700