Skip to content

Commit 8c551d2

Browse files
mosheislandMoshe Island
andauthored
deepspeed-chat: filter stage3 too long prompts (deepspeedai#782)
In case stage3 prompts are too long, the prompts are used but they are arbitrary sliced at start to fit into the configured max prompt length. This arbitrary slicing sometimes causes prompts to be less meaningful. Which in turn, causes the generator to generate garbage. This phenomena was observed to de-stabilize RLHF stage3. To fix it, we filter prompts that are too long. In addition, dataset rebuild flag is propagated to other required consumers. Note that since generated dataset are cached in disk, this commit will have effect only if we cleanup step3 cached datasets. Change-Id: I440f09decf0784e4c2c8167a893006dff312281b Signed-off-by: Moshe Island <misland@habana.ai> Co-authored-by: Moshe Island <misland@habana.ai>
1 parent 09af71a commit 8c551d2

1 file changed

Lines changed: 53 additions & 24 deletions

File tree

applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,19 @@ def get_shuffle_idx(seed, size):
9292
return shuffle_idx
9393

9494

95-
def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
96-
split_name, data_split, split_index,
97-
data_size):
95+
def get_raw_dataset_split_index(local_rank,
96+
output_path,
97+
dataset_name,
98+
seed,
99+
split_name,
100+
data_split,
101+
split_index,
102+
data_size,
103+
rebuild=False):
98104
index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
99105
# reindex each time when using local jsonfile since it's more likely to get modified
100-
if (not os.path.isfile(index_file_name)) or (dataset_name == 'jsonfile'):
106+
if rebuild or (not os.path.isfile(index_file_name)) or (dataset_name
107+
== 'jsonfile'):
101108
splits = [float(s) for s in data_split.split(',')]
102109
splits_sum = sum(splits)
103110
splits = [split / splits_sum for split in splits]
@@ -176,6 +183,9 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
176183
chosen_token["attention_mask"] = chosen_token[
177184
"attention_mask"].squeeze(0)
178185
chosen_dataset.append(chosen_token)
186+
print(
187+
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
188+
)
179189

180190
elif train_phase == 2:
181191
for i, tmp_data in enumerate(current_dataset):
@@ -204,39 +214,41 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
204214
reject_token["input_ids"] = reject_token["input_ids"]
205215
reject_token["attention_mask"] = reject_token["attention_mask"]
206216
reject_dataset.append(reject_token)
217+
print(
218+
f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}'
219+
)
207220

208221
elif train_phase == 3:
222+
filtered = 0
209223
for i, tmp_data in enumerate(current_dataset):
210224
# tokenize the text
211225
prompt = raw_dataset.get_prompt(tmp_data)
212226
if prompt is not None:
213227
prompt_token = tokenizer(prompt, return_tensors="pt")
214-
prompt_token["input_ids"] = prompt_token["input_ids"]
215-
prompt_token["attention_mask"] = prompt_token["attention_mask"]
216-
for key_word in ["input_ids", "attention_mask"]:
217-
length = prompt_token[key_word].size()[-1]
218-
if length > max_seq_len:
219-
y = prompt_token[key_word].squeeze(0)[length -
220-
(max_seq_len -
221-
1):].flip(0)
222-
else:
223-
y = prompt_token[key_word].squeeze(0).flip(0)
224-
prompt_token[key_word] = y
225-
prompt_dataset.append(prompt_token)
228+
if prompt_token["input_ids"].size()[-1] <= max_seq_len:
229+
for key_word in ["input_ids", "attention_mask"]:
230+
prompt_token[key_word] = prompt_token[
231+
key_word].squeeze(0).flip(0)
232+
prompt_dataset.append(prompt_token)
233+
else:
234+
filtered += 1
235+
print(f'Creating dataset {raw_dataset.dataset_name_clean} '
236+
f'for {train_phase=} size={len(prompt_dataset)} {filtered=}')
237+
226238
return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,
227239
tokenizer.pad_token_id, train_phase)
228240

229241

230242
def create_dataset(local_rank, dataset_name, data_split, output_path,
231243
train_phase, seed, tokenizer, end_of_conversation_token,
232-
max_seq_len):
244+
max_seq_len, rebuild):
233245
raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank)
234246
train_dataset = raw_dataset.get_train_data()
235247
train_index = get_raw_dataset_split_index(local_rank, output_path,
236248
raw_dataset.dataset_name_clean,
237249
seed, "train", data_split,
238250
train_phase - 1,
239-
len(train_dataset))
251+
len(train_dataset), rebuild)
240252
train_dataset = Subset(train_dataset, train_index)
241253
train_dataset = create_dataset_split(train_dataset, raw_dataset,
242254
train_phase, tokenizer,
@@ -248,7 +260,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
248260
raw_dataset.dataset_name_clean,
249261
seed, "eval",
250262
data_split, train_phase - 1,
251-
len(eval_dataset))
263+
len(eval_dataset), rebuild)
252264
eval_dataset = Subset(eval_dataset, eval_index)
253265
eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase,
254266
tokenizer, end_of_conversation_token,
@@ -287,19 +299,36 @@ def create_prompt_dataset(local_rank,
287299
torch.distributed.all_reduce(buf_create_cache)
288300

289301
if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
302+
print(f'Creating prompt dataset {data_path}, {reload=}')
290303
if len(data_path) == 1: # Single dataset.
291304
train_dataset, eval_dataset = create_dataset(
292-
local_rank, data_path[0], data_split, output_path, train_phase,
293-
seed, tokenizer, end_of_conversation_token, max_seq_len)
305+
local_rank,
306+
data_path[0],
307+
data_split,
308+
output_path,
309+
train_phase,
310+
seed,
311+
tokenizer,
312+
end_of_conversation_token,
313+
max_seq_len,
314+
rebuild=reload)
294315
else: # Blending datasets.
295316
train_datasets = []
296317
eval_datasets = []
297318
train_size = 0
298319
eval_size = 0
299320
for d_path in data_path:
300321
train_dataset, eval_dataset = create_dataset(
301-
local_rank, d_path, data_split, output_path, train_phase,
302-
seed, tokenizer, end_of_conversation_token, max_seq_len)
322+
local_rank,
323+
d_path,
324+
data_split,
325+
output_path,
326+
train_phase,
327+
seed,
328+
tokenizer,
329+
end_of_conversation_token,
330+
max_seq_len,
331+
rebuild=reload)
303332
train_datasets.append(train_dataset)
304333
eval_datasets.append(eval_dataset)
305334
train_size += len(train_dataset)
@@ -328,7 +357,7 @@ def create_prompt_dataset(local_rank,
328357
tokenizer,
329358
end_of_conversation_token,
330359
max_seq_len,
331-
)
360+
rebuild=reload)
332361
sft_train_datasets.append(sft_train_dataset)
333362
sft_eval_datasets.append(sft_eval_dataset)
334363
sft_train_size += len(sft_train_dataset)

0 commit comments

Comments
 (0)