Skip to content

[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191

Open
realAsma wants to merge 1 commit intoasma/new-qat-1from
asma/new-qat-2
Open

[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191
realAsma wants to merge 1 commit intoasma/new-qat-1from
asma/new-qat-2

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Apr 7, 2026

Summary

This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base ModelOptHFTrainer with Liger fused loss, per-parameter learning rates, and training utilities.

Model-agnostic Liger kernel fused loss

Adds custom Liger kernel integration in ModelOptHFTrainer that extends HuggingFace's built-in support in three ways:

  1. Model-agnostic: Works with any causal LM that has an lm_head, unlike HF's Liger which only supports a fixed set of model architectures.
  2. DeepSpeed ZeRO-3 support: HF's Liger integration only works with FSDP. ModelOpt adds distributed param gathering for DeepSpeed ZeRO-3 and DDP as well.
  3. KD loss support: KDTrainer extends fused loss to knowledge distillation via LigerFusedLinearJSD for fused lm_head + Jensen-Shannon divergence.

Liger kernel memory sweep (Qwen3-1.7B, 2×H100 FSDP2, NVFP4+FP8_KV)

Max per-GPU batch size before OOM at each sequence length:

QAT (no teacher)

Seq Length 512 1024 2048 4096 8192 16384
Liger 16 16 16 16 8 4
No Liger 16 16 8 4 2 OOM

QAD (with teacher)

Seq Length 512 1024 2048 4096 8192 16384
Liger 16 16 8 4 2 1
No Liger 8 4 2 1 OOM OOM

Liger fused loss enables 2-4× larger batch sizes at long context lengths by avoiding the materialization of the full logit tensor.

ModelOptHFTrainer enhancements

  • ModelOptTrainerArguments with --trainable_params, --frozen_params, --lr_config, --save_dtype, and --manual_gc flags
  • Per-parameter learning rate support via YAML config (lr_config)
  • _prepare_model and _update_config_json_dtype promoted to base class

KDTrainer simplification + fix

Removes mtd.convert() and the DistillationModel in-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly inside compute_kd_loss_func. This eliminates:

  • mtd.convert() in-place class swap and DynamicModule wrapping
  • Forward hooks for capturing intermediate outputs
  • hide_teacher_model / hide_loss_modules context managers for checkpointing
  • Deferred initialization branching (FSDP2 vs DDP/DeepSpeed)
  • save_model and QADTrainer._quantize_model overrides

Bug fix: The previous DistillationModel/mtd.convert() approach did not support CPU RAM-efficient loading for QAD. The teacher model had to be fully loaded on GPU before wrapping, which doubled peak memory during initialization. The new approach loads the teacher lazily on the trainer, enabling standard HF device-map and low-cpu-mem-usage loading.

Only logit-level distillation is supported for the HF path. The core DistillationModel/mtd.convert() API remains for Megatron and advanced intermediate-layer distillation use cases.

Test plan

  • pytest tests/unit/torch/distill/ (29 passed)
  • pytest tests/unit/torch/opt/plugins/test_hf_patching.py (2 passed)
  • pytest tests/unit/torch/opt/plugins/test_lr_config.py
  • Pre-commit hooks pass
  • GPU example tests: pytest tests/examples/llm_qat/ (QAT, QAD, LoRA QAT, QLoRA)
  • GPU distill example: pytest tests/examples/llm_distill/

🤖 Generated with Claude Code

@realAsma realAsma requested review from a team as code owners April 7, 2026 21:40
@realAsma realAsma requested review from Edwardf0t1 and removed request for a team April 7, 2026 21:40
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

🗂️ Base branches to auto review (3)
  • main
  • release/.*
  • feature/.*

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 2501b6ca-d635-4b19-93f7-ac7498c5357e

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/new-qat-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma requested review from ChenhanYu and shengliangxu and removed request for a team April 7, 2026 21:40
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 7, 2026

Codecov Report

❌ Patch coverage is 74.17417% with 86 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.64%. Comparing base (30875fc) to head (f246115).

Files with missing lines Patch % Lines
modelopt/torch/opt/plugins/transformers.py 62.14% 81 Missing ⚠️
modelopt/torch/distill/plugins/huggingface.py 96.61% 4 Missing ⚠️
...torch/quantization/plugins/transformers_trainer.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@                Coverage Diff                 @@
##           asma/new-qat-1    #1191      +/-   ##
==================================================
+ Coverage           75.63%   75.64%   +0.01%     
==================================================
  Files                 462      462              
  Lines               49873    50116     +243     
==================================================
+ Hits                37719    37912     +193     
- Misses              12154    12204      +50     
Flag Coverage Δ
examples 41.74% <74.17%> (+0.20%) ⬆️
gpu 58.37% <19.81%> (-0.20%) ⬇️
regression 14.89% <19.81%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma realAsma force-pushed the asma/new-qat-2 branch 2 times, most recently from e349a90 to b762870 Compare April 7, 2026 22:36
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 97759a4 to bfc343c Compare April 8, 2026 16:03
@realAsma realAsma force-pushed the asma/new-qat-1 branch 3 times, most recently from cc45203 to 9dd1732 Compare April 9, 2026 18:48
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: This PR simplifies the HF knowledge distillation trainer by removing mtd.convert() class-swap in favor of explicit teacher forwarding, enhances ModelOptHFTrainer with Liger fused loss, per-parameter LRs, parameter freezing, and refactors the llm_qat example to use YAML configs with a new ModelOptArgParser.

Issues Found:

  1. [Correctness] CRITICAL: recipe.ptq_cfg does not exist — should be recipe.quantize
    ModelOptPTQRecipe (in modelopt/recipe/config.py:69) exposes its quant config via the quantize attribute, not ptq_cfg. This will raise AttributeError at runtime in two places:

    • examples/llm_qat/simple_qat_train.py:126model = mtq.quantize(model, recipe.ptq_cfg, calibrate)
    • modelopt/torch/quantization/plugins/transformers_trainer.py:217return recipe.ptq_cfg

    The correct usage is already in examples/llm_qat/quantize.py:77 (ptq_cfg = recipe.quantize), confirming this is a copy-paste error. Existing usage in examples/llm_ptq/hf_ptq.py also uses recipe.quantize.

  2. [Correctness] LMLogitsLoss.forward double-sums — returns scalar instead of per-token losses
    LogitsDistillationLoss.forward with reduction="none" already sums over the vocab dimension (line 64 of losses.py), returning shape (B*S,). The new LMLogitsLoss.forward then does another .sum(dim=-1) on this 1D tensor, collapsing it to a scalar. This makes the ignore-index masking in _standard_kd_loss a no-op — padding tokens contribute equally to the loss.

  3. [Correctness] Inconsistent causal shift between standard and Liger KD paths
    _liger_kd_loss applies the standard causal LM shift (hidden_states[..., :-1, :], labels[..., 1:]) before computing JSD. But _standard_kd_loss applies no shift — it computes KL-div on all B*S positions and masks with unshifted labels. While comparing student/teacher at the same position is valid, the different masking alignment means the two paths produce semantically different losses for the same input. This will be surprising when toggling --use_liger_kernel.

  4. [Correctness] _forward_redirect doesn't restore module.forward on failure
    If the module(dummy) call raises before entering wrapped_forward (e.g., FSDP pre-forward hook fails), module.forward remains patched to wrapped_forward. A try/finally would be safer.

  5. [Tests] Low coverage on core library files
    Codecov reports ~20% patch coverage for transformers.py (165 missing lines) and huggingface.py (82 missing lines). The new Liger fused loss path, _forward_redirect, _sharded_liger_compute, parameter freezing, and save_dtype rewriting have no unit test coverage. These are critical code paths for distributed training correctness.

  6. [Correctness] save_dtype defaults to "bfloat16" instead of preserving original model dtype
    The old QATTrainer saved the model's original dtype (self._original_dtype). The new ModelOptHFTrainer.save_model hardcodes save_dtype="bfloat16" by default. For models originally in float16, this silently changes the config.json dtype, which may affect downstream inference engines.

  7. [Readability] Misleading shape comments in LMLogitsLoss
    The comment # (B*S, V) on the super().forward() call is wrong — the parent returns (B*S,) when reduction="none". This directly led to bug #2.

Suggestions:

  • The liger-kernel>=0.5.0 addition to pyproject.toml [hf] extras makes it a hard install dependency for all HF users. Since usage is guarded by --use_liger_kernel, consider making it an optional extra ([liger] or [hf-liger]) to avoid install issues in constrained environments.
  • The pre-commit hook for generate-arguments-md uses language: system and runs python examples/llm_qat/train.py --generate_docs. This requires all modelopt dependencies to be installed in the pre-commit environment, which will fail for most contributors. Consider language: python with explicit dependencies, or making it a manual step.
  • _dataset_cache (module-level mutable dict in dataset_utils.py) acts as an in-memory cache but is never evicted. In long-running processes or notebooks, this could hold large datasets in memory indefinitely.

Overall Assessment: The architectural direction is sound — removing mtd.convert() class-swap in favor of explicit teacher forwarding is a significant simplification. The ModelOptArgParser and YAML-based config system is a good developer experience improvement. However, there are two critical correctness bugs (recipe.ptq_cfg and LMLogitsLoss double-sum) that must be fixed before merge, plus the KD loss path inconsistency warrants discussion.

Comment thread examples/llm_qat/simple_qat_train.py Outdated
Comment thread modelopt/torch/quantization/plugins/transformers_trainer.py Outdated
Comment thread modelopt/torch/distill/plugins/huggingface.py Outdated
Comment thread modelopt/torch/distill/plugins/huggingface.py Outdated
Comment thread modelopt/torch/distill/plugins/huggingface.py
Comment thread modelopt/torch/opt/plugins/transformers.py
def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None:
"""Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str."""
cfg_path = os.path.join(output_dir, "config.json")
if not os.path.isfile(cfg_path):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_dtype defaults to "bfloat16" which silently changes the dtype for float16 models. Consider defaulting to None and falling back to the model's original dtype when not explicitly set.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial fix landed in 1f1c2507 (dataclass default set to None). Follow-up fix staged locally: _update_config_json_dtype now early-returns when dtype_str is None, so the original model dtype written by super().save_model() is preserved. The getattr fallback in save_model is also aligned to None for consistency. Will be included in the next push.

Comment thread .pre-commit-config.yaml
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 4472470 to 1088304 Compare April 14, 2026 14:12
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 564b431 to 35b4f2a Compare April 15, 2026 17:34
@realAsma realAsma changed the title Simplify KDTrainer and enhance ModelOptHFTrainer [2/N] Simplify KDTrainer and enhance ModelOptHFTrainer Apr 16, 2026
Introduce ModelOptHFTrainer wrapping HF Trainer with modelopt features
(quantization, LR config, trainable/frozen param globs, save_dtype
config rewrite, Liger fused CE, manual GC, etc.) and simplify the
KDTrainer distillation API on top of it.

Also includes follow-up fixes applied during review:
- Causal shift fix and forward-restore safety in KDTrainer
- DeepSpeed ZeRO-3 support in KDTrainer; Liger hidden-states dtype fix
- save_dtype defaults to "bfloat16"; config.json rewrite skipped when
  save_dtype is None
- Narrowed exceptions, moved defaults to configs, fixed recipe.quantize
  reference in transformers_trainer.py

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants