[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191
[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191realAsma wants to merge 1 commit intoasma/new-qat-1from
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. 🗂️ Base branches to auto review (3)
Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## asma/new-qat-1 #1191 +/- ##
==================================================
+ Coverage 75.63% 75.64% +0.01%
==================================================
Files 462 462
Lines 49873 50116 +243
==================================================
+ Hits 37719 37912 +193
- Misses 12154 12204 +50
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
e349a90 to
b762870
Compare
97759a4 to
bfc343c
Compare
cc45203 to
9dd1732
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: This PR simplifies the HF knowledge distillation trainer by removing mtd.convert() class-swap in favor of explicit teacher forwarding, enhances ModelOptHFTrainer with Liger fused loss, per-parameter LRs, parameter freezing, and refactors the llm_qat example to use YAML configs with a new ModelOptArgParser.
Issues Found:
-
[Correctness] CRITICAL:
recipe.ptq_cfgdoes not exist — should berecipe.quantize
ModelOptPTQRecipe(inmodelopt/recipe/config.py:69) exposes its quant config via thequantizeattribute, notptq_cfg. This will raiseAttributeErrorat runtime in two places:examples/llm_qat/simple_qat_train.py:126—model = mtq.quantize(model, recipe.ptq_cfg, calibrate)modelopt/torch/quantization/plugins/transformers_trainer.py:217—return recipe.ptq_cfg
The correct usage is already in
examples/llm_qat/quantize.py:77(ptq_cfg = recipe.quantize), confirming this is a copy-paste error. Existing usage inexamples/llm_ptq/hf_ptq.pyalso usesrecipe.quantize. -
[Correctness]
LMLogitsLoss.forwarddouble-sums — returns scalar instead of per-token losses
LogitsDistillationLoss.forwardwithreduction="none"already sums over the vocab dimension (line 64 oflosses.py), returning shape(B*S,). The newLMLogitsLoss.forwardthen does another.sum(dim=-1)on this 1D tensor, collapsing it to a scalar. This makes the ignore-index masking in_standard_kd_lossa no-op — padding tokens contribute equally to the loss. -
[Correctness] Inconsistent causal shift between standard and Liger KD paths
_liger_kd_lossapplies the standard causal LM shift (hidden_states[..., :-1, :],labels[..., 1:]) before computing JSD. But_standard_kd_lossapplies no shift — it computes KL-div on allB*Spositions and masks with unshifted labels. While comparing student/teacher at the same position is valid, the different masking alignment means the two paths produce semantically different losses for the same input. This will be surprising when toggling--use_liger_kernel. -
[Correctness]
_forward_redirectdoesn't restoremodule.forwardon failure
If themodule(dummy)call raises before enteringwrapped_forward(e.g., FSDP pre-forward hook fails),module.forwardremains patched towrapped_forward. Atry/finallywould be safer. -
[Tests] Low coverage on core library files
Codecov reports ~20% patch coverage fortransformers.py(165 missing lines) andhuggingface.py(82 missing lines). The new Liger fused loss path,_forward_redirect,_sharded_liger_compute, parameter freezing, andsave_dtyperewriting have no unit test coverage. These are critical code paths for distributed training correctness. -
[Correctness]
save_dtypedefaults to"bfloat16"instead of preserving original model dtype
The oldQATTrainersaved the model's original dtype (self._original_dtype). The newModelOptHFTrainer.save_modelhardcodessave_dtype="bfloat16"by default. For models originally infloat16, this silently changes the config.json dtype, which may affect downstream inference engines. -
[Readability] Misleading shape comments in
LMLogitsLoss
The comment# (B*S, V)on thesuper().forward()call is wrong — the parent returns(B*S,)whenreduction="none". This directly led to bug #2.
Suggestions:
- The
liger-kernel>=0.5.0addition topyproject.toml[hf]extras makes it a hard install dependency for all HF users. Since usage is guarded by--use_liger_kernel, consider making it an optional extra ([liger]or[hf-liger]) to avoid install issues in constrained environments. - The pre-commit hook for
generate-arguments-mduseslanguage: systemand runspython examples/llm_qat/train.py --generate_docs. This requires all modelopt dependencies to be installed in the pre-commit environment, which will fail for most contributors. Considerlanguage: pythonwith explicit dependencies, or making it a manual step. _dataset_cache(module-level mutable dict indataset_utils.py) acts as an in-memory cache but is never evicted. In long-running processes or notebooks, this could hold large datasets in memory indefinitely.
Overall Assessment: The architectural direction is sound — removing mtd.convert() class-swap in favor of explicit teacher forwarding is a significant simplification. The ModelOptArgParser and YAML-based config system is a good developer experience improvement. However, there are two critical correctness bugs (recipe.ptq_cfg and LMLogitsLoss double-sum) that must be fixed before merge, plus the KD loss path inconsistency warrants discussion.
| def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None: | ||
| """Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str.""" | ||
| cfg_path = os.path.join(output_dir, "config.json") | ||
| if not os.path.isfile(cfg_path): |
There was a problem hiding this comment.
save_dtype defaults to "bfloat16" which silently changes the dtype for float16 models. Consider defaulting to None and falling back to the model's original dtype when not explicitly set.
There was a problem hiding this comment.
Partial fix landed in 1f1c2507 (dataclass default set to None). Follow-up fix staged locally: _update_config_json_dtype now early-returns when dtype_str is None, so the original model dtype written by super().save_model() is preserved. The getattr fallback in save_model is also aligned to None for consistency. Will be included in the next push.
99da38e to
7718d64
Compare
4472470 to
1088304
Compare
564b431 to
35b4f2a
Compare
7718d64 to
4c5a889
Compare
35b4f2a to
d94e64e
Compare
ba1d725 to
be7d8e2
Compare
d94e64e to
30875fc
Compare
Introduce ModelOptHFTrainer wrapping HF Trainer with modelopt features (quantization, LR config, trainable/frozen param globs, save_dtype config rewrite, Liger fused CE, manual GC, etc.) and simplify the KDTrainer distillation API on top of it. Also includes follow-up fixes applied during review: - Causal shift fix and forward-restore safety in KDTrainer - DeepSpeed ZeRO-3 support in KDTrainer; Liger hidden-states dtype fix - save_dtype defaults to "bfloat16"; config.json rewrite skipped when save_dtype is None - Narrowed exceptions, moved defaults to configs, fixed recipe.quantize reference in transformers_trainer.py Signed-off-by: realAsma <akuriparambi@nvidia.com>
be7d8e2 to
f246115
Compare
Summary
This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base
ModelOptHFTrainerwith Liger fused loss, per-parameter learning rates, and training utilities.Model-agnostic Liger kernel fused loss
Adds custom Liger kernel integration in
ModelOptHFTrainerthat extends HuggingFace's built-in support in three ways:lm_head, unlike HF's Liger which only supports a fixed set of model architectures.KDTrainerextends fused loss to knowledge distillation viaLigerFusedLinearJSDfor fused lm_head + Jensen-Shannon divergence.Liger kernel memory sweep (Qwen3-1.7B, 2×H100 FSDP2, NVFP4+FP8_KV)
Max per-GPU batch size before OOM at each sequence length:
QAT (no teacher)
QAD (with teacher)
Liger fused loss enables 2-4× larger batch sizes at long context lengths by avoiding the materialization of the full logit tensor.
ModelOptHFTrainer enhancements
ModelOptTrainerArgumentswith--trainable_params,--frozen_params,--lr_config,--save_dtype, and--manual_gcflagslr_config)_prepare_modeland_update_config_json_dtypepromoted to base classKDTrainer simplification + fix
Removes
mtd.convert()and theDistillationModelin-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly insidecompute_kd_loss_func. This eliminates:mtd.convert()in-place class swap and DynamicModule wrappinghide_teacher_model/hide_loss_modulescontext managers for checkpointingsave_modelandQADTrainer._quantize_modeloverridesBug fix: The previous
DistillationModel/mtd.convert()approach did not support CPU RAM-efficient loading for QAD. The teacher model had to be fully loaded on GPU before wrapping, which doubled peak memory during initialization. The new approach loads the teacher lazily on the trainer, enabling standard HF device-map and low-cpu-mem-usage loading.Only logit-level distillation is supported for the HF path. The core
DistillationModel/mtd.convert()API remains for Megatron and advanced intermediate-layer distillation use cases.Test plan
pytest tests/unit/torch/distill/(29 passed)pytest tests/unit/torch/opt/plugins/test_hf_patching.py(2 passed)pytest tests/unit/torch/opt/plugins/test_lr_config.pypytest tests/examples/llm_qat/(QAT, QAD, LoRA QAT, QLoRA)pytest tests/examples/llm_distill/🤖 Generated with Claude Code