We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d2269a5 commit a5d4dc1Copy full SHA for a5d4dc1
1 file changed
applications/DeepSpeed-Chat/training/utils/model/model_utils.py
@@ -9,7 +9,7 @@
9
AutoConfig,
10
AutoModel,
11
)
12
-
+from huggingface_hub import snapshot_download
13
from transformers.deepspeed import HfDeepSpeedConfig
14
15
from .reward_model import RewardModel
@@ -64,6 +64,8 @@ def create_critic_model(model_name_or_path,
64
num_padding_at_beginning=num_padding_at_beginning)
65
66
if rlhf_training:
67
+ if not os.path.isdir(model_name_or_path):
68
+ model_name_or_path = snapshot_download(model_name_or_path)
69
# critic model needs to load the weight here
70
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
71
assert os.path.exists(
0 commit comments