@@ -75,7 +75,7 @@ def compute_forward(self, batch, stage):
7575 def compute_objectives (self , predictions , batch , stage ):
7676 """Computes the loss (CTC+NLL) given predictions and targets."""
7777
78- (p_ctc , p_seq , wav_lens , hyps , ) = predictions
78+ (p_ctc , p_seq , wav_lens , hyps ) = predictions
7979
8080 ids = batch .id
8181 tokens_eos , tokens_eos_lens = batch .tokens_eos
@@ -169,7 +169,6 @@ def on_stage_end(self, stage, stage_loss, epoch):
169169
170170 # log stats and save checkpoint at end-of-epoch
171171 if stage == sb .Stage .VALID :
172-
173172 # report different epoch stages according current stage
174173 current_epoch = self .hparams .epoch_counter .current
175174 if current_epoch <= self .hparams .stage_one_epochs :
@@ -247,7 +246,6 @@ def on_fit_start(self):
247246
248247 # Load latest checkpoint to resume training if interrupted
249248 if self .checkpointer is not None :
250-
251249 # do not reload the weights if training is interrupted right before stage 2
252250 group = current_optimizer .param_groups [0 ]
253251 if "momentum" not in group :
@@ -263,7 +261,8 @@ def on_evaluate_start(self, max_key=None, min_key=None):
263261 max_key = max_key , min_key = min_key
264262 )
265263 ckpt = sb .utils .checkpoints .average_checkpoints (
266- ckpts , recoverable_name = "model" ,
264+ ckpts ,
265+ recoverable_name = "model" ,
267266 )
268267
269268 self .hparams .model .load_state_dict (ckpt , strict = True )
@@ -272,11 +271,13 @@ def on_evaluate_start(self, max_key=None, min_key=None):
272271
273272def dataio_prepare (hparams ):
274273 """This function prepares the datasets to be used in the brain class.
275- It also defines the data processing pipeline through user-defined functions."""
274+ It also defines the data processing pipeline through user-defined functions.
275+ """
276276 data_folder = hparams ["data_folder" ]
277277
278278 train_data = sb .dataio .dataset .DynamicItemDataset .from_csv (
279- csv_path = hparams ["train_data" ], replacements = {"data_root" : data_folder },
279+ csv_path = hparams ["train_data" ],
280+ replacements = {"data_root" : data_folder },
280281 )
281282
282283 if hparams ["sorting" ] == "ascending" :
@@ -301,12 +302,14 @@ def dataio_prepare(hparams):
301302 )
302303
303304 valid_data = sb .dataio .dataset .DynamicItemDataset .from_csv (
304- csv_path = hparams ["valid_data" ], replacements = {"data_root" : data_folder },
305+ csv_path = hparams ["valid_data" ],
306+ replacements = {"data_root" : data_folder },
305307 )
306308 valid_data = valid_data .filtered_sorted (sort_key = "duration" )
307309
308310 test_data = sb .dataio .dataset .DynamicItemDataset .from_csv (
309- csv_path = hparams ["test_data" ], replacements = {"data_root" : data_folder },
311+ csv_path = hparams ["test_data" ],
312+ replacements = {"data_root" : data_folder },
310313 )
311314 test_data = test_data .filtered_sorted (sort_key = "duration" , reverse = True )
312315
@@ -344,7 +347,8 @@ def text_pipeline(wrd):
344347
345348 # 4. Set output:
346349 sb .dataio .dataset .set_output_keys (
347- datasets , ["id" , "sig" , "wrd" , "tokens_bos" , "tokens_eos" , "tokens" ],
350+ datasets ,
351+ ["id" , "sig" , "wrd" , "tokens_bos" , "tokens_eos" , "tokens" ],
348352 )
349353
350354 # 5. If Dynamic Batching is used, we instantiate the needed samplers.
@@ -356,11 +360,11 @@ def text_pipeline(wrd):
356360 dynamic_hparams = hparams ["dynamic_batch_sampler" ]
357361
358362 train_batch_sampler = DynamicBatchSampler (
359- train_data , ** dynamic_hparams , length_func = lambda x : x ["duration" ],
363+ train_data , ** dynamic_hparams , length_func = lambda x : x ["duration" ]
360364 )
361365
362366 valid_batch_sampler = DynamicBatchSampler (
363- valid_data , ** dynamic_hparams , length_func = lambda x : x ["duration" ],
367+ valid_data , ** dynamic_hparams , length_func = lambda x : x ["duration" ]
364368 )
365369
366370 return (
@@ -374,7 +378,6 @@ def text_pipeline(wrd):
374378
375379
376380if __name__ == "__main__" :
377-
378381 # CLI:
379382 hparams_file , run_opts , overrides = sb .parse_arguments (sys .argv [1 :])
380383 with open (hparams_file ) as fin :
0 commit comments