55* Artem Ploujnikov 2021
66"""
77import os
8+ import logging
89import shutil
910
11+ logger = logging .getLogger (__name__ )
12+
1013
1114def save_for_pretrained (
1215 hparams ,
1316 min_key = None ,
1417 max_key = None ,
18+ ckpt_predicate = None ,
1519 pretrainer_key = "pretrainer" ,
1620 checkpointer_key = "checkpointer" ,
1721):
@@ -24,12 +28,14 @@ def save_for_pretrained(
2428 ---------
2529 hparams: dict
2630 the hyperparameter file
27- max_key : str
31+ max_key: str
2832 Key to use for finding best checkpoint (higher is better).
2933 By default, passed to ``self.checkpointer.recover_if_possible()``.
30- min_key : str
34+ min_key: str
3135 Key to use for finding best checkpoint (lower is better).
32- By def ault, passed to ``self.checkpointer.recover_if_possible()``.
36+ By default, passed to ``self.checkpointer.recover_if_possible()``.
37+ ckpt_predicate: callable
38+ a filter predicate to locate checkpoints
3339 checkpointer_key: str
3440 the key under which the checkpointer is stored
3541 pretrained_key: str
@@ -42,21 +48,42 @@ def save_for_pretrained(
4248 )
4349 pretrainer = hparams [pretrainer_key ]
4450 checkpointer = hparams [checkpointer_key ]
45- checkpoint = checkpointer .find_checkpoint (min_key = min_key , max_key = max_key )
51+ checkpoint = checkpointer .find_checkpoint (
52+ min_key = min_key , max_key = max_key , ckpt_predicate = ckpt_predicate
53+ )
54+ if checkpoint :
55+ logger .info (
56+ "Saving checkpoint '%s' a pretrained model" , checkpoint .path
57+ )
58+ pretrainer_keys = set (pretrainer .loadables .keys ())
59+ checkpointer_keys = set (checkpoint .paramfiles .keys ())
60+ keys_to_save = pretrainer_keys & checkpointer_keys
61+ for key in keys_to_save :
62+ source_path = checkpoint .paramfiles [key ]
63+ if not os .path .exists (source_path ):
64+ raise ValueError (
65+ f"File { source_path } does not exist in the checkpoint"
66+ )
67+ target_path = pretrainer .paths [key ]
68+ dirname = os .path .dirname (target_path )
69+ if not os .path .exists (dirname ):
70+ os .makedirs (dirname )
71+ if os .path .exists (target_path ):
72+ os .remove (target_path )
73+ shutil .copyfile (source_path , target_path )
74+ saved = True
75+ else :
76+ logger .info (
77+ "Unable to find a matching checkpoint for min_key = %s, max_key = %s" ,
78+ min_key ,
79+ max_key ,
80+ )
81+ checkpoints = checkpointer .list_checkpoints ()
82+ checkpoints_str = "\n " .join (
83+ f"{ checkpoint .path } : { checkpoint .meta } "
84+ for checkpoint in checkpoints
85+ )
86+ logger .info ("Available checkpoints: %s" , checkpoints_str )
87+ saved = False
4688
47- pretrainer_keys = set (pretrainer .loadables .keys ())
48- checkpointer_keys = set (checkpoint .paramfiles .keys ())
49- keys_to_save = pretrainer_keys & checkpointer_keys
50- for key in keys_to_save :
51- source_path = checkpoint .paramfiles [key ]
52- if not os .path .exists (source_path ):
53- raise ValueError (
54- f"File { source_path } does not exist in the checkpoint"
55- )
56- target_path = pretrainer .paths [key ]
57- dirname = os .path .dirname (target_path )
58- if not os .path .exists (dirname ):
59- os .makedirs (dirname )
60- if os .path .exists (target_path ):
61- os .remove (target_path )
62- shutil .copyfile (source_path , target_path )
89+ return saved
0 commit comments