Skip to content

Commit 2d78b37

Browse files
authored
add 'set_checkpointable_layers' func
1 parent faae906 commit 2d78b37

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

deepspeed/runtime/pipe/module.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def __init__(self,
9393
base_seed=1234,
9494
partition_method='parameters',
9595
activation_checkpoint_interval=0,
96-
activation_checkpoint_func=checkpointing.checkpoint):
96+
activation_checkpoint_func=checkpointing.checkpoint,
97+
checkpointable_layers=None):
9798
"""Modules to be parallelized with pipeline parallelism.
9899
99100
The key constraint that enables pipeline parallelism is the
@@ -198,6 +199,7 @@ def forward(self, inputs):
198199

199200
self.activation_checkpoint_interval = activation_checkpoint_interval
200201
self.activation_checkpoint_func = activation_checkpoint_func
202+
self.set_checkpointable_layers(checkpointable_layers)
201203

202204
def _build(self):
203205
specs = self._layer_specs
@@ -571,9 +573,19 @@ def load_state_dir(self, load_dir, strict=True):
571573

572574
self._synchronize_tied_weights()
573575

576+
def set_checkpointable_layers(self, string):
577+
"""
578+
Allows you to pass a string which defines which layers are checkpointable
579+
"""
580+
self.checkpointable_layers = string
581+
574582
def _is_checkpointable(self, funcs):
575583
if self.__class__.__name__ == 'GPT2ModelPipe':
576584
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
577585
for f in funcs)
586+
elif self.checkpointable_layers is not None:
587+
ret = all(self.checkpointable_layers in f.__class__.__name__
588+
for f in funcs)
589+
return ret
578590
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
579591
return any(len(list(p)) > 0 for p in params)

0 commit comments

Comments
 (0)