Skip to content

Commit 16afe7a

Browse files
committed
[src,egs,scripts]: Replace online-nnet3 decoding setup with 'looped' decoding and give example script with TDNN+LSTM.
1 parent d9a5312 commit 16afe7a

24 files changed

+824
-662
lines changed

egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1b.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# and adding
77
# --egs.chunk-left-context-initial=0
88
# and --egs.chunk-right-context-final=0
9-
9+
# See 1e for summary of results.
1010

1111
# steps/info/chain_dir_info.pl exp/chain_cleaned/tdnn_lstm1a_sp_bi
1212
# exp/chain_cleaned/tdnn_lstm1a_sp_bi: num-iters=253 nj=2..12 num-params=9.5M dim=40+100->3607 combine=-0.07->-0.07 xent:train/valid[167,252,final]=(-0.960,-0.859,-0.852/-1.05,-0.999,-0.997) logprob:train/valid[167,252,final]=(-0.076,-0.064,-0.062/-0.099,-0.092,-0.091)

egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1d.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# 1d is as 1b, but adding decay-time=40 to the fast-lstmp-layers. note: it
44
# uses egs from 1b, remember to remove that before I commit.
5+
# See 1e for summary of results.
56

67
# steps/info/chain_dir_info.pl exp/chain_cleaned/tdnn_lstm1a_sp_bi
78
# exp/chain_cleaned/tdnn_lstm1a_sp_bi: num-iters=253 nj=2..12 num-params=9.5M dim=40+100->3607 combine=-0.07->-0.07 xent:train/valid[167,252,final]=(-0.960,-0.859,-0.852/-1.05,-0.999,-0.997) logprob:train/valid[167,252,final]=(-0.076,-0.064,-0.062/-0.099,-0.092,-0.091)

egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1e.sh

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,50 @@
11
#!/bin/bash
22

3-
# 1e is as 1b, but reducing decay-time from 40 to 20.
3+
# 1e is as 1d, but reducing decay-time from 40 to 20.
4+
5+
# The following table shows comparison of various decay-time values,
6+
# namely: [b:unset=infinity, f:80, d:40, e:20, g:10, g2:5].
7+
# note: the g2 script is not checked in.
8+
# There is no clear trend on the non-looped decoding, but looped decoding seems
9+
# to improve as decay-time is decreased. We end up recommending decay-time=20,
10+
# as by then we get all the improvement on looped decoding, and it's the
11+
# most conservative setting with which we can get this improvement (although
12+
# actually it seems fine to use an even smaller decay-time).
13+
14+
# local/chain/compare_wer_general.sh --looped exp/chain_cleaned/tdnn_lstm1{b,f,d,e,g,g2}_sp_bi
15+
16+
# local/chain/compare_wer_general.sh --looped exp/chain_cleaned/tdnn_lstm1b_sp_bi exp/chain_cleaned/tdnn_lstm1f_sp_bi exp/chain_cleaned/tdnn_lstm1d_sp_bi exp/chain_cleaned/tdnn_lstm1e_sp_bi exp/chain_cleaned/tdnn_lstm1g_sp_bi exp/chain_cleaned/tdnn_lstm1g2_sp_bi
17+
# System tdnn_lstm1b_sp_bi tdnn_lstm1f_sp_bi tdnn_lstm1d_sp_bi tdnn_lstm1e_sp_bi tdnn_lstm1g_sp_bi tdnn_lstm1g2_sp_bi
18+
# WER on dev(orig) 9.1 8.8 9.0 9.0 9.0 9.4
19+
# [looped:] 9.4 9.3 9.2 9.0 8.9 9.4
20+
# WER on dev(rescored) 8.4 8.2 8.4 8.4 8.4 8.7
21+
# [looped:] 8.8 8.7 8.6 8.4 8.3 8.7
22+
# WER on test(orig) 8.9 9.0 8.9 8.8 8.8 9.3
23+
# [looped:] 9.3 9.3 9.0 8.8 8.8 9.2
24+
# WER on test(rescored) 8.4 8.6 8.3 8.4 8.4 8.9
25+
# [looped:] 8.7 8.9 8.5 8.3 8.4 8.8
26+
# Final train prob -0.0621 -0.0631 -0.0595 -0.0648 -0.0689 -0.0739
27+
# Final valid prob -0.0799 -0.0802 -0.0823 -0.0827 -0.0890 -0.0963
28+
# Final train prob (xent) -0.8300 -0.8295 -0.8129 -0.8372 -0.8610 -0.8792
29+
# Final valid prob (xent) -0.9500 -0.9662 -0.9589 -0.9497 -0.9982 -1.0256
30+
31+
32+
# the following table compares the 'online' decoding with regular and looped
33+
# decoding. online decoding is a little better than either (possibly due to
34+
# using slightly later iVectors).
35+
#
36+
# local/chain/compare_wer_general.sh --looped exp/chain_cleaned/tdnn_lstm1e_sp_bi{,_online} 2>/dev/null
37+
# local/chain/compare_wer_general.sh --looped exp/chain_cleaned/tdnn_lstm1e_sp_bi exp/chain_cleaned/tdnn_lstm1e_sp_bi_online
38+
# System tdnn_lstm1e_sp_bi tdnn_lstm1e_sp_bi_online
39+
# WER on dev(orig) 9.0 8.8
40+
# [looped:] 9.0
41+
# WER on dev(rescored) 8.4 8.4
42+
# [looped:] 8.4
43+
# WER on test(orig) 8.8 8.8
44+
# [looped:] 8.8
45+
# WER on test(rescored) 8.4 8.4
46+
# [looped:] 8.3
47+
448

