Skip to content

Commit 3517c6d

Browse files
PareesaMSmrwyattii
andauthored
Resolves the issue with evaluation on step2 for single GPU (deepspeedai#766)
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
1 parent 1ba50ed commit 3517c6d

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

applications/DeepSpeed-Chat/training/utils/model/model_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .reward_model import RewardModel
1616
from ..utils import load_state_dict_into_model
17+
from ..utils import print_rank_0
1718

1819

1920
def configure_dropout(model_config, dropout):
@@ -130,8 +131,8 @@ def create_critic_model(model_name_or_path,
130131
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
131132
ds_config, rlhf_training, dropout)
132133
end = time.time()
133-
if torch.distributed.get_rank() == 0:
134-
print(f"> Creating model from_config took {end - start} seconds")
134+
print_rank_0(f">Creating model from_config took {end - start} seconds",
135+
None)
135136

136137
critic_model = RewardModel(
137138
critic_model,
@@ -152,8 +153,8 @@ def create_critic_model(model_name_or_path,
152153
start = time.time()
153154
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
154155
end = time.time()
155-
if torch.distributed.get_rank() == 0:
156-
print(f"> torch.load took {end - start} seconds")
156+
print_rank_0(f">Creating model from_config took {end - start} seconds",
157+
None)
157158

158159
# load critic model from checkpoint with zero-stage 3 compatibility
159160
# this functionality may be moved to DS checkpoint load API in future
@@ -163,7 +164,8 @@ def create_critic_model(model_name_or_path,
163164
"",
164165
zero_stage=zero_stage)
165166
end = time.time()
166-
if torch.distributed.get_rank() == 0:
167-
print(f"> Loading model state dict took {end - start} seconds")
167+
168+
print_rank_0(f">Creating model from_config took {end - start} seconds",
169+
None)
168170

169171
return critic_model

applications/DeepSpeed-Chat/training/utils/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,22 @@
1414
import torch.nn as nn
1515

1616

17-
def print_rank_0(msg, rank=0):
18-
if rank <= 0:
17+
def print_rank_0(msg, rank=None):
18+
if rank is not None and rank <= 0:
1919
print(msg)
20+
elif is_rank_0():
21+
print(msg)
22+
23+
24+
def is_rank_0():
25+
"""Check whether it is rank 0."""
26+
if torch.distributed.is_initialized():
27+
if torch.distributed.get_rank() == 0:
28+
return True
29+
else:
30+
return False
31+
else:
32+
return True
2033

2134

2235
def to_device(batch, device):

0 commit comments

Comments
 (0)