Skip to content

Commit 87fbb8f

Browse files
sweinbachSamuel Weinbach
andauthored
fix large checkpoints in pipe parallel (EleutherAI#33)
Co-authored-by: Samuel Weinbach <samuel.weinbach@gmail.com>
1 parent c45ec1c commit 87fbb8f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

deepspeed/runtime/pipe/module.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,10 @@ def save_state_dict(self, save_dir):
553553
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
554554
if not hasattr(layer, 'state_dict'):
555555
continue
556-
torch.save(layer.state_dict(), model_ckpt_path)
556+
557+
# for some reason torch is saving lots of unnecessary information
558+
# this line forces a copy of the tensor in order to get rid of temporary stuff
559+
torch.save({k: (v.cpu() if torch.is_tensor(v) else v) for k, v in layer.state_dict().items()}, model_ckpt_path)
557560

558561
def load_state_dir(self, load_dir, strict=True):
559562
rank = dist.get_rank()

0 commit comments

Comments
 (0)