Skip to content

Commit a5d4dc1

Browse files
authored
Update model_utils.py (deepspeedai#471)
1 parent d2269a5 commit a5d4dc1

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

applications/DeepSpeed-Chat/training/utils/model/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
AutoConfig,
1010
AutoModel,
1111
)
12-
12+
from huggingface_hub import snapshot_download
1313
from transformers.deepspeed import HfDeepSpeedConfig
1414

1515
from .reward_model import RewardModel
@@ -64,6 +64,8 @@ def create_critic_model(model_name_or_path,
6464
num_padding_at_beginning=num_padding_at_beginning)
6565

6666
if rlhf_training:
67+
if not os.path.isdir(model_name_or_path):
68+
model_name_or_path = snapshot_download(model_name_or_path)
6769
# critic model needs to load the weight here
6870
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
6971
assert os.path.exists(

0 commit comments

Comments
 (0)