Tags: deepspeedai/DeepSpeed
Tags
Fix hook count performance regression from v0.18.5 (#7886) Fixes performance regressions reported in #7882 and #7885. PR #7780 added dynamic hook count computation for reentrant checkpointing correctness, but placed the call inside every gradient hook closure. For a model with n parameter tensors, this creates significant overhead per backward pass. Summary: 1. Added `should_refresh_expected_hook_count()` predicate that returns true only at backward phase boundaries (first hook, or new reentrant phase), so `count_used_parameters_in_backward()` is called once per phase instead of once per hook. 2. Applied this predicate in ZeRO-1/2 (stage_1_and_2.py) and both ZeRO-3 hook sites (stage3.py), reusing the `cached_max_expected_hooks_seen` value when refresh isn't needed. 3. Changed enter_backward() to reset hook counters on first real backward entry, preventing pollution from pre-user-backward autograd calls (e.g., TiledFusedLogitsLoss). With 24-layer transformer, ~267M params (147 parameter tensors), ZeRO-2, 8×H100 80GB, bf16, batch size 8, 20 warmup + 20 measured iterations: - Before fix: 0.1265s/iter - After fix: 0.0505s/iter --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Ramya Ramineni <rraminen@users.noreply.github.com>
Replace torch.jit.script with torch.compile (#7835) (#7840) Fixes #7835. On torch==2.10.0, importing DeepSpeed emitted deprecation warnings from import-time JIT-decorated helpers. This change updates the compatibility path to align with PyTorch guidance while keeping import clean. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
fix: avoid IndexError in BF16_Optimizer.destroy() when using DummyOpt… …im (#7763) fix: avoid IndexError in BF16_Optimizer.destroy() when using DummyOptim Short-circuit BF16_Optimizer.destroy() if using_real_optimizer is False. When initialized with optimizer=None (DummyOptim), bf16_groups remains empty, causing an IndexError when accessing it in destroy(). Resolves #7752
PreviousNext