@@ -93,7 +93,8 @@ def __init__(self,
9393 base_seed = 1234 ,
9494 partition_method = 'parameters' ,
9595 activation_checkpoint_interval = 0 ,
96- activation_checkpoint_func = checkpointing .checkpoint ):
96+ activation_checkpoint_func = checkpointing .checkpoint ,
97+ checkpointable_layers = None ):
9798 """Modules to be parallelized with pipeline parallelism.
9899
99100 The key constraint that enables pipeline parallelism is the
@@ -198,6 +199,7 @@ def forward(self, inputs):
198199
199200 self .activation_checkpoint_interval = activation_checkpoint_interval
200201 self .activation_checkpoint_func = activation_checkpoint_func
202+ self .set_checkpointable_layers (checkpointable_layers )
201203
202204 def _build (self ):
203205 specs = self ._layer_specs
@@ -571,9 +573,19 @@ def load_state_dir(self, load_dir, strict=True):
571573
572574 self ._synchronize_tied_weights ()
573575
576+ def set_checkpointable_layers (self , string ):
577+ """
578+ Allows you to pass a string which defines which layers are checkpointable
579+ """
580+ self .checkpointable_layers = string
581+
574582 def _is_checkpointable (self , funcs ):
575583 if self .__class__ .__name__ == 'GPT2ModelPipe' :
576584 return all ('ParallelTransformerLayerPipe' in f .__class__ .__name__
577585 for f in funcs )
586+ elif self .checkpointable_layers is not None :
587+ ret = all (self .checkpointable_layers in f .__class__ .__name__
588+ for f in funcs )
589+ return ret
578590 params = [f .parameters () for f in funcs if isinstance (f , torch .nn .Module )]
579591 return any (len (list (p )) > 0 for p in params )
0 commit comments