@@ -92,12 +92,19 @@ def get_shuffle_idx(seed, size):
9292 return shuffle_idx
9393
9494
95- def get_raw_dataset_split_index (local_rank , output_path , dataset_name , seed ,
96- split_name , data_split , split_index ,
97- data_size ):
95+ def get_raw_dataset_split_index (local_rank ,
96+ output_path ,
97+ dataset_name ,
98+ seed ,
99+ split_name ,
100+ data_split ,
101+ split_index ,
102+ data_size ,
103+ rebuild = False ):
98104 index_file_name = f"{ output_path } /{ dataset_name } _seed{ seed } _{ split_name } _{ data_split } _{ split_index } .npy"
99105 # reindex each time when using local jsonfile since it's more likely to get modified
100- if (not os .path .isfile (index_file_name )) or (dataset_name == 'jsonfile' ):
106+ if rebuild or (not os .path .isfile (index_file_name )) or (dataset_name
107+ == 'jsonfile' ):
101108 splits = [float (s ) for s in data_split .split (',' )]
102109 splits_sum = sum (splits )
103110 splits = [split / splits_sum for split in splits ]
@@ -176,6 +183,9 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
176183 chosen_token ["attention_mask" ] = chosen_token [
177184 "attention_mask" ].squeeze (0 )
178185 chosen_dataset .append (chosen_token )
186+ print (
187+ f'Creating dataset { raw_dataset .dataset_name_clean } for { train_phase = } size={ len (chosen_dataset )} '
188+ )
179189
180190 elif train_phase == 2 :
181191 for i , tmp_data in enumerate (current_dataset ):
@@ -204,39 +214,41 @@ def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,
204214 reject_token ["input_ids" ] = reject_token ["input_ids" ]
205215 reject_token ["attention_mask" ] = reject_token ["attention_mask" ]
206216 reject_dataset .append (reject_token )
217+ print (
218+ f'Creating dataset { raw_dataset .dataset_name_clean } for { train_phase = } size={ len (chosen_dataset )} '
219+ )
207220
208221 elif train_phase == 3 :
222+ filtered = 0
209223 for i , tmp_data in enumerate (current_dataset ):
210224 # tokenize the text
211225 prompt = raw_dataset .get_prompt (tmp_data )
212226 if prompt is not None :
213227 prompt_token = tokenizer (prompt , return_tensors = "pt" )
214- prompt_token ["input_ids" ] = prompt_token ["input_ids" ]
215- prompt_token ["attention_mask" ] = prompt_token ["attention_mask" ]
216- for key_word in ["input_ids" , "attention_mask" ]:
217- length = prompt_token [key_word ].size ()[- 1 ]
218- if length > max_seq_len :
219- y = prompt_token [key_word ].squeeze (0 )[length -
220- (max_seq_len -
221- 1 ):].flip (0 )
222- else :
223- y = prompt_token [key_word ].squeeze (0 ).flip (0 )
224- prompt_token [key_word ] = y
225- prompt_dataset .append (prompt_token )
228+ if prompt_token ["input_ids" ].size ()[- 1 ] <= max_seq_len :
229+ for key_word in ["input_ids" , "attention_mask" ]:
230+ prompt_token [key_word ] = prompt_token [
231+ key_word ].squeeze (0 ).flip (0 )
232+ prompt_dataset .append (prompt_token )
233+ else :
234+ filtered += 1
235+ print (f'Creating dataset { raw_dataset .dataset_name_clean } '
236+ f'for { train_phase = } size={ len (prompt_dataset )} { filtered = } ' )
237+
226238 return PromptDataset (prompt_dataset , chosen_dataset , reject_dataset ,
227239 tokenizer .pad_token_id , train_phase )
228240
229241
230242def create_dataset (local_rank , dataset_name , data_split , output_path ,
231243 train_phase , seed , tokenizer , end_of_conversation_token ,
232- max_seq_len ):
244+ max_seq_len , rebuild ):
233245 raw_dataset = get_raw_dataset (dataset_name , output_path , seed , local_rank )
234246 train_dataset = raw_dataset .get_train_data ()
235247 train_index = get_raw_dataset_split_index (local_rank , output_path ,
236248 raw_dataset .dataset_name_clean ,
237249 seed , "train" , data_split ,
238250 train_phase - 1 ,
239- len (train_dataset ))
251+ len (train_dataset ), rebuild )
240252 train_dataset = Subset (train_dataset , train_index )
241253 train_dataset = create_dataset_split (train_dataset , raw_dataset ,
242254 train_phase , tokenizer ,
@@ -248,7 +260,7 @@ def create_dataset(local_rank, dataset_name, data_split, output_path,
248260 raw_dataset .dataset_name_clean ,
249261 seed , "eval" ,
250262 data_split , train_phase - 1 ,
251- len (eval_dataset ))
263+ len (eval_dataset ), rebuild )
252264 eval_dataset = Subset (eval_dataset , eval_index )
253265 eval_dataset = create_dataset_split (eval_dataset , raw_dataset , train_phase ,
254266 tokenizer , end_of_conversation_token ,
@@ -287,19 +299,36 @@ def create_prompt_dataset(local_rank,
287299 torch .distributed .all_reduce (buf_create_cache )
288300
289301 if local_rank <= 0 and (buf_create_cache .item () != 0 or reload ):
302+ print (f'Creating prompt dataset { data_path } , { reload = } ' )
290303 if len (data_path ) == 1 : # Single dataset.
291304 train_dataset , eval_dataset = create_dataset (
292- local_rank , data_path [0 ], data_split , output_path , train_phase ,
293- seed , tokenizer , end_of_conversation_token , max_seq_len )
305+ local_rank ,
306+ data_path [0 ],
307+ data_split ,
308+ output_path ,
309+ train_phase ,
310+ seed ,
311+ tokenizer ,
312+ end_of_conversation_token ,
313+ max_seq_len ,
314+ rebuild = reload )
294315 else : # Blending datasets.
295316 train_datasets = []
296317 eval_datasets = []
297318 train_size = 0
298319 eval_size = 0
299320 for d_path in data_path :
300321 train_dataset , eval_dataset = create_dataset (
301- local_rank , d_path , data_split , output_path , train_phase ,
302- seed , tokenizer , end_of_conversation_token , max_seq_len )
322+ local_rank ,
323+ d_path ,
324+ data_split ,
325+ output_path ,
326+ train_phase ,
327+ seed ,
328+ tokenizer ,
329+ end_of_conversation_token ,
330+ max_seq_len ,
331+ rebuild = reload )
303332 train_datasets .append (train_dataset )
304333 eval_datasets .append (eval_dataset )
305334 train_size += len (train_dataset )
@@ -328,7 +357,7 @@ def create_prompt_dataset(local_rank,
328357 tokenizer ,
329358 end_of_conversation_token ,
330359 max_seq_len ,
331- )
360+ rebuild = reload )
332361 sft_train_datasets .append (sft_train_dataset )
333362 sft_eval_datasets .append (sft_eval_dataset )
334363 sft_train_size += len (sft_train_dataset )
0 commit comments