549
# 1d is as 1b, but adding decay-time=40 to the fast-lstmp-layers. note: it
650
# uses egs from 1b, remember to remove that before I commit.
@@ -77,6 +121,8 @@ tdnn_lstm_affix=1e #affix for TDNN-LSTM directory, e.g. "a" or "b", in case we
77121
common_egs_dir= # you can set this to use previously dumped egs.
78122
remove_egs=true
79123

124+
test_online_decoding=false # if true, it will run the last decoding stage.
125+
80126
# End configuration section.
81127
echo "$0 $@" # Print the command line for logging
82128

@@ -289,8 +335,10 @@ if [ $stage -le 21 ]; then
289335
# 'looped' decoding. we didn't write a -parallel version of this program yet,
290336
# so it will take a bit longer as the --num-threads option is not supported.
291337
# we just hardcode the --frames-per-chunk option as it doesn't have to
292-
# match any value used in training, and it won't affect the results (unlike
293-
# regular decoding).
338+
# match any value used in training, and it won't affect the results very much (unlike
339+
# regular decoding)... [it will affect them slightly due to differences in the
340+
# iVector extraction; probably smaller will be worse as it sees less of the future,
341+
# but in a real scenario, long chunks will introduce excessive latency].
294342
rm $dir/.error 2>/dev/null || true
295343
for dset in dev test; do
296344
(
@@ -313,4 +361,35 @@ if [ $stage -le 21 ]; then
313361
fi
314362

315363

364+
if $test_online_decoding && [ $stage -le 22 ]; then
365+
# note: if the features change (e.g. you add pitch features), you will have to
366+
# change the options of the following command line.
367+
steps/online/nnet3/prepare_online_decoding.sh \
368+
--mfcc-config conf/mfcc_hires.conf \
369+
data/lang_chain exp/nnet3${nnet3_affix}/extractor ${dir} ${dir}_online
370+
371+
rm $dir/.error 2>/dev/null || true
372+
for dset in dev test; do
373+
(
374+
# note: we just give it "$dset" as it only uses the wav.scp, the
375+
# feature type does not matter.
376+
377+
steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \
378+
--extra-left-context-initial $extra_left_context_initial \
379+
--acwt 1.0 --post-decode-acwt 10.0 \
380+
--scoring-opts "--min-lmwt 5 " \
381+
$dir/graph data/${dset} ${dir}_online/decode_${dset} || exit 1;
382+
steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \
383+
data/${dset}_hires ${dir}_online/decode_${dset} ${dir}_online/decode_${dset}_rescore || exit 1
384+
) || touch $dir/.error &
385+
done
386+
wait
387+
if [ -f $dir/.error ]; then
388+
echo "$0: something went wrong in decoding"
389+
exit 1
390+
fi
391+
fi
392+
393+
394+
316395
exit 0

egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1f.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22

3-
# 1f is as 1b, but increasing decay-time from 40 to 80. [see also 1e, at 20.]
3+
# 1f is as 1d, but increasing decay-time from 40 to 80. [see also 1e, at 20.]
4+
# see 1e for summary of results.
45

56
# 1d is as 1b, but adding decay-time=40 to the fast-lstmp-layers. note: it
67
# uses egs from 1b, remember to remove that before I commit.

egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_1g.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#######################
44
# 1g is as 1e, but reducing decay-time further from 20 to 10.
5+
# see 1e for summary of results.
6+
57
# 1e is as 1b, but reducing decay-time from 40 to 20.
68

79
# 1d is as 1b, but adding decay-time=40 to the fast-lstmp-layers. note: it

egs/wsj/s5/steps/online/nnet3/decode.sh

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
stage=0
99
nj=4
1010
cmd=run.pl
11+
frames_per_chunk=20
12+
extra_left_context_initial=0
1113
min_active=200
1214
max_active=7000
1315
beam=15.0
@@ -114,11 +116,6 @@ else
114116
fi
115117

116118

117-
decoder=online2-wav-nnet3-latgen-faster
118-
parallel_opts=
119-
opts="--online=$online"
120-
121-
122119
if [ "$post_decode_acwt" == 1.0 ]; then
123120
lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz"
124121
else
@@ -132,8 +129,12 @@ if [ -f $srcdir/frame_subsampling_factor ]; then
132129
fi
133130

134131
if [ $stage -le 0 ]; then
135-
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
136-
$decoder $opts $silence_weighting_opts --do-endpointing=$do_endpointing $frame_subsampling_opt \
132+
$cmd JOB=1:$nj $dir/log/decode.JOB.log \
133+
online2-wav-nnet3-latgen-faster $silence_weighting_opts --do-endpointing=$do_endpointing \
134+
--frames-per-chunk=$frames_per_chunk \
135+
--extra-left-context-initial=$extra_left_context_initial \
136+
--online=$online \
137+
$frame_subsampling_opt \
137138
--config=$online_config \
138139
--min-active=$min_active --max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
139140
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \

src/itf/decodable-itf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class DecodableInterface {
112112

113113
/// Returns the number of states in the acoustic model
114114
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
115-
/// this is for compatibility with OpenFst.
115+
/// this is for compatibility with OpenFst).
116116
virtual int32 NumIndices() const = 0;
117117

118118
virtual ~DecodableInterface() {}

src/nnet2/online-nnet2-decodable.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ int32 DecodableNnet2Online::NumFramesReady() const {
8080

8181
void DecodableNnet2Online::ComputeForFrame(int32 frame) {
8282
int32 features_ready = features_->NumFramesReady();
83-
bool input_finished = features_->IsLastFrame(features_ready - 1);
83+
bool input_finished = features_->IsLastFrame(features_ready - 1);
8484
KALDI_ASSERT(frame >= 0);
8585
if (frame >= begin_frame_ &&
8686
frame < begin_frame_ + scaled_loglikes_.NumRows())
@@ -112,20 +112,20 @@ void DecodableNnet2Online::ComputeForFrame(int32 frame) {
112112
t_modified = features_ready - 1;
113113
features_->GetFrame(t_modified, &row);
114114
}
115-
CuMatrix<BaseFloat> cu_features;
115+
CuMatrix<BaseFloat> cu_features;
116116
cu_features.Swap(&features); // Copy to GPU, if we're using one.
117-
117+
118118

119119
int32 num_frames_out = input_frame_end - input_frame_begin -
120120
left_context_ - right_context_;
121-
121+
122122
CuMatrix<BaseFloat> cu_posteriors(num_frames_out, num_pdfs_);
123-
123+
124124
// The "false" below tells it not to pad the input: we've already done
125125
// any padding that we needed to do.
126126
NnetComputation(nnet_.GetNnet(), cu_features,
127127
false, &cu_posteriors);
128-
128+
129129
cu_posteriors.ApplyFloor(1.0e-20); // Avoid log of zero which leads to NaN.
130130
cu_posteriors.ApplyLog();
131131
// subtract log-prior (divide by prior)

src/nnet3/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
2828
discriminative-supervision.o nnet-discriminative-example.o \
2929
nnet-discriminative-diagnostics.o \
3030
discriminative-training.o nnet-discriminative-training.o \
31-
online-nnet3-decodable-simple.o nnet-compile-looped.o \
32-
decodable-simple-looped.o
31+
nnet-compile-looped.o decodable-simple-looped.o \
32+
decodable-online-looped.o
3333

3434

3535
LIBNAME = kaldi-nnet3

0 commit comments

Comments
 (0)