Skip to content

Commit 64848de

Browse files
llama-fit-params: free memory target per device (ggml-org#18679)
1 parent 9a5724d commit 64848de

6 files changed

Lines changed: 83 additions & 39 deletions

File tree

common/arg.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,7 +2255,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22552255
std::vector<std::string> split_arg{ it, {} };
22562256
if (split_arg.size() >= llama_max_devices()) {
22572257
throw std::invalid_argument(
2258-
string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices())
2258+
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
22592259
);
22602260
}
22612261
for (size_t i = 0; i < llama_max_devices(); ++i) {
@@ -2295,10 +2295,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22952295
}
22962296
).set_env("LLAMA_ARG_FIT"));
22972297
add_opt(common_arg(
2298-
{ "-fitt", "--fit-target" }, "MiB",
2299-
string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)),
2300-
[](common_params & params, int value) {
2301-
params.fit_params_target = value * size_t(1024*1024);
2298+
{ "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...",
2299+
string_format("target margin per device for --fit, comma-separated list of values, "
2300+
"single value is broadcast across all devices, default: %zu", params.fit_params_target[0]/(1024*1024)),
2301+
[](common_params & params, const std::string & value) {
2302+
std::string arg_next = value;
2303+
2304+
// split string by , and /
2305+
const std::regex regex{ R"([,/]+)" };
2306+
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
2307+
std::vector<std::string> split_arg{ it, {} };
2308+
if (split_arg.size() >= llama_max_devices()) {
2309+
throw std::invalid_argument(
2310+
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
2311+
);
2312+
}
2313+
if (split_arg.size() == 1) {
2314+
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
2315+
return;
2316+
}
2317+
for (size_t i = 0; i < split_arg.size(); i++) {
2318+
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
2319+
}
23022320
}
23032321
).set_env("LLAMA_ARG_FIT_TARGET"));
23042322
add_opt(common_arg(

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ common_init_result::common_init_result(common_params & params) :
10971097
if (params.fit_params) {
10981098
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
10991099
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
1100-
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
1100+
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
11011101
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
11021102
}
11031103

common/common.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,14 @@ struct common_params {
332332
// offload params
333333
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
334334

335-
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
336-
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
337-
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
338-
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
339-
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
340-
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
335+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
336+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
337+
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
338+
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
339+
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
340+
341+
// margin per device in bytes for fitting parameters to free memory:
342+
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
341343

342344
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
343345

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ extern "C" {
495495
struct llama_context_params * cparams,
496496
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
497497
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
498-
size_t margin, // margin of memory to leave per device in bytes
498+
size_t * margins, // margins of memory to leave per device in bytes
499499
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
500500
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
501501

src/llama.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,8 @@ class llama_params_fit_exception : public std::runtime_error {
147147
static void llama_params_fit_impl(
148148
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
149149
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
150-
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
150+
size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
151151
constexpr int64_t MiB = 1024*1024;
152-
const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
153152
typedef std::vector<llama_device_memory_data> dmds_t;
154153
const llama_model_params default_mparams = llama_model_default_params();
155154

@@ -168,6 +167,12 @@ static void llama_params_fit_impl(
168167
return;
169168
}
170169

170+
std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
171+
margins.reserve(nd);
172+
for (size_t id = 0; id < nd; id++) {
173+
margins.push_back(margins_s[id]);
174+
}
175+
171176
std::vector<std::string> dev_names;
172177
{
173178
dev_names.reserve(nd);
@@ -187,9 +192,10 @@ static void llama_params_fit_impl(
187192

188193
int64_t sum_free = 0;
189194
int64_t sum_projected_free = 0;
190-
int64_t min_projected_free = INT64_MAX;
191195
int64_t sum_projected_used = 0;
192196
int64_t sum_projected_model = 0;
197+
std::vector<int64_t> projected_free_per_device;
198+
projected_free_per_device.reserve(nd);
193199

194200
if (nd > 1) {
195201
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@@ -199,45 +205,63 @@ static void llama_params_fit_impl(
199205

200206
const int64_t projected_used = dmd.mb.total();
201207
const int64_t projected_free = dmd.free - projected_used;
208+
projected_free_per_device.push_back(projected_free);
202209

203210
sum_free += dmd.free;
204211
sum_projected_used += projected_used;
205212
sum_projected_free += projected_free;
206-
min_projected_free = std::min(min_projected_free, projected_free);
207213
sum_projected_model += dmd.mb.model;
208214

209215
if (nd > 1) {
210-
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
211-
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
212-
projected_free >= 0 ? "surplus" : "deficit");
216+
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n",
217+
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
213218
}
214219
}
215220
assert(sum_free >= 0 && sum_projected_used >= 0);
216221
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
217222
__func__, sum_projected_used/MiB, sum_free/MiB);
218-
if (min_projected_free >= margin) {
219-
if (nd == 1) {
223+
if (nd == 1) {
224+
if (projected_free_per_device[0] >= margins[0]) {
220225
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
221-
__func__, min_projected_free/MiB, margin/MiB);
226+
__func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
227+
return;
228+
}
229+
} else {
230+
bool changes_needed = false;
231+
for (size_t id = 0; id < nd; id++) {
232+
if (projected_free_per_device[id] < margins[id]) {
233+
changes_needed = true;
234+
break;
235+
}
236+
}
237+
if (!changes_needed) {
238+
LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
222239
return;
223240
}
224-
LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
225-
__func__, min_projected_free/MiB, margin/MiB);
226-
return;
227241
}
228242

229243
// step 2: try reducing memory use by reducing the context size
230244

231245
{
232-
int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
246+
int64_t global_surplus = sum_projected_free;
247+
for (size_t id = 0; id < nd; id++) {
248+
global_surplus -= margins[id];
249+
}
233250
if (global_surplus < 0) {
234-
LLAMA_LOG_INFO(nd == 1 ?
235-
"%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
236-
"%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
237-
__func__, margin/MiB, -global_surplus/MiB);
251+
if (nd == 1) {
252+
LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n",
253+
__func__, margins[0]/MiB, -global_surplus/MiB);
254+
} else {
255+
LLAMA_LOG_INFO(
256+
"%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n",
257+
__func__, -global_surplus/MiB);
258+
}
238259
if (cparams->n_ctx == 0) {
239260
if (hp_nct > n_ctx_min) {
240-
int64_t sum_used_target = sum_free - nd*margin_s;
261+
int64_t sum_used_target = sum_free;
262+
for (size_t id = 0; id < nd; id++) {
263+
sum_used_target -= margins[id];
264+
}
241265
if (nd > 1) {
242266
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
243267
// - for dense models only whole layers can be assigned to devices
@@ -448,9 +472,9 @@ static void llama_params_fit_impl(
448472
const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
449473
path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
450474

451-
for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
452-
global_surplus_cpu_moe += dmd.free;
453-
global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
475+
for (size_t id = 0; id < nd; id++) {
476+
global_surplus_cpu_moe += dmds_cpu_moe[id].free;
477+
global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id];
454478
}
455479

456480
if (global_surplus_cpu_moe > 0) {
@@ -469,7 +493,7 @@ static void llama_params_fit_impl(
469493
std::vector<int64_t> targets; // maximum acceptable memory use per device
470494
targets.reserve(nd);
471495
for (size_t id = 0; id < nd; id++) {
472-
targets.push_back(dmds_full[id].free - margin);
496+
targets.push_back(dmds_full[id].free - margins[id]);
473497
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
474498
}
475499

@@ -701,11 +725,11 @@ static void llama_params_fit_impl(
701725
enum llama_params_fit_status llama_params_fit(
702726
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
703727
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
704-
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
728+
size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) {
705729
const int64_t t0_us = llama_time_us();
706730
llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
707731
try {
708-
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
732+
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level);
709733
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
710734
} catch (const llama_params_fit_exception & e) {
711735
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());

tools/fit-params/fit-params.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ int main(int argc, char ** argv) {
2727
auto mparams = common_model_params_to_llama(params);
2828
auto cparams = common_context_params_to_llama(params);
2929
const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
30-
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
30+
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
3131
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
3232
if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) {
3333
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);

0 commit comments

Comments
 (0)