Skip to content

Commit b2b2d1c

Browse files
Merge pull request #590 from NVIDIA/jasper-ampere
[Jasper/PyT] Updating for Ampere
2 parents b2763ae + ae7fce1 commit b2b2d1c

69 files changed

Lines changed: 918 additions & 818 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

PyTorch/SpeechRecognition/Jasper/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@ __pycache__
33
results/
44
datasets/
55
checkpoints/
6+
7+
*.swp
8+
*.swo
9+
*.swn

PyTorch/SpeechRecognition/Jasper/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.10-py3
15+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
1616
FROM ${FROM_IMAGE_NAME}
1717

1818

PyTorch/SpeechRecognition/Jasper/README.md

Lines changed: 277 additions & 254 deletions
Large diffs are not rendered by default.

PyTorch/SpeechRecognition/Jasper/helpers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
from metrics import word_error_rate
2020

2121

22-
23-
24-
25-
AmpOptimizations = ["O0", "O1", "O2", "O3"]
26-
2722
def print_once(msg):
2823
if (not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)):
2924
print(msg)

PyTorch/SpeechRecognition/Jasper/inference.py

Lines changed: 75 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import math
2020
import toml
2121
from dataset import AudioToTextDataLayer
22-
from helpers import process_evaluation_batch, process_evaluation_epoch, add_ctc_labels, AmpOptimizations, print_dict, model_multi_gpu, __ctc_decoder_predictions_tensor
22+
from helpers import process_evaluation_batch, process_evaluation_epoch, add_ctc_labels, print_dict, model_multi_gpu, __ctc_decoder_predictions_tensor
2323
from model import AudioPreprocessing, GreedyCTCDecoder, JasperEncoderDecoder
2424
from parts.features import audio_from_file
2525
import torch
@@ -46,21 +46,21 @@ def parse_args():
4646
parser.add_argument("--ckpt", default=None, type=str, required=True, help='path to model checkpoint')
4747
parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
4848
parser.add_argument("--pad_to", default=None, type=int, help="default is pad to value as specified in model configurations. if -1 pad to maximum duration. If > 0 pad batch to next multiple of value")
49-
parser.add_argument("--fp16", action='store_true', help='use half precision')
50-
parser.add_argument("--pyt_fp16", action='store_true', help='use half precision')
49+
parser.add_argument("--amp", "--fp16", action='store_true', help='use half precision')
5150
parser.add_argument("--cudnn_benchmark", action='store_true', help="enable cudnn benchmark")
5251
parser.add_argument("--save_prediction", type=str, default=None, help="if specified saves predictions in text form at this location")
5352
parser.add_argument("--logits_save_to", default=None, type=str, help="if specified will save logits to path")
5453
parser.add_argument("--seed", default=42, type=int, help='seed')
55-
parser.add_argument("--masked_fill", type="bool", help="Overrides the masked_fill option for the Encoder")
5654
parser.add_argument("--output_dir", default="results/", type=str, help="Output directory to store exported models. Only used if --export_model is used")
5755
parser.add_argument("--export_model", action='store_true', help="Exports the audio_featurizer, encoder and decoder using torch.jit to the output_dir")
5856
parser.add_argument("--wav", type=str, help='absolute path to .wav file (16KHz)')
57+
parser.add_argument("--cpu", action="store_true", help="Run inference on CPU")
58+
parser.add_argument("--ema", action="store_true", help="If available, load EMA model weights")
5959
return parser.parse_args()
6060

61-
def calc_wer(data_layer, audio_processor,
62-
encoderdecoder, greedy_decoder,
63-
labels, args):
61+
def calc_wer(data_layer, audio_processor,
62+
encoderdecoder, greedy_decoder,
63+
labels, args, device):
6464

