Skip to content

Commit b1ae952

Browse files
hainan-xvdanpovey
authored andcommitted
[scripts,egs] Support averaging forward and backward RNNLMs (kaldi-asr#2436)
1 parent 447e964 commit b1ae952

File tree

6 files changed

+355
-3
lines changed

6 files changed

+355
-3
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tuning/run_tdnn_lstm_back_1e.sh

egs/swbd/s5c/local/rnnlm/tuning/run_tdnn_lstm_1e.sh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ train_stage=-10
3232
# variables for lattice rescoring
3333
run_lat_rescore=true
3434
run_nbest_rescore=true
35+
run_backward_rnnlm=false
36+
3537
ac_model_dir=exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp
3638
decode_dir_suffix=rnnlm_1e
3739
ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order
@@ -130,10 +132,10 @@ if [ $stage -le 4 ] && $run_lat_rescore; then
130132
# Lattice rescoring
131133
rnnlm/lmrescore$pruned.sh \
132134
--cmd "$decode_cmd --mem 4G" \
133-
--weight 0.5 --max-ngram-order $ngram_order \
135+
--weight 0.45 --max-ngram-order $ngram_order \
134136
data/lang_$LM $dir \
135137
data/${decode_set}_hires ${decode_dir} \
136-
${decode_dir}_${decode_dir_suffix}
138+
${decode_dir}_${decode_dir_suffix}_0.45
137139
done
138140
fi
139141

@@ -151,4 +153,10 @@ if [ $stage -le 5 ] && $run_nbest_rescore; then
151153
done
152154
fi
153155

156+
# running backward RNNLM, which further improves WERS by combining backward with
157+
# the forward RNNLM trained in this script.
158+
if [ $stage -le 6 ] && $run_backward_rnnlm; then
159+
local/rnnlm/run_tdnn_lstm_back.sh
160+
fi
161+
154162
exit 0
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/bin/bash
2+
3+
# Copyright 2012 Johns Hopkins University (author: Daniel Povey)
4+
# 2015 Guoguo Chen
5+
# 2017 Hainan Xu
6+
# 2017 Xiaohui Zhang
7+
8+
# This script trains a backward LMs on the swbd LM-training data, and use it
9+
# to rescore either decoded lattices, or lattices that are just rescored with
10+
# a forward RNNLM. In order to run this, you must first run the forward RNNLM
11+
# recipe at local/rnnlm/run_tdnn_lstm.sh
12+
13+
# rnnlm/train_rnnlm.sh: best iteration (out of 35) was 34, linking it to final iteration.
14+
# rnnlm/train_rnnlm.sh: train/dev perplexity was 41.8 / 55.1.
15+
# Train objf: -5.18 -4.46 -4.26 -4.18 -4.12 -4.07 -4.04 -4.00 -3.99 -3.98 -3.95 -3.93 -3.91 -3.90 -3.88 -3.87 -3.86 -3.85 -3.83 -3.82 -3.82 -3.81 -3.79 -3.79 -3.78 -3.77 -3.76 -3.77 -3.75 -3.74 -3.74 -3.73 -3.72 -3.71 -3.71
16+
# Dev objf: -10.32 -4.89 -4.57 -4.45 -4.37 -4.33 -4.29 -4.26 -4.24 -4.22 -4.18 -4.17 -4.15 -4.14 -4.13 -4.12 -4.11 -4.10 -4.09 -4.08 -4.07 -4.06 -4.06 -4.05 -4.05 -4.05 -4.04 -4.04 -4.03 -4.03 -4.02 -4.02 -4.02 -4.01 -4.01
17+
18+
# %WER 11.1 | 1831 21395 | 89.9 6.4 3.7 1.0 11.1 46.3 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped/score_13_0.0/eval2000_hires.ctm.swbd.filt.sys
19+
# %WER 9.9 | 1831 21395 | 91.0 5.8 3.2 0.9 9.9 43.2 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys
20+
# %WER 9.5 | 1831 21395 | 91.4 5.5 3.1 0.9 9.5 42.5 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e_back/score_11_0.0/eval2000_hires.ctm.swbd.filt.sys
21+
22+
# %WER 15.9 | 4459 42989 | 85.7 9.7 4.6 1.6 15.9 51.6 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped/score_10_0.0/eval2000_hires.ctm.filt.sys
23+
# %WER 14.4 | 4459 42989 | 87.0 8.7 4.3 1.5 14.4 49.4 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e/score_11_0.0/eval2000_hires.ctm.filt.sys
24+
# %WER 13.9 | 4459 42989 | 87.6 8.4 4.0 1.5 13.9 48.6 | exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp/decode_eval2000_sw1_fsh_fg_looped_rnnlm_1e_back/score_10_0.0/eval2000_hires.ctm.filt.sys
25+
26+
# Begin configuration section.
27+
28+
dir=exp/rnnlm_lstm_1e_backward
29+
embedding_dim=1024
30+
lstm_rpd=256
31+
lstm_nrpd=256
32+
stage=-10
33+
train_stage=-10
34+
35+
# variables for lattice rescoring
36+
run_lat_rescore=true
37+
ac_model_dir=exp/nnet3/tdnn_lstm_1a_adversarial0.3_epochs12_ld5_sp
38+
decode_dir_suffix_forward=rnnlm_1e
39+
decode_dir_suffix_backward=rnnlm_1e_back
40+
ngram_order=4 # approximate the lattice-rescoring by limiting the max-ngram-order
41+
# if it's set, it merges histories in the lattice if they share
42+
# the same ngram history and this prevents the lattice from
43+
# exploding exponentially
44+
45+
. ./cmd.sh
46+
. ./utils/parse_options.sh
47+
48+
text=data/train_nodev/text
49+
fisher_text=data/local/lm/fisher/text1.gz
50+
lexicon=data/local/dict_nosp/lexiconp.txt
51+
text_dir=data/rnnlm/text_nosp_1e_back
52+
mkdir -p $dir/config
53+
set -e
54+
55+
for f in $text $lexicon; do
56+
[ ! -f $f ] && \
57+
echo "$0: expected file $f to exist; search for local/wsj_extend_dict.sh in run.sh" && exit 1
58+
done
59+
60+
if [ $stage -le 0 ]; then
61+
mkdir -p $text_dir
62+
echo -n >$text_dir/dev.txt
63+
# hold out one in every 50 lines as dev data.
64+
cat $text | cut -d ' ' -f2- | awk '{for(i=NF;i>0;i--) printf("%s ", $i); print""}' | awk -v text_dir=$text_dir '{if(NR%50 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$text_dir/swbd.txt
65+
cat > $dir/config/hesitation_mapping.txt <<EOF
66+
hmm hum
67+
mmm um
68+
mm um
69+
mhm um-hum
70+
EOF
71+
gunzip -c $fisher_text | awk 'NR==FNR{a[$1]=$2;next}{for (n=1;n<=NF;n++) if ($n in a) $n=a[$n];print $0}' \
72+
$dir/config/hesitation_mapping.txt - | awk '{for(i=NF;i>0;i--) printf("%s ", $i); print""}' > $text_dir/fisher.txt
73+
fi
74+
75+
if [ $stage -le 1 ]; then
76+
cp data/lang/words.txt $dir/config/
77+
n=`cat $dir/config/words.txt | wc -l`
78+
echo "<brk> $n" >> $dir/config/words.txt
79+
80+
# words that are not present in words.txt but are in the training or dev data, will be
81+
# mapped to <SPOKEN_NOISE> during training.
82+
echo "<unk>" >$dir/config/oov.txt
83+
84+
cat > $dir/config/data_weights.txt <<EOF
85+
swbd 3 1.0
86+
fisher 1 1.0
87+
EOF
88+
89+
rnnlm/get_unigram_probs.py --vocab-file=$dir/config/words.txt \
90+
--unk-word="<unk>" \
91+
--data-weights-file=$dir/config/data_weights.txt \
92+
$text_dir | awk 'NF==2' >$dir/config/unigram_probs.txt
93+
94+
# choose features
95+
rnnlm/choose_features.py --unigram-probs=$dir/config/unigram_probs.txt \
96+
--use-constant-feature=true \
97+
--special-words='<s>,</s>,<brk>,<unk>,[noise],[laughter],[vocalized-noise]' \
98+
$dir/config/words.txt > $dir/config/features.txt
99+
100+
cat >$dir/config/xconfig <<EOF
101+
input dim=$embedding_dim name=input
102+
relu-renorm-layer name=tdnn1 dim=$embedding_dim input=Append(0, IfDefined(-1))
103+
fast-lstmp-layer name=lstm1 cell-dim=$embedding_dim recurrent-projection-dim=$lstm_rpd non-recurrent-projection-dim=$lstm_nrpd
104+
relu-renorm-layer name=tdnn2 dim=$embedding_dim input=Append(0, IfDefined(-3))
105+
fast-lstmp-layer name=lstm2 cell-dim=$embedding_dim recurrent-projection-dim=$lstm_rpd non-recurrent-projection-dim=$lstm_nrpd
106+
relu-renorm-layer name=tdnn3 dim=$embedding_dim input=Append(0, IfDefined(-3))
107+
output-layer name=output include-log-softmax=false dim=$embedding_dim
108+
EOF
109+
rnnlm/validate_config_dir.sh $text_dir $dir/config
110+
fi
111+
112+
if [ $stage -le 2 ]; then
113+
rnnlm/prepare_rnnlm_dir.sh $text_dir $dir/config $dir
114+
fi
115+
116+
if [ $stage -le 3 ]; then
117+
rnnlm/train_rnnlm.sh --num-jobs-initial 1 --num-jobs-final 3 \
118+
--stage $train_stage --num-epochs 10 --cmd "$train_cmd" $dir
119+
fi
120+
121+
LM=sw1_fsh_fg # using the 4-gram const arpa file as old lm
122+
if [ $stage -le 4 ] && $run_lat_rescore; then
123+
echo "$0: Perform lattice-rescoring on $ac_model_dir"
124+
125+
for decode_set in eval2000; do
126+
decode_dir=${ac_model_dir}/decode_${decode_set}_${LM}_looped
127+
if [ ! -d ${decode_dir}_${decode_dir_suffix_forward} ]; then
128+
echo "$0: Must run the forward recipe first at local/rnnlm/run_tdnn_lstm.sh"
129+
exit 1
130+
fi
131+
132+
# Lattice rescoring
133+
rnnlm/lmrescore_back.sh \
134+
--cmd "$decode_cmd --mem 4G" \
135+
--weight 0.45 --max-ngram-order $ngram_order \
136+
data/lang_$LM $dir \
137+
data/${decode_set}_hires ${decode_dir}_${decode_dir_suffix_forward}_0.45 \
138+
${decode_dir}_${decode_dir_suffix_backward}_0.45
139+
done
140+
fi
141+
142+
exit 0

scripts/rnnlm/lmrescore_back.sh

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#!/bin/bash
2+
3+
# Copyright 2017 Hainan Xu
4+
# Apache 2.0
5+
6+
# This script rescores lattices with KALDI RNNLM trained on reversed text.
7+
# The input directory should already be rescored with a forward RNNLM, preferably
8+
# with the pruned algorithm, since smaller lattices make rescoring much faster.
9+
# An example of the forward pruned rescoring is at
10+
# egs/swbd/s5c/local/rnnlm/run_tdnn_lstm.sh
11+
# One example script for backward RNNLM rescoring is at
12+
# egs/swbd/s5c/local/rnnlm/run_tdnn_lstm_back.sh
13+
14+
# Begin configuration section.
15+
cmd=run.pl
16+
skip_scoring=false
17+
max_ngram_order=4 # Approximate the lattice-rescoring by limiting the max-ngram-order
18+
# if it's set, it merges histories in the lattice if they share
19+
# the same ngram history and this prevents the lattice from
20+
# exploding exponentially. Details of the n-gram approximation
21+
# method are described in section 2.3 of the paper
22+
# http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdm
23+
24+
weight=0.5 # Interpolation weight for RNNLM.
25+
normalize=false # If true, we add a normalization step to the output of the RNNLM
26+
# so that it adds up to *exactly* 1. Note that this is not necessary
27+
# as in our RNNLM setup, a properly trained network would automatically
28+
# have its normalization term close to 1. The details of this
29+
# could be found at http://www.danielpovey.com/files/2018_icassp_rnnlm.pdf
30+
31+
# End configuration section.
32+
33+
echo "$0 $@" # Print the command line for logging
34+
35+
. ./utils/parse_options.sh
36+
37+
if [ $# != 5 ]; then
38+
echo "Does language model rescoring of lattices (remove old LM, add new LM)"
39+
echo "with Kaldi RNNLM trained on reversed text. See comments in file for details"
40+
echo ""
41+
echo "Usage: $0 [options] <old-lang-dir> <rnnlm-dir> \\"
42+
echo " <data-dir> <input-decode-dir> <output-decode-dir>"
43+
echo " e.g.: $0 data/lang_tg exp/rnnlm_lstm/ data/test \\"
44+
echo " exp/tri3/test_rnnlm_forward exp/tri3/test_rnnlm_bidirection"
45+
echo "options: [--cmd (run.pl|queue.pl [queue opts])]"
46+
exit 1;
47+
fi
48+
49+
[ -f path.sh ] && . ./path.sh;
50+
51+
oldlang=$1
52+
rnnlm_dir=$2
53+
data=$3
54+
indir=$4
55+
outdir=$5
56+
57+
oldlm=$oldlang/G.fst
58+
if [ ! -f $oldlm ]; then
59+
echo "$0: file $oldlm not found; using $oldlang/G.carpa"
60+
oldlm=$oldlang/G.carpa
61+
fi
62+
63+
[ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1;
64+
[ ! -f $rnnlm_dir/final.raw ] && echo "$0: Missing file $rnnlm_dir/final.raw" && exit 1;
65+
[ ! -f $rnnlm_dir/feat_embedding.final.mat ] && [ ! -f $rnnlm_dir/word_embedding.final.mat ] && echo "$0: Missing word embedding file" && exit 1;
66+
67+
[ ! -f $oldlang/words.txt ] &&\
68+
echo "$0: Missing file $oldlang/words.txt" && exit 1;
69+
! ls $indir/lat.*.gz >/dev/null &&\
70+
echo "$0: No lattices input directory $indir" && exit 1;
71+
awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) {
72+
print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \
73+
|| exit 1;
74+
75+
normalize_opt=
76+
if $normalize; then
77+
normalize_opt="--normalize-probs=true"
78+
fi
79+
oldlm_command="fstproject --project_output=true $oldlm |"
80+
special_symbol_opts=$(cat $rnnlm_dir/special_symbol_opts.txt)
81+
82+
word_embedding=
83+
if [ -f $rnnlm_dir/word_embedding.final.mat ]; then
84+
word_embedding=$rnnlm_dir/word_embedding.final.mat
85+
else
86+
word_embedding="'rnnlm-get-word-embedding $rnnlm_dir/word_feats.txt $rnnlm_dir/feat_embedding.final.mat -|'"
87+
fi
88+
89+
mkdir -p $outdir/log
90+
nj=`cat $indir/num_jobs` || exit 1;
91+
cp $indir/num_jobs $outdir
92+
93+
# In order to rescore with a backward RNNLM, we first remove the original LM
94+
# scores with lattice-lmrescore, before reversing the lattices
95+
oldlm_weight=$(perl -e "print -1.0 * $weight;")
96+
if [ "$oldlm" == "$oldlang/G.fst" ]; then
97+
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
98+
lattice-lmrescore --lm-scale=$oldlm_weight \
99+
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \
100+
lattice-reverse ark:- ark:- \| \
101+
lattice-lmrescore-kaldi-rnnlm --lm-scale=$weight $special_symbol_opts \
102+
--max-ngram-order=$max_ngram_order $normalize_opt \
103+
$word_embedding "$rnnlm_dir/final.raw" ark:- ark:- \| \
104+
lattice-reverse ark:- "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
105+
else
106+
$cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \
107+
lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \
108+
"ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm" ark:- \| \
109+
lattice-reverse ark:- ark:- \| \
110+
lattice-lmrescore-kaldi-rnnlm --lm-scale=$weight $special_symbol_opts \
111+
--max-ngram-order=$max_ngram_order $normalize_opt \
112+
$word_embedding "$rnnlm_dir/final.raw" ark:- ark:- \| \
113+
lattice-reverse ark:- "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1;
114+
fi
115+
116+
if ! $skip_scoring ; then
117+
err_msg="$0: Not scoring because local/score.sh does not exist or not executable."
118+
[ ! -x local/score.sh ] && echo $err_msg && exit 1;
119+
echo local/score.sh --cmd "$cmd" $data $oldlang $outdir
120+
local/score.sh --cmd "$cmd" $data $oldlang $outdir
121+
else
122+
echo "$0: Not scoring because --skip-scoring was specified."
123+
fi
124+
125+
exit 0;

src/latbin/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
2525
lattice-determinize-phone-pruned-parallel lattice-expand-ngram \
2626
lattice-lmrescore-const-arpa lattice-lmrescore-rnnlm nbest-to-prons \
2727
lattice-arc-post lattice-determinize-non-compact lattice-lmrescore-kaldi-rnnlm \
28-
lattice-lmrescore-pruned lattice-lmrescore-kaldi-rnnlm-pruned
28+
lattice-lmrescore-pruned lattice-lmrescore-kaldi-rnnlm-pruned lattice-reverse
2929

3030
OBJFILES =
3131

src/latbin/lattice-reverse.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// latbin/lattice-reverse.cc
2+
3+
// Copyright 2018 Hainan Xu
4+
5+
// See ../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
21+
#include "base/kaldi-common.h"
22+
#include "util/common-utils.h"
23+
#include "fstext/fstext-lib.h"
24+
#include "lat/kaldi-lattice.h"
25+
26+
int main(int argc, char *argv[]) {
27+
try {
28+
using namespace kaldi;
29+
typedef kaldi::int32 int32;
30+
typedef kaldi::int64 int64;
31+
using fst::SymbolTable;
32+
using fst::VectorFst;
33+
using fst::StdArc;
34+
35+
const char *usage =
36+
"Reverse a lattice in order to rescore the lattice with a RNNLM \n"
37+
"trained reversed text. An example for its application is at \n"
38+
"swbd/local/rnnlm/run_lstm_tdnn_back.sh\n"
39+
"Usage: lattice-reverse lattice-rspecifier lattice-wspecifier\n"
40+
" e.g.: lattice-reverse ark:forward.lats ark:backward.lats\n";
41+
42+
ParseOptions po(usage);
43+
std::string include_rxfilename;
44+
std::string exclude_rxfilename;
45+
46+
po.Read(argc, argv);
47+
48+
if (po.NumArgs() != 2) {
49+
po.PrintUsage();
50+
exit(1);
51+
}
52+
53+
std::string lats_rspecifier = po.GetArg(1),
54+
lats_wspecifier = po.GetArg(2);
55+
56+
int32 n_done = 0;
57+
58+
SequentialLatticeReader lattice_reader(lats_rspecifier);
59+
LatticeWriter lattice_writer(lats_wspecifier);
60+
61+
for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++) {
62+
string key = lattice_reader.Key();
63+
Lattice &lat = lattice_reader.Value();
64+
Lattice olat;
65+
fst::Reverse(lat, &olat);
66+
lattice_writer.Write(lattice_reader.Key(), olat);
67+
}
68+
69+
KALDI_LOG << "Done reversing " << n_done << " lattices.";
70+
71+
return (n_done != 0 ? 0 : 1);
72+
} catch(const std::exception &e) {
73+
std::cerr << e.what();
74+
return -1;
75+
}
76+
}

0 commit comments

Comments
 (0)