forked from abetlen/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_cpp_ext.py
More file actions
136 lines (116 loc) · 3.76 KB
/
Copy pathllama_cpp_ext.py
File metadata and controls
136 lines (116 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Experimental bindings for non-public llama.cpp APIs from `llama-ext.h`.
This module is not part of the stable llama-cpp-python public API.
Downstream code should not import or depend on it directly.
"""
from __future__ import annotations
import ctypes
import functools
from typing import Any, Iterable, Union
from . import llama_cpp
_lib = llama_cpp._lib
def _ctypes_function_from_names(
names: Iterable[str],
argtypes: list[Any],
restype: Any,
):
"""Decorator for extension functions whose exported symbol name can vary by ABI."""
def decorator(f):
missing: list[str] = []
for name in names:
try:
func = getattr(_lib, name)
except AttributeError:
missing.append(name)
continue
func.argtypes = argtypes
func.restype = restype
functools.wraps(f)(func)
return func
raise AttributeError(
f"None of the shared library symbols were found: {', '.join(missing)}"
)
return decorator
# LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
@_ctypes_function_from_names(
(
"llama_set_embeddings_nextn",
"_Z26llama_set_embeddings_nextnP13llama_contextbb",
"?llama_set_embeddings_nextn@@YAXPEAUllama_context@@_N1@Z",
),
[llama_cpp.llama_context_p_ctypes, ctypes.c_bool, ctypes.c_bool],
None,
)
def llama_set_embeddings_nextn(
ctx: llama_cpp.llama_context_p,
value: bool,
masked: bool,
/,
):
"""Set whether the context outputs nextn embeddings or not."""
...
# LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
@_ctypes_function_from_names(
(
"llama_set_nextn_layer_offset",
"_Z28llama_set_nextn_layer_offsetP13llama_contexti",
"?llama_set_nextn_layer_offset@@YAXPEAUllama_context@@H@Z",
),
[llama_cpp.llama_context_p_ctypes, ctypes.c_int32],
None,
)
def llama_set_nextn_layer_offset(
ctx: llama_cpp.llama_context_p,
offset: Union[ctypes.c_int32, int],
/,
):
"""Select which appended NextN block the decoder MTP graph runs."""
...
# LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
@_ctypes_function_from_names(
(
"llama_get_embeddings_nextn",
"_Z26llama_get_embeddings_nextnP13llama_context",
"?llama_get_embeddings_nextn@@YAPEAMPEAUllama_context@@@Z",
),
[llama_cpp.llama_context_p_ctypes],
ctypes.POINTER(ctypes.c_float),
)
def llama_get_embeddings_nextn(
ctx: llama_cpp.llama_context_p,
/,
):
"""Get the nextn embeddings from the last evaluation."""
...
# LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
@_ctypes_function_from_names(
(
"llama_get_embeddings_nextn_ith",
"_Z30llama_get_embeddings_nextn_ithP13llama_contexti",
"?llama_get_embeddings_nextn_ith@@YAPEAMPEAUllama_context@@H@Z",
),
[llama_cpp.llama_context_p_ctypes, ctypes.c_int32],
ctypes.POINTER(ctypes.c_float),
)
def llama_get_embeddings_nextn_ith(
ctx: llama_cpp.llama_context_p,
i: Union[ctypes.c_int32, int],
/,
):
"""Get the nextn embeddings for the ith output row from the last evaluation."""
...
# LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
@_ctypes_function_from_names(
(
"llama_get_ctx_other",
"_Z19llama_get_ctx_otherP13llama_context",
"?llama_get_ctx_other@@YAPEAUllama_context@@PEAU1@@Z",
),
[llama_cpp.llama_context_p_ctypes],
llama_cpp.llama_context_p_ctypes,
)
def llama_get_ctx_other(
ctx: llama_cpp.llama_context_p,
/,
):
"""Get the context linked through llama_context_params.ctx_other."""
...