1919import math
2020import toml
2121from dataset import AudioToTextDataLayer
22- from helpers import process_evaluation_batch , process_evaluation_epoch , add_ctc_labels , AmpOptimizations , print_dict , model_multi_gpu , __ctc_decoder_predictions_tensor
22+ from helpers import process_evaluation_batch , process_evaluation_epoch , add_ctc_labels , print_dict , model_multi_gpu , __ctc_decoder_predictions_tensor
2323from model import AudioPreprocessing , GreedyCTCDecoder , JasperEncoderDecoder
2424from parts .features import audio_from_file
2525import torch
@@ -46,21 +46,21 @@ def parse_args():
4646 parser .add_argument ("--ckpt" , default = None , type = str , required = True , help = 'path to model checkpoint' )
4747 parser .add_argument ("--max_duration" , default = None , type = float , help = 'maximum duration of sequences. if None uses attribute from model configuration file' )
4848 parser .add_argument ("--pad_to" , default = None , type = int , help = "default is pad to value as specified in model configurations. if -1 pad to maximum duration. If > 0 pad batch to next multiple of value" )
49- parser .add_argument ("--fp16" , action = 'store_true' , help = 'use half precision' )
50- parser .add_argument ("--pyt_fp16" , action = 'store_true' , help = 'use half precision' )
49+ parser .add_argument ("--amp" , "--fp16" , action = 'store_true' , help = 'use half precision' )
5150 parser .add_argument ("--cudnn_benchmark" , action = 'store_true' , help = "enable cudnn benchmark" )
5251 parser .add_argument ("--save_prediction" , type = str , default = None , help = "if specified saves predictions in text form at this location" )
5352 parser .add_argument ("--logits_save_to" , default = None , type = str , help = "if specified will save logits to path" )
5453 parser .add_argument ("--seed" , default = 42 , type = int , help = 'seed' )
55- parser .add_argument ("--masked_fill" , type = "bool" , help = "Overrides the masked_fill option for the Encoder" )
5654 parser .add_argument ("--output_dir" , default = "results/" , type = str , help = "Output directory to store exported models. Only used if --export_model is used" )
5755 parser .add_argument ("--export_model" , action = 'store_true' , help = "Exports the audio_featurizer, encoder and decoder using torch.jit to the output_dir" )
5856 parser .add_argument ("--wav" , type = str , help = 'absolute path to .wav file (16KHz)' )
57+ parser .add_argument ("--cpu" , action = "store_true" , help = "Run inference on CPU" )
58+ parser .add_argument ("--ema" , action = "store_true" , help = "If available, load EMA model weights" )
5959 return parser .parse_args ()
6060
61- def calc_wer (data_layer , audio_processor ,
62- encoderdecoder , greedy_decoder ,
63- labels , args ):
61+ def calc_wer (data_layer , audio_processor ,
62+ encoderdecoder , greedy_decoder ,
63+ labels , args , device ):
6464
6565 encoderdecoder = encoderdecoder .module if hasattr (encoderdecoder , 'module' ) else encoderdecoder
6666 with torch .no_grad ():
@@ -74,16 +74,14 @@ def calc_wer(data_layer, audio_processor,
7474 # Evaluation mini-batch for loop
7575 for it , data in enumerate (tqdm (data_layer .data_iterator )):
7676
77- tensors = []
78- for d in data :
79- tensors .append (d .cuda ())
77+ tensors = [t .to (device ) for t in data ]
8078
8179 t_audio_signal_e , t_a_sig_length_e , t_transcript_e , t_transcript_len_e = tensors
82-
83- t_processed_signal = audio_processor (t_audio_signal_e , t_a_sig_length_e )
80+
81+ t_processed_signal = audio_processor (t_audio_signal_e , t_a_sig_length_e )
8482 t_log_probs_e , _ = encoderdecoder .infer (t_processed_signal )
8583 t_predictions_e = greedy_decoder (t_log_probs_e )
86-
84+
8785 values_dict = dict (
8886 predictions = [t_predictions_e ],
8987 transcript = [t_transcript_e ],
@@ -92,7 +90,7 @@ def calc_wer(data_layer, audio_processor,
9290 )
9391 # values_dict will contain results from all workers
9492 process_evaluation_batch (values_dict , _global_var_dict , labels = labels )
95-
93+
9694 if args .steps is not None and it + 1 >= args .steps :
9795 break
9896
@@ -102,18 +100,13 @@ def calc_wer(data_layer, audio_processor,
102100 return wer , _global_var_dict
103101
104102
105- def jit_export (
106- audio , audio_len ,
107- audio_processor ,
108- encoderdecoder ,
109- greedy_decoder ,
110- args ):
103+ def jit_export (audio , audio_len , audio_processor , encoderdecoder , greedy_decoder , args ):
111104
112105 print ("##############" )
113106
114- module_name = "{}_{}" .format (os .path .basename (args .model_toml ), "fp16" if args .fp16 else "fp32" )
107+ module_name = "{}_{}" .format (os .path .basename (args .model_toml ), "fp16" if args .amp else "fp32" )
115108
116- if args .masked_fill is not None and args . masked_fill == False :
109+ if args .use_conv_mask :
117110 module_name = module_name + "_noMaskConv"
118111
119112 # Export just the featurizer
@@ -137,12 +130,18 @@ def jit_export(
137130
138131 return traced_module_feat , traced_module_acoustic , traced_module_decode
139132
140- def run_once (audio_processor , encoderdecoder , greedy_decoder , audio , audio_len , labels ):
141- features = audio_processor (audio , audio_len )
142- torch .cuda .synchronize ()
133+ def run_once (audio_processor , encoderdecoder , greedy_decoder , audio , audio_len , labels , device ):
134+ features , lens = audio_processor (audio , audio_len )
135+ if not device .type == 'cpu' :
136+ torch .cuda .synchronize ()
143137 t0 = time .perf_counter ()
144- t_log_probs_e = encoderdecoder (features [0 ])
145- torch .cuda .synchronize ()
138+ # TorchScripted model does not support (features, lengths)
139+ if isinstance (encoderdecoder , torch .jit .TracedModule ):
140+ t_log_probs_e = encoderdecoder (features )
141+ else :
142+ t_log_probs_e , _ = encoderdecoder .infer ((features , lens ))
143+ if not device .type == 'cpu' :
144+ torch .cuda .synchronize ()
146145 t1 = time .perf_counter ()
147146 t_predictions_e = greedy_decoder (log_probs = t_log_probs_e )
148147 hypotheses = __ctc_decoder_predictions_tensor (t_predictions_e , labels = labels )
@@ -157,6 +156,7 @@ def eval(
157156 greedy_decoder ,
158157 labels ,
159158 multi_gpu ,
159+ device ,
160160 args ):
161161 """performs inference / evaluation
162162 Args:
@@ -169,21 +169,19 @@ def eval(
169169 args: script input arguments
170170 """
171171 logits_save_to = args .logits_save_to
172-
172+
173173 with torch .no_grad ():
174174 if args .wav :
175175 audio , audio_len = audio_from_file (args .wav )
176- run_once (audio_processor , encoderdecoder , greedy_decoder , audio , audio_len , labels )
176+ run_once (audio_processor , encoderdecoder , greedy_decoder , audio , audio_len , labels , device )
177177 if args .export_model :
178- jit_audio_processor , jit_encoderdecoder , jit_greedy_decoder = jit_export (audio , audio_len , audio_processor ,
179- encoderdecoder ,
180- greedy_decoder ,args )
181- run_once (jit_audio_processor , jit_encoderdecoder , jit_greedy_decoder , audio , audio_len , labels )
178+ jit_audio_processor , jit_encoderdecoder , jit_greedy_decoder = jit_export (audio , audio_len , audio_processor , encoderdecoder ,greedy_decoder ,args )
179+ run_once (jit_audio_processor , jit_encoderdecoder , jit_greedy_decoder , audio , audio_len , labels , device )
182180 return
183- wer , _global_var_dict = calc_wer (data_layer , audio_processor , encoderdecoder , greedy_decoder , labels , args )
181+ wer , _global_var_dict = calc_wer (data_layer , audio_processor , encoderdecoder , greedy_decoder , labels , args , device )
184182 if (not multi_gpu or (multi_gpu and torch .distributed .get_rank () == 0 )):
185183 print ("==========>>>>>>Evaluation WER: {0}\n " .format (wer ))
186-
184+
187185 if args .save_prediction is not None :
188186 with open (args .save_prediction , 'w' ) as fp :
189187 fp .write ('\n ' .join (_global_var_dict ['predictions' ]))
@@ -203,26 +201,29 @@ def eval(
203201 # print("===>>>Diff : {0} %".format((wer_after - wer_before) * 100.0 / wer_before))
204202 # print("")
205203
206-
204+
207205def main (args ):
208206 random .seed (args .seed )
209207 np .random .seed (args .seed )
210208 torch .manual_seed (args .seed )
211- torch .backends .cudnn .benchmark = args .cudnn_benchmark
212- print ("CUDNN BENCHMARK " , args .cudnn_benchmark )
213- assert (torch .cuda .is_available ())
214209
215- if args .local_rank is not None :
216- torch .cuda .set_device (args .local_rank )
217- torch .distributed .init_process_group (backend = 'nccl' , init_method = 'env://' )
218210 multi_gpu = args .local_rank is not None
219- if multi_gpu :
220- print ("DISTRIBUTED with " , torch .distributed .get_world_size ())
221211
222- if args .fp16 :
223- optim_level = 3
212+ if args .cpu :
213+ assert (not multi_gpu )
214+ device = torch .device ('cpu' )
224215 else :
225- optim_level = 0
216+ assert (torch .cuda .is_available ())
217+ device = torch .device ('cuda' )
218+ torch .backends .cudnn .benchmark = args .cudnn_benchmark
219+ print ("CUDNN BENCHMARK " , args .cudnn_benchmark )
220+
221+ if multi_gpu :
222+ print ("DISTRIBUTED with " , torch .distributed .get_world_size ())
223+ torch .cuda .set_device (args .local_rank )
224+ torch .distributed .init_process_group (backend = 'nccl' , init_method = 'env://' )
225+
226+ optim_level = 3 if args .amp else 0
226227
227228 jasper_model_definition = toml .load (args .model_toml )
228229 dataset_vocab = jasper_model_definition ['labels' ]['labels' ]
@@ -231,32 +232,32 @@ def main(args):
231232 val_manifest = args .val_manifest
232233 featurizer_config = jasper_model_definition ['input_eval' ]
233234 featurizer_config ["optimization_level" ] = optim_level
234- featurizer_config ["fp16" ] = args .fp16
235- args .use_conv_mask = jasper_model_definition ['encoder' ].get ('convmask' , True )
235+ featurizer_config ["fp16" ] = args .amp
236236
237- if args .masked_fill is not None :
238- print ("{} masked_fill" .format ("Enabling" if args .masked_fill else "Disabling" ))
239- jasper_model_definition ["encoder" ]["conv_mask" ] = args .masked_fill
237+ args .use_conv_mask = jasper_model_definition ['encoder' ].get ('convmask' , True )
238+ if args .use_conv_mask and args .export_model :
239+ print ('WARNING: Masked convs currently not supported for TorchScript. Disabling.' )
240+ jasper_model_definition ['encoder' ]['convmask' ] = False
240241
241242 if args .max_duration is not None :
242243 featurizer_config ['max_duration' ] = args .max_duration
243244 if args .pad_to is not None :
244- featurizer_config ['pad_to' ] = args .pad_to
245+ featurizer_config ['pad_to' ] = args .pad_to
245246
246247 if featurizer_config ['pad_to' ] == "max" :
247248 featurizer_config ['pad_to' ] = - 1
248-
249+
249250 print ('=== model_config ===' )
250251 print_dict (jasper_model_definition )
251252 print ()
252253 print ('=== feature_config ===' )
253254 print_dict (featurizer_config )
254255 print ()
255256 data_layer = None
256-
257+
257258 if args .wav is None :
258259 data_layer = AudioToTextDataLayer (
259- dataset_dir = args .dataset_dir ,
260+ dataset_dir = args .dataset_dir ,
260261 featurizer_config = featurizer_config ,
261262 manifest_filepath = val_manifest ,
262263 labels = dataset_vocab ,
@@ -274,10 +275,16 @@ def main(args):
274275 exit (0 )
275276 else :
276277 checkpoint = torch .load (args .ckpt , map_location = "cpu" )
278+ if args .ema and 'ema_state_dict' in checkpoint :
279+ print ('Loading EMA state dict' )
280+ sd = 'ema_state_dict'
281+ else :
282+ sd = 'state_dict'
283+
277284 for k in audio_preprocessor .state_dict ().keys ():
278- checkpoint ['state_dict' ][k ] = checkpoint ['state_dict' ].pop ("audio_preprocessor." + k )
279- audio_preprocessor .load_state_dict (checkpoint ['state_dict' ], strict = False )
280- encoderdecoder .load_state_dict (checkpoint ['state_dict' ], strict = False )
285+ checkpoint [sd ][k ] = checkpoint [sd ].pop ("audio_preprocessor." + k )
286+ audio_preprocessor .load_state_dict (checkpoint [sd ], strict = False )
287+ encoderdecoder .load_state_dict (checkpoint [sd ], strict = False )
281288
282289 greedy_decoder = GreedyCTCDecoder ()
283290
@@ -298,24 +305,27 @@ def main(args):
298305 print ('-----------------' )
299306
300307 print ("audio_preprocessor.normalize: " , audio_preprocessor .featurizer .normalize )
301- audio_preprocessor .cuda ()
302- encoderdecoder .cuda ()
303- if args .fp16 :
304- encoderdecoder = amp .initialize ( models = encoderdecoder ,
305- opt_level = AmpOptimizations [optim_level ])
308+
309+ audio_preprocessor .to (device )
310+ encoderdecoder .to (device )
311+
312+ if args .amp :
313+ encoderdecoder = amp .initialize (models = encoderdecoder ,
314+ opt_level = 'O' + str (optim_level ))
306315
307316 encoderdecoder = model_multi_gpu (encoderdecoder , multi_gpu )
308317 audio_preprocessor .eval ()
309318 encoderdecoder .eval ()
310319 greedy_decoder .eval ()
311-
320+
312321 eval (
313322 data_layer = data_layer ,
314323 audio_processor = audio_preprocessor ,
315324 encoderdecoder = encoderdecoder ,
316325 greedy_decoder = greedy_decoder ,
317326 labels = ctc_vocab ,
318327 args = args ,
328+ device = device ,
319329 multi_gpu = multi_gpu )
320330
321331if __name__ == "__main__" :
0 commit comments