4040from torch .nn .utils .rnn import pad_sequence
4141
4242import dllogger as DLLogger
43- from apex import amp
4443from dllogger import StdOutBackend , JSONStreamBackend , Verbosity
4544
4645from common import utils
46+ from common .log_helper import unique_dllogger_fpath
4747from common .text import text_to_sequence
4848from waveglow import model as glow
4949from waveglow .denoiser import Denoiser
@@ -59,8 +59,8 @@ def parse_args(parser):
5959 help = 'Full path to the input text (phareses separated by newlines)' )
6060 parser .add_argument ('-o' , '--output' , default = None ,
6161 help = 'Output folder to save audio (file per phrase)' )
62- parser .add_argument ('--log-file' , type = str , default = 'nvlog.json' ,
63- help = 'Filename for logging ' )
62+ parser .add_argument ('--log-file' , type = str , default = None ,
63+ help = 'Path to a DLLogger log file ' )
6464 parser .add_argument ('--cuda' , action = 'store_true' ,
6565 help = 'Run inference on a GPU using CUDA' )
6666 parser .add_argument ('--fastpitch' , type = str ,
@@ -75,7 +75,7 @@ def parse_args(parser):
7575 help = 'Sampling rate' )
7676 parser .add_argument ('--stft-hop-length' , type = int , default = 256 ,
7777 help = 'STFT hop length for estimating audio length from mel size' )
78- parser .add_argument ('--amp-run ' , action = 'store_true' ,
78+ parser .add_argument ('--amp' , action = 'store_true' ,
7979 help = 'Inference with AMP' )
8080 parser .add_argument ('--batch-size' , type = int , default = 64 )
8181 parser .add_argument ('--include-warmup' , action = 'store_true' ,
@@ -105,7 +105,7 @@ def parse_args(parser):
105105 return parser
106106
107107
108- def load_and_setup_model (model_name , parser , checkpoint , amp_run , device ,
108+ def load_and_setup_model (model_name , parser , checkpoint , amp , device ,
109109 unk_args = [], forward_is_infer = False , ema = True ,
110110 jitable = False ):
111111 model_parser = models .parse_model_args (model_name , parser , add_help = False )
@@ -139,7 +139,7 @@ def load_and_setup_model(model_name, parser, checkpoint, amp_run, device,
139139
140140 if model_name == "WaveGlow" :
141141 model = model .remove_weightnorm (model )
142- if amp_run :
142+ if amp :
143143 model .half ()
144144 model .eval ()
145145 return model .to (device )
@@ -232,25 +232,28 @@ def main():
232232 Launches text to speech (inference).
233233 Inference is executed on a single GPU.
234234 """
235+
236+ torch .backends .cudnn .benchmark = True
237+
235238 parser = argparse .ArgumentParser (description = 'PyTorch FastPitch Inference' ,
236239 allow_abbrev = False )
237240 parser = parse_args (parser )
238241 args , unk_args = parser .parse_known_args ()
239242
240- DLLogger .init (backends = [JSONStreamBackend (Verbosity .DEFAULT , args .log_file ),
241- StdOutBackend (Verbosity .VERBOSE )])
242- for k ,v in vars (args ).items ():
243- DLLogger .log (step = "PARAMETER" , data = {k :v })
244- DLLogger .log (step = "PARAMETER" , data = {'model_name' : 'FastPitch_PyT' })
245-
246243 if args .output is not None :
247244 Path (args .output ).mkdir (parents = False , exist_ok = True )
248245
246+ log_fpath = args .log_file or str (Path (args .output , 'nvlog_infer.json' ))
247+ log_fpath = unique_dllogger_fpath (log_fpath )
248+ DLLogger .init (backends = [JSONStreamBackend (Verbosity .DEFAULT , log_fpath ),
249+ StdOutBackend (Verbosity .VERBOSE )])
250+ [DLLogger .log ("PARAMETER" , {k :v }) for k ,v in vars (args ).items ()]
251+
249252 device = torch .device ('cuda' if args .cuda else 'cpu' )
250253
251254 if args .fastpitch is not None :
252255 generator = load_and_setup_model (
253- 'FastPitch' , parser , args .fastpitch , args .amp_run , device ,
256+ 'FastPitch' , parser , args .fastpitch , args .amp , device ,
254257 unk_args = unk_args , forward_is_infer = True , ema = args .ema ,
255258 jitable = args .torchscript )
256259
@@ -263,7 +266,7 @@ def main():
263266 with warnings .catch_warnings ():
264267 warnings .simplefilter ("ignore" )
265268 waveglow = load_and_setup_model (
266- 'WaveGlow' , parser , args .waveglow , args .amp_run , device ,
269+ 'WaveGlow' , parser , args .waveglow , args .amp , device ,
267270 unk_args = unk_args , forward_is_infer = True , ema = args .ema )
268271 denoiser = Denoiser (waveglow ).to (device )
269272 waveglow = getattr (waveglow , 'infer' , waveglow )
@@ -305,13 +308,14 @@ def main():
305308 all_frames = 0
306309
307310 reps = args .repeats
308- log_enabled = reps == 1
311+ log_enabled = True # reps == 1
309312 log = lambda s , d : DLLogger .log (step = s , data = d ) if log_enabled else None
310313
311- for repeat in (tqdm .tqdm (range (reps )) if reps > 1 else range (reps )):
314+ # for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
315+ for rep in range (reps ):
312316 for b in batches :
313317 if generator is None :
314- log (0 , {'Synthesizing from ground truth mels' })
318+ log (rep , {'Synthesizing from ground truth mels' })
315319 mel , mel_lens = b ['mel' ], b ['mel_lens' ]
316320 else :
317321 with torch .no_grad (), gen_measures :
@@ -321,8 +325,8 @@ def main():
321325 gen_infer_perf = mel .size (0 ) * mel .size (2 ) / gen_measures [- 1 ]
322326 all_letters += b ['text_lens' ].sum ().item ()
323327 all_frames += mel .size (0 ) * mel .size (2 )
324- log (0 , {"generator_frames_per_sec " : gen_infer_perf })
325- log (0 , {"generator_latency " : gen_measures [- 1 ]})
328+ log (rep , {"fastpitch_frames_per_sec " : gen_infer_perf })
329+ log (rep , {"fastpitch_latency " : gen_measures [- 1 ]})
326330
327331 if waveglow is not None :
328332 with torch .no_grad (), waveglow_measures :
@@ -336,8 +340,8 @@ def main():
336340 waveglow_infer_perf = (
337341 audios .size (0 ) * audios .size (1 ) / waveglow_measures [- 1 ])
338342
339- log (0 , {"waveglow_samples_per_sec" : waveglow_infer_perf })
340- log (0 , {"waveglow_latency" : waveglow_measures [- 1 ]})
343+ log (rep , {"waveglow_samples_per_sec" : waveglow_infer_perf })
344+ log (rep , {"waveglow_latency" : waveglow_measures [- 1 ]})
341345
342346 if args .output is not None and reps == 1 :
343347 for i , audio in enumerate (audios ):
@@ -354,27 +358,31 @@ def main():
354358 write (audio_path , args .sampling_rate , audio .cpu ().numpy ())
355359
356360 if generator is not None and waveglow is not None :
357- log (0 , {"latency" : (gen_measures [- 1 ] + waveglow_measures [- 1 ])})
361+ log (rep , {"latency" : (gen_measures [- 1 ] + waveglow_measures [- 1 ])})
358362
359363 log_enabled = True
360364 if generator is not None :
361365 gm = np .sort (np .asarray (gen_measures ))
362- log ('avg' , {"generator letters/s" : all_letters / gm .sum ()})
363- log ('avg' , {"generator_frames/s" : all_frames / gm .sum ()})
364- log ('avg' , {"generator_latency" : gm .mean ()})
365- log ('90%' , {"generator_latency" : gm .mean () + norm .ppf ((1.0 + 0.90 ) / 2 ) * gm .std ()})
366- log ('95%' , {"generator_latency" : gm .mean () + norm .ppf ((1.0 + 0.95 ) / 2 ) * gm .std ()})
367- log ('99%' , {"generator_latency" : gm .mean () + norm .ppf ((1.0 + 0.99 ) / 2 ) * gm .std ()})
366+ rtf = all_samples / (all_utterances * gm .mean () * args .sampling_rate )
367+ log ('avg' , {"fastpitch letters/s" : all_letters / gm .sum ()})
368+ log ('avg' , {"fastpitch_frames/s" : all_frames / gm .sum ()})
369+ log ('avg' , {"fastpitch_latency" : gm .mean ()})
370+ log ('avg' , {"fastpitch RTF" : rtf })
371+ log ('90%' , {"fastpitch_latency" : gm .mean () + norm .ppf ((1.0 + 0.90 ) / 2 ) * gm .std ()})
372+ log ('95%' , {"fastpitch_latency" : gm .mean () + norm .ppf ((1.0 + 0.95 ) / 2 ) * gm .std ()})
373+ log ('99%' , {"fastpitch_latency" : gm .mean () + norm .ppf ((1.0 + 0.99 ) / 2 ) * gm .std ()})
368374 if waveglow is not None :
369375 wm = np .sort (np .asarray (waveglow_measures ))
376+ rtf = all_samples / (all_utterances * wm .mean () * args .sampling_rate )
370377 log ('avg' , {"waveglow_samples/s" : all_samples / wm .sum ()})
371378 log ('avg' , {"waveglow_latency" : wm .mean ()})
379+ log ('avg' , {"waveglow RTF" : rtf })
372380 log ('90%' , {"waveglow_latency" : wm .mean () + norm .ppf ((1.0 + 0.90 ) / 2 ) * wm .std ()})
373381 log ('95%' , {"waveglow_latency" : wm .mean () + norm .ppf ((1.0 + 0.95 ) / 2 ) * wm .std ()})
374382 log ('99%' , {"waveglow_latency" : wm .mean () + norm .ppf ((1.0 + 0.99 ) / 2 ) * wm .std ()})
375383 if generator is not None and waveglow is not None :
376384 m = gm + wm
377- rtf = all_samples / (len ( batches ) * all_utterances * m .mean () * args .sampling_rate )
385+ rtf = all_samples / (all_utterances * m .mean () * args .sampling_rate )
378386 log ('avg' , {"samples/s" : all_samples / m .sum ()})
379387 log ('avg' , {"letters/s" : all_letters / m .sum ()})
380388 log ('avg' , {"latency" : m .mean ()})
0 commit comments