@@ -147,9 +147,8 @@ class llama_params_fit_exception : public std::runtime_error {
147147static void llama_params_fit_impl (
148148 const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
149149 float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
150- size_t margin_s , uint32_t n_ctx_min, enum ggml_log_level log_level) {
150+ size_t * margins_s , uint32_t n_ctx_min, enum ggml_log_level log_level) {
151151 constexpr int64_t MiB = 1024 *1024 ;
152- const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
153152 typedef std::vector<llama_device_memory_data> dmds_t ;
154153 const llama_model_params default_mparams = llama_model_default_params ();
155154
@@ -168,6 +167,12 @@ static void llama_params_fit_impl(
168167 return ;
169168 }
170169
170+ std::vector<int64_t > margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
171+ margins.reserve (nd);
172+ for (size_t id = 0 ; id < nd; id++) {
173+ margins.push_back (margins_s[id]);
174+ }
175+
171176 std::vector<std::string> dev_names;
172177 {
173178 dev_names.reserve (nd);
@@ -187,9 +192,10 @@ static void llama_params_fit_impl(
187192
188193 int64_t sum_free = 0 ;
189194 int64_t sum_projected_free = 0 ;
190- int64_t min_projected_free = INT64_MAX ;
191195 int64_t sum_projected_used = 0 ;
192196 int64_t sum_projected_model = 0 ;
197+ std::vector<int64_t > projected_free_per_device;
198+ projected_free_per_device.reserve (nd);
193199
194200 if (nd > 1 ) {
195201 LLAMA_LOG_INFO (" %s: projected memory use with initial parameters [MiB]:\n " , __func__);
@@ -199,45 +205,63 @@ static void llama_params_fit_impl(
199205
200206 const int64_t projected_used = dmd.mb .total ();
201207 const int64_t projected_free = dmd.free - projected_used;
208+ projected_free_per_device.push_back (projected_free);
202209
203210 sum_free += dmd.free ;
204211 sum_projected_used += projected_used;
205212 sum_projected_free += projected_free;
206- min_projected_free = std::min (min_projected_free, projected_free);
207213 sum_projected_model += dmd.mb .model ;
208214
209215 if (nd > 1 ) {
210- LLAMA_LOG_INFO (" %s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n " ,
211- __func__, dev_names[id].c_str (), dmd.total /MiB, projected_used/MiB, std::abs (projected_free)/MiB,
212- projected_free >= 0 ? " surplus" : " deficit" );
216+ LLAMA_LOG_INFO (" %s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 " \n " ,
217+ __func__, dev_names[id].c_str (), dmd.total /MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
213218 }
214219 }
215220 assert (sum_free >= 0 && sum_projected_used >= 0 );
216221 LLAMA_LOG_INFO (" %s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n " ,
217222 __func__, sum_projected_used/MiB, sum_free/MiB);
218- if (min_projected_free >= margin ) {
219- if (nd == 1 ) {
223+ if (nd == 1 ) {
224+ if (projected_free_per_device[ 0 ] >= margins[ 0 ] ) {
220225 LLAMA_LOG_INFO (" %s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n " ,
221- __func__, min_projected_free/MiB, margin/MiB);
226+ __func__, projected_free_per_device[0 ]/MiB, margins[0 ]/MiB);
227+ return ;
228+ }
229+ } else {
230+ bool changes_needed = false ;
231+ for (size_t id = 0 ; id < nd; id++) {
232+ if (projected_free_per_device[id] < margins[id]) {
233+ changes_needed = true ;
234+ break ;
235+ }
236+ }
237+ if (!changes_needed) {
238+ LLAMA_LOG_INFO (" %s: targets for free memory can be met on all devices, no changes needed\n " , __func__);
222239 return ;
223240 }
224- LLAMA_LOG_INFO (" %s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n " ,
225- __func__, min_projected_free/MiB, margin/MiB);
226- return ;
227241 }
228242
229243 // step 2: try reducing memory use by reducing the context size
230244
231245 {
232- int64_t global_surplus = sum_projected_free - int64_t (nd)*margin;
246+ int64_t global_surplus = sum_projected_free;
247+ for (size_t id = 0 ; id < nd; id++) {
248+ global_surplus -= margins[id];
249+ }
233250 if (global_surplus < 0 ) {
234- LLAMA_LOG_INFO (nd == 1 ?
235- " %s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n " :
236- " %s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n " ,
237- __func__, margin/MiB, -global_surplus/MiB);
251+ if (nd == 1 ) {
252+ LLAMA_LOG_INFO (" %s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n " ,
253+ __func__, margins[0 ]/MiB, -global_surplus/MiB);
254+ } else {
255+ LLAMA_LOG_INFO (
256+ " %s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n " ,
257+ __func__, -global_surplus/MiB);
258+ }
238259 if (cparams->n_ctx == 0 ) {
239260 if (hp_nct > n_ctx_min) {
240- int64_t sum_used_target = sum_free - nd*margin_s;
261+ int64_t sum_used_target = sum_free;
262+ for (size_t id = 0 ; id < nd; id++) {
263+ sum_used_target -= margins[id];
264+ }
241265 if (nd > 1 ) {
242266 // for multiple devices we need to be more conservative in terms of how much context we think can fit:
243267 // - for dense models only whole layers can be assigned to devices
@@ -448,9 +472,9 @@ static void llama_params_fit_impl(
448472 const dmds_t dmds_cpu_moe = llama_get_device_memory_data (
449473 path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
450474
451- for (const llama_device_memory_data & dmd : dmds_cpu_moe ) {
452- global_surplus_cpu_moe += dmd .free ;
453- global_surplus_cpu_moe -= int64_t (dmd .mb .total ()) + margin ;
475+ for (size_t id = 0 ; id < nd; id++ ) {
476+ global_surplus_cpu_moe += dmds_cpu_moe[id] .free ;
477+ global_surplus_cpu_moe -= int64_t (dmds_cpu_moe[id] .mb .total ()) + margins[id] ;
454478 }
455479
456480 if (global_surplus_cpu_moe > 0 ) {
@@ -469,7 +493,7 @@ static void llama_params_fit_impl(
469493 std::vector<int64_t > targets; // maximum acceptable memory use per device
470494 targets.reserve (nd);
471495 for (size_t id = 0 ; id < nd; id++) {
472- targets.push_back (dmds_full[id].free - margin );
496+ targets.push_back (dmds_full[id].free - margins[id] );
473497 LLAMA_LOG_DEBUG (" %s: id=%zu, target=%" PRId64 " MiB\n " , __func__, id, targets[id]/MiB);
474498 }
475499
@@ -701,11 +725,11 @@ static void llama_params_fit_impl(
701725enum llama_params_fit_status llama_params_fit (
702726 const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
703727 float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
704- size_t margin_s , uint32_t n_ctx_min, enum ggml_log_level log_level) {
728+ size_t * margins , uint32_t n_ctx_min, enum ggml_log_level log_level) {
705729 const int64_t t0_us = llama_time_us ();
706730 llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS ;
707731 try {
708- llama_params_fit_impl (path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s , n_ctx_min, log_level);
732+ llama_params_fit_impl (path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins , n_ctx_min, log_level);
709733 LLAMA_LOG_INFO (" %s: successfully fit params to free device memory\n " , __func__);
710734 } catch (const llama_params_fit_exception & e) {
711735 LLAMA_LOG_WARN (" %s: failed to fit params to free device memory: %s\n " , __func__, e.what ());
0 commit comments