We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c45ec1c commit 87fbb8fCopy full SHA for 87fbb8f
deepspeed/runtime/pipe/module.py
@@ -553,7 +553,10 @@ def save_state_dict(self, save_dir):
553
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
554
if not hasattr(layer, 'state_dict'):
555
continue
556
- torch.save(layer.state_dict(), model_ckpt_path)
+
557
+ # for some reason torch is saving lots of unnecessary information
558
+ # this line forces a copy of the tensor in order to get rid of temporary stuff
559
+ torch.save({k: (v.cpu() if torch.is_tensor(v) else v) for k, v in layer.state_dict().items()}, model_ckpt_path)
560
561
def load_state_dir(self, load_dir, strict=True):
562
rank = dist.get_rank()
0 commit comments