6565
encoderdecoder = encoderdecoder.module if hasattr(encoderdecoder, 'module') else encoderdecoder
6666
with torch.no_grad():
@@ -74,16 +74,14 @@ def calc_wer(data_layer, audio_processor,
7474
# Evaluation mini-batch for loop
7575
for it, data in enumerate(tqdm(data_layer.data_iterator)):
7676

77-
tensors = []
78-
for d in data:
79-
tensors.append(d.cuda())
77+
tensors = [t.to(device) for t in data]
8078

8179
t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors
82-
83-
t_processed_signal = audio_processor(t_audio_signal_e, t_a_sig_length_e)
80+
81+
t_processed_signal = audio_processor(t_audio_signal_e, t_a_sig_length_e)
8482
t_log_probs_e, _ = encoderdecoder.infer(t_processed_signal)
8583
t_predictions_e = greedy_decoder(t_log_probs_e)
86-
84+
8785
values_dict = dict(
8886
predictions=[t_predictions_e],
8987
transcript=[t_transcript_e],
@@ -92,7 +90,7 @@ def calc_wer(data_layer, audio_processor,
9290
)
9391
# values_dict will contain results from all workers
9492
process_evaluation_batch(values_dict, _global_var_dict, labels=labels)
95-
93+
9694
if args.steps is not None and it + 1 >= args.steps:
9795
break
9896

@@ -102,18 +100,13 @@ def calc_wer(data_layer, audio_processor,
102100
return wer, _global_var_dict
103101

104102

105-
def jit_export(
106-
audio, audio_len,
107-
audio_processor,
108-
encoderdecoder,
109-
greedy_decoder,
110-
args):
103+
def jit_export(audio, audio_len, audio_processor, encoderdecoder, greedy_decoder, args):
111104

112105
print("##############")
113106

114-
module_name = "{}_{}".format(os.path.basename(args.model_toml), "fp16" if args.fp16 else "fp32")
107+
module_name = "{}_{}".format(os.path.basename(args.model_toml), "fp16" if args.amp else "fp32")
115108

116-
if args.masked_fill is not None and args.masked_fill == False:
109+
if args.use_conv_mask:
117110
module_name = module_name + "_noMaskConv"
118111

119112
# Export just the featurizer
@@ -137,12 +130,18 @@ def jit_export(
137130

138131
return traced_module_feat, traced_module_acoustic, traced_module_decode
139132

140-
def run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels):
141-
features = audio_processor(audio, audio_len)
142-
torch.cuda.synchronize()
133+
def run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device):
134+
features, lens = audio_processor(audio, audio_len)
135+
if not device.type == 'cpu':
136+
torch.cuda.synchronize()
143137
t0 = time.perf_counter()
144-
t_log_probs_e = encoderdecoder(features[0])
145-
torch.cuda.synchronize()
138+
# TorchScripted model does not support (features, lengths)
139+
if isinstance(encoderdecoder, torch.jit.TracedModule):
140+
t_log_probs_e = encoderdecoder(features)
141+
else:
142+
t_log_probs_e, _ = encoderdecoder.infer((features, lens))
143+
if not device.type == 'cpu':
144+
torch.cuda.synchronize()
146145
t1 = time.perf_counter()
147146
t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)
148147
hypotheses = __ctc_decoder_predictions_tensor(t_predictions_e, labels=labels)
@@ -157,6 +156,7 @@ def eval(
157156
greedy_decoder,
158157
labels,
159158
multi_gpu,
159+
device,
160160
args):
161161
"""performs inference / evaluation
162162
Args:
@@ -169,21 +169,19 @@ def eval(
169169
args: script input arguments
170170
"""
171171
logits_save_to=args.logits_save_to
172-
172+
173173
with torch.no_grad():
174174
if args.wav:
175175
audio, audio_len = audio_from_file(args.wav)
176-
run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels)
176+
run_once(audio_processor, encoderdecoder, greedy_decoder, audio, audio_len, labels, device)
177177
if args.export_model:
178-
jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder = jit_export(audio, audio_len, audio_processor,
179-
encoderdecoder,
180-
greedy_decoder,args)
181-
run_once(jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder, audio, audio_len, labels)
178+
jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder = jit_export(audio, audio_len, audio_processor, encoderdecoder,greedy_decoder,args)
179+
run_once(jit_audio_processor, jit_encoderdecoder, jit_greedy_decoder, audio, audio_len, labels, device)
182180
return
183-
wer, _global_var_dict = calc_wer(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, args)
181+
wer, _global_var_dict = calc_wer(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, args, device)
184182
if (not multi_gpu or (multi_gpu and torch.distributed.get_rank() == 0)):
185183
print("==========>>>>>>Evaluation WER: {0}\n".format(wer))
186-
184+
187185
if args.save_prediction is not None:
188186
with open(args.save_prediction, 'w') as fp:
189187
fp.write('\n'.join(_global_var_dict['predictions']))
@@ -203,26 +201,29 @@ def eval(
203201
# print("===>>>Diff : {0} %".format((wer_after - wer_before) * 100.0 / wer_before))
204202
# print("")
205203

206-
204+
207205
def main(args):
208206
random.seed(args.seed)
209207
np.random.seed(args.seed)
210208
torch.manual_seed(args.seed)
211-
torch.backends.cudnn.benchmark = args.cudnn_benchmark
212-
print("CUDNN BENCHMARK ", args.cudnn_benchmark)
213-
assert(torch.cuda.is_available())
214209

215-
if args.local_rank is not None:
216-
torch.cuda.set_device(args.local_rank)
217-
torch.distributed.init_process_group(backend='nccl', init_method='env://')
218210
multi_gpu = args.local_rank is not None
219-
if multi_gpu:
220-
print("DISTRIBUTED with ", torch.distributed.get_world_size())
221211

222-
if args.fp16:
223-
optim_level = 3
212+
if args.cpu:
213+
assert(not multi_gpu)
214+
device = torch.device('cpu')
224215
else:
225-
optim_level = 0
216+
assert(torch.cuda.is_available())
217+
device = torch.device('cuda')
218+
torch.backends.cudnn.benchmark = args.cudnn_benchmark
219+
print("CUDNN BENCHMARK ", args.cudnn_benchmark)
220+
221+
if multi_gpu:
222+
print("DISTRIBUTED with ", torch.distributed.get_world_size())
223+
torch.cuda.set_device(args.local_rank)
224+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
225+
226+
optim_level = 3 if args.amp else 0
226227

227228
jasper_model_definition = toml.load(args.model_toml)
228229
dataset_vocab = jasper_model_definition['labels']['labels']
@@ -231,32 +232,32 @@ def main(args):
231232
val_manifest = args.val_manifest
232233
featurizer_config = jasper_model_definition['input_eval']
233234
featurizer_config["optimization_level"] = optim_level
234-
featurizer_config["fp16"] = args.fp16
235-
args.use_conv_mask = jasper_model_definition['encoder'].get('convmask', True)
235+
featurizer_config["fp16"] = args.amp
236236

237-
if args.masked_fill is not None:
238-
print("{} masked_fill".format("Enabling" if args.masked_fill else "Disabling"))
239-
jasper_model_definition["encoder"]["conv_mask"] = args.masked_fill
237+
args.use_conv_mask = jasper_model_definition['encoder'].get('convmask', True)
238+
if args.use_conv_mask and args.export_model:
239+
print('WARNING: Masked convs currently not supported for TorchScript. Disabling.')
240+
jasper_model_definition['encoder']['convmask'] = False
240241

241242
if args.max_duration is not None:
242243
featurizer_config['max_duration'] = args.max_duration
243244
if args.pad_to is not None:
244-
featurizer_config['pad_to'] = args.pad_to
245+
featurizer_config['pad_to'] = args.pad_to
245246

246247
if featurizer_config['pad_to'] == "max":
247248
featurizer_config['pad_to'] = -1
248-
249+
249250
print('=== model_config ===')
250251
print_dict(jasper_model_definition)
251252
print()
252253
print('=== feature_config ===')
253254
print_dict(featurizer_config)
254255
print()
255256
data_layer = None
256-
257+
257258
if args.wav is None:
258259
data_layer = AudioToTextDataLayer(
259-
dataset_dir=args.dataset_dir,
260+
dataset_dir=args.dataset_dir,
260261
featurizer_config=featurizer_config,
261262
manifest_filepath=val_manifest,
262263
labels=dataset_vocab,
@@ -274,10 +275,16 @@ def main(args):
274275
exit(0)
275276
else:
276277
checkpoint = torch.load(args.ckpt, map_location="cpu")
278+
if args.ema and 'ema_state_dict' in checkpoint:
279+
print('Loading EMA state dict')
280+
sd = 'ema_state_dict'
281+
else:
282+
sd = 'state_dict'
283+
277284
for k in audio_preprocessor.state_dict().keys():
278-
checkpoint['state_dict'][k] = checkpoint['state_dict'].pop("audio_preprocessor." + k)
279-
audio_preprocessor.load_state_dict(checkpoint['state_dict'], strict=False)
280-
encoderdecoder.load_state_dict(checkpoint['state_dict'], strict=False)
285+
checkpoint[sd][k] = checkpoint[sd].pop("audio_preprocessor." + k)
286+
audio_preprocessor.load_state_dict(checkpoint[sd], strict=False)
287+
encoderdecoder.load_state_dict(checkpoint[sd], strict=False)
281288

282289
greedy_decoder = GreedyCTCDecoder()
283290

@@ -298,24 +305,27 @@ def main(args):
298305
print('-----------------')
299306

300307
print ("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize)
301-
audio_preprocessor.cuda()
302-
encoderdecoder.cuda()
303-
if args.fp16:
304-
encoderdecoder = amp.initialize( models=encoderdecoder,
305-
opt_level=AmpOptimizations[optim_level])
308+
309+
audio_preprocessor.to(device)
310+
encoderdecoder.to(device)
311+
312+
if args.amp:
313+
encoderdecoder = amp.initialize(models=encoderdecoder,
314+
opt_level='O'+str(optim_level))
306315

307316
encoderdecoder = model_multi_gpu(encoderdecoder, multi_gpu)
308317
audio_preprocessor.eval()
309318
encoderdecoder.eval()
310319
greedy_decoder.eval()
311-
320+
312321
eval(
313322
data_layer=data_layer,
314323
audio_processor=audio_preprocessor,
315324
encoderdecoder=encoderdecoder,
316325
greedy_decoder=greedy_decoder,
317326
labels=ctc_vocab,
318327
args=args,
328+
device=device,
319329
multi_gpu=multi_gpu)
320330

321331
if __name__=="__main__":

0 commit comments

Comments
 (0)