Skip to content

Commit 07155e9

Browse files
committed
Update save_for_pretrained to avod crashes when a matching checkpoint is not found and output details
1 parent fbb4da7 commit 07155e9

1 file changed

Lines changed: 47 additions & 20 deletions

File tree

speechbrain/pretrained/training.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
* Artem Ploujnikov 2021
66
"""
77
import os
8+
import logging
89
import shutil
910

11+
logger = logging.getLogger(__name__)
12+
1013

1114
def save_for_pretrained(
1215
hparams,
1316
min_key=None,
1417
max_key=None,
18+
ckpt_predicate=None,
1519
pretrainer_key="pretrainer",
1620
checkpointer_key="checkpointer",
1721
):
@@ -24,12 +28,14 @@ def save_for_pretrained(
2428
---------
2529
hparams: dict
2630
the hyperparameter file
27-
max_key : str
31+
max_key: str
2832
Key to use for finding best checkpoint (higher is better).
2933
By default, passed to ``self.checkpointer.recover_if_possible()``.
30-
min_key : str
34+
min_key: str
3135
Key to use for finding best checkpoint (lower is better).
32-
By def ault, passed to ``self.checkpointer.recover_if_possible()``.
36+
By default, passed to ``self.checkpointer.recover_if_possible()``.
37+
ckpt_predicate: callable
38+
a filter predicate to locate checkpoints
3339
checkpointer_key: str
3440
the key under which the checkpointer is stored
3541
pretrained_key: str
@@ -42,21 +48,42 @@ def save_for_pretrained(
4248
)
4349
pretrainer = hparams[pretrainer_key]
4450
checkpointer = hparams[checkpointer_key]
45-
checkpoint = checkpointer.find_checkpoint(min_key=min_key, max_key=max_key)
51+
checkpoint = checkpointer.find_checkpoint(
52+
min_key=min_key, max_key=max_key, ckpt_predicate=ckpt_predicate
53+
)
54+
if checkpoint:
55+
logger.info(
56+
"Saving checkpoint '%s' a pretrained model", checkpoint.path
57+
)
58+
pretrainer_keys = set(pretrainer.loadables.keys())
59+
checkpointer_keys = set(checkpoint.paramfiles.keys())
60+
keys_to_save = pretrainer_keys & checkpointer_keys
61+
for key in keys_to_save:
62+
source_path = checkpoint.paramfiles[key]
63+
if not os.path.exists(source_path):
64+
raise ValueError(
65+
f"File {source_path} does not exist in the checkpoint"
66+
)
67+
target_path = pretrainer.paths[key]
68+
dirname = os.path.dirname(target_path)
69+
if not os.path.exists(dirname):
70+
os.makedirs(dirname)
71+
if os.path.exists(target_path):
72+
os.remove(target_path)
73+
shutil.copyfile(source_path, target_path)
74+
saved = True
75+
else:
76+
logger.info(
77+
"Unable to find a matching checkpoint for min_key = %s, max_key = %s",
78+
min_key,
79+
max_key,
80+
)
81+
checkpoints = checkpointer.list_checkpoints()
82+
checkpoints_str = "\n".join(
83+
f"{checkpoint.path}: {checkpoint.meta}"
84+
for checkpoint in checkpoints
85+
)
86+
logger.info("Available checkpoints: %s", checkpoints_str)
87+
saved = False
4688

47-
pretrainer_keys = set(pretrainer.loadables.keys())
48-
checkpointer_keys = set(checkpoint.paramfiles.keys())
49-
keys_to_save = pretrainer_keys & checkpointer_keys
50-
for key in keys_to_save:
51-
source_path = checkpoint.paramfiles[key]
52-
if not os.path.exists(source_path):
53-
raise ValueError(
54-
f"File {source_path} does not exist in the checkpoint"
55-
)
56-
target_path = pretrainer.paths[key]
57-
dirname = os.path.dirname(target_path)
58-
if not os.path.exists(dirname):
59-
os.makedirs(dirname)
60-
if os.path.exists(target_path):
61-
os.remove(target_path)
62-
shutil.copyfile(source_path, target_path)
89+
return saved

0 commit comments

Comments
 (0)