forked from abetlen/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_cpp_ext.py
More file actions
117 lines (99 loc) · 3.2 KB
/
Copy pathllama_cpp_ext.py
File metadata and controls
117 lines (99 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Experimental bindings for non-public llama.cpp APIs from `llama-ext.h`.
This module is not part of the stable llama-cpp-python public API.
Downstream code should not import or depend on it directly.
"""
from __future__ import annotations
import ctypes
import functools
from typing import Any, Iterable, Union
from . import llama_cpp
_lib = llama_cpp._lib
def _ctypes_function_from_names(
names: Iterable[str],
argtypes: list[Any],
restype: Any,
):
"""Decorator for extension functions whose exported symbol name can vary by ABI."""
def decorator(f):
missing: list[str] = []
for name in names:
try:
func = getattr(_lib, name)
except AttributeError:
missing.append(name)
continue
func.argtypes = argtypes
func.restype = restype
functools.wraps(f)(func)
return func
raise AttributeError(
f"None of the shared library symbols were found: {', '.join(missing)}"
)
return decorator
# LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
@_ctypes_function_from_names(
(
"llama_set_embeddings_nextn",
"_Z26llama_set_embeddings_nextnP13llama_contextbb",
"?llama_set_embeddings_nextn@@YAXPEAUllama_context@@_N1@Z",
),
[llama_cpp.llama_context_p_ctypes, ctypes.c_bool, ctypes.c_bool],
None,
)
def llama_set_embeddings_nextn(
ctx: llama_cpp.llama_context_p,
value: bool,
masked: bool,
/,
):
"""Set whether the context outputs nextn embeddings or not."""
...
# LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
@_ctypes_function_from_names(
(
"llama_get_embeddings_nextn",
"_Z26llama_get_embeddings_nextnP13llama_context",
"?llama_get_embeddings_nextn@@YAPEAMPEAUllama_context@@@Z",
),
[llama_cpp.llama_context_p_ctypes],
ctypes.POINTER(ctypes.c_float),
)
def llama_get_embeddings_nextn(
ctx: llama_cpp.llama_context_p,
/,
):
"""Get the nextn embeddings from the last evaluation."""
...
# LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
@_ctypes_function_from_names(
(
"llama_get_embeddings_nextn_ith",
"_Z30llama_get_embeddings_nextn_ithP13llama_contexti",
"?llama_get_embeddings_nextn_ith@@YAPEAMPEAUllama_context@@H@Z",
),
[llama_cpp.llama_context_p_ctypes, ctypes.c_int32],
ctypes.POINTER(ctypes.c_float),
)
def llama_get_embeddings_nextn_ith(
ctx: llama_cpp.llama_context_p,
i: Union[ctypes.c_int32, int],
/,
):
"""Get the nextn embeddings for the ith output row from the last evaluation."""
...
# LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
@_ctypes_function_from_names(
(
"llama_get_ctx_other",
"_Z19llama_get_ctx_otherP13llama_context",
"?llama_get_ctx_other@@YAPEAUllama_context@@PEAU1@@Z",
),
[llama_cpp.llama_context_p_ctypes],
llama_cpp.llama_context_p_ctypes,
)
def llama_get_ctx_other(
ctx: llama_cpp.llama_context_p,
/,
):
"""Get the context linked through llama_context_params.ctx_other."""
...