Skip to content

Commit c9afec3

Browse files
committed
Simplify code after changing constructor of Table classes; unify and streamline decoder interfaces; minor fixes to scripts.
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@205 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
1 parent 48f969a commit c9afec3

34 files changed

+278
-367
lines changed

egs/rm/s1/RESULTS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ exp/decode_tri2k_regtree_fmllr/wer:Average WER is 2.688901 (337 / 12533) # +reg
6666
exp/decode_tri2l/wer:Average WER is 2.704859 (339 / 12533) # Splice-9-frames + LDA + MLLT + SAT (fMLLR in test)
6767
exp/decode_tri2l_utt/wer:Average WER is 4.930982 (618 / 12533) # [ as decode_tri2l but per-utt in test. ]
6868

69+
# linear-VTLN on top of LDA+MLLT features.
70+
exp/decode_tri2m/wer:Average WER is 3.223490 (404 / 12533) # offset-only transform after VTLN part of LVTLN
71+
exp/decode_tri2m_diag/wer:Average WER is 3.119764 (391 / 12533) # diagonal transform after VTLN part of LVTLN
72+
exp/decode_tri2m_diag_fmllr/wer:Average WER is 2.784649 (349 / 12533) # + fMLLR
73+
exp/decode_tri2m_diag_utt/wer:Average WER is 3.279343 (411 / 12533) # [per-utt]
74+
exp/decode_tri2m_vtln/wer:Average WER is 4.747467 (595 / 12533) # feature-space VTLN, plus offset-only transform
75+
# (for some reason it failed)
76+
exp/decode_tri2m_vtln_diag/wer:Average WER is 3.087848 (387 / 12533) # + diagonal transform
77+
exp/decode_tri2m_vtln_diag_utt/wer:Average WER is 4.340541 (544 / 12533) # [per-utterance]
78+
exp/decode_tri2m_vtln_nofmllr/wer:Average WER is 5.784728 (725 / 12533) # feature-space VTLN, with no fMLLR
79+
80+
6981
# sgmma is SGMM without speaker vectors.
7082
exp/decode_sgmma/wer:Average WER is 3.319237 (416 / 12533)
7183
exp/decode_sgmma_fmllr/wer:Average WER is 2.934308 (289 / 9849)

egs/rm/s1/run.sh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ steps/make_mfcc_test.sh $mfccdir
5959
steps/train_mono.sh
6060
steps/decode_mono.sh &
6161
steps/train_tri1.sh
62-
(steps/decode_tri1.sh ; steps/decode_tri1_fmllr.sh; steps/decode_tri1_regtree_fmllr.sh ) &
62+
(steps/decode_tri1.sh ; steps/decode_tri1_fmllr.sh; steps/decode_tri1_regtree_fmllr.sh ; steps/decode_tri1_latgen.sh) &
6363

6464
steps/train_tri2a.sh
65-
(steps/decode_tri2a.sh ; steps/decode_tri2a_fmllr.sh; steps/decode_tri2a_fmllr_utt.sh )&
65+
(steps/decode_tri2a.sh ; steps/decode_tri2a_fmllr.sh; steps/decode_tri2a_fmllr_utt.sh ;
66+
steps/decode_tri2a_dfmllr.sh; steps/decode_tri2a_dfmllr_fmllr.sh;
67+
steps/decode_tri2a_dfmllr_utt.sh;
68+
)&
6669

6770

6871
# Then do the same for 2b, 2c, and so on
@@ -95,7 +98,7 @@ done
9598

9699
# To train and test SGMM systems:
97100

98-
steps/train_ubma.sh
101+
99102

100103
# note: if the SGMM decoding is too slow, aside from playing
101104
# with decoder beams and max-leaves, you can set e.g.
@@ -107,6 +110,9 @@ steps/train_ubma.sh
107110
# You can take this all the way down to 1 for fastest speed, although
108111
# this will degrade results.
109112

113+
114+
steps/train_ubma.sh
115+
110116
(steps/train_sgmma.sh; steps/decode_sgmma.sh; steps/decode_sgmma_fmllr.sh;
111117
steps/decode_sgmma_fmllr_utt.sh; steps/train_sgmma_fmllrbasis.sh;
112118
steps/decode_sgmma_fmllrbasis_utt.sh )&
@@ -120,6 +126,9 @@ steps/train_ubma.sh
120126
# as sgmmb but with LDA+STC features.
121127
(steps/train_ubmc.sh; steps/train_sgmmd.sh; steps/decode_sgmmd.sh; steps/decode_sgmmd_fmllr.sh )&
122128

129+
(steps/train_ubmd.sh; steps/train_sgmme.sh; steps/decode_sgmme.sh; steps/decode_sgmme_fmllr.sh;
130+
steps/decode_sgmme_latgen.sh )&
131+
123132

124133

125134

egs/rm/s1/steps/decode_tri2g_diag_fmllr.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dir=exp/decode_tri2g_diag_fmllr
2121
mkdir -p $dir
2222
model=exp/tri2g/final.mdl
2323
alignmodel=exp/tri2g/final.alimdl
24+
lvtln=exp/tri2g/final.lvtln
2425
tree=exp/tri2g/tree
2526
graphdir=exp/graph_tri2g
2627
silphones=`cat data/silphones.csl`

egs/rm/s1/steps/decode_tri2m_vtln_nofmllr.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
if [ -f path.sh ]; then . path.sh; fi
2222
dir=exp/decode_tri2m_vtln_nofmllr
2323
mkdir -p $dir
24+
mat=exp/tri2f/final.mat
2425
vtlnmodel=exp/tri2m/final.vtlnmdl
2526
lvtlnmodel=exp/tri2m/final.mdl
2627
alignmodel=exp/tri2m/final.alimdl

src/bin/decode-faster-mapped.cc

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ int main(int argc, char *argv[]) {
4141
ParseOptions po(usage);
4242
bool binary = false;
4343
BaseFloat acoustic_scale = 0.1;
44-
44+
bool allow_partial = true;
4545
std::string word_syms_filename;
4646
FasterDecoderOptions decoder_opts;
4747
decoder_opts.Register(&po, true); // true == include obscure settings.
4848
po.Register("binary", &binary, "Write output in binary mode");
4949
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods");
50-
50+
po.Register("allow-partial", &allow_partial, "Produce output even when final state was not reached");
5151
po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]");
5252

5353
po.Read(argc, argv);
@@ -72,10 +72,7 @@ int main(int argc, char *argv[]) {
7272

7373
Int32VectorWriter words_writer(words_wspecifier);
7474

75-
Int32VectorWriter alignment_writer;
76-
if (alignment_wspecifier != "")
77-
if (!alignment_writer.Open(alignment_wspecifier))
78-
KALDI_ERR << "Failed to open alignments output.";
75+
Int32VectorWriter alignment_writer(alignment_wspecifier);
7976

8077
fst::SymbolTable *word_syms = NULL;
8178
if (word_syms_filename != "") {
@@ -123,15 +120,13 @@ int main(int argc, char *argv[]) {
123120
decoder.Decode(&decodable);
124121

125122
VectorFst<StdArc> decoded; // linear FST.
126-
bool saw_endstate = decoder.GetOutput(true, // consider only final states.
127-
&decoded);
128123

129-
if (saw_endstate || decoder.GetOutput(false,
130-
&decoded)) {
124+
if ( (allow_partial || decoder.ReachedFinal())
125+
&& decoder.GetBestPath(&decoded) ) {
131126
num_success++;
132-
if (!saw_endstate) {
127+
if (!decoder.ReachedFinal())
133128
KALDI_WARN << "Decoder did not reach end-state, outputting partial traceback.";
134-
}
129+
135130
std::vector<int32> alignment;
136131
std::vector<int32> words;
137132
StdArc::Weight weight;

src/bin/decode-faster.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ int main(int argc, char *argv[]) {
3939
ParseOptions po(usage);
4040
bool binary = false;
4141
BaseFloat acoustic_scale = 0.1;
42-
42+
bool allow_partial = true;
4343
std::string word_syms_filename;
4444
FasterDecoderOptions decoder_opts;
4545
decoder_opts.Register(&po, true); // true == include obscure settings.
4646
po.Register("binary", &binary, "Write output in binary mode");
47+
po.Register("allow-partial", &allow_partial, "Produce output even when final state was not reached");
4748
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods");
4849
po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]");
4950

@@ -61,10 +62,7 @@ int main(int argc, char *argv[]) {
6162

6263
Int32VectorWriter words_writer(words_wspecifier);
6364

64-
Int32VectorWriter alignment_writer;
65-
if (alignment_wspecifier != "")
66-
if (!alignment_writer.Open(alignment_wspecifier))
67-
KALDI_ERR << "Failed to open alignments output.";
65+
Int32VectorWriter alignment_writer(alignment_wspecifier);
6866

6967
fst::SymbolTable *word_syms = NULL;
7068
if (word_syms_filename != "") {
@@ -112,15 +110,13 @@ int main(int argc, char *argv[]) {
112110
decoder.Decode(&decodable);
113111

114112
VectorFst<StdArc> decoded; // linear FST.
115-
bool saw_endstate = decoder.GetOutput(true, // consider only final states.
116-
&decoded);
117113

118-
if (saw_endstate || decoder.GetOutput(false,
119-
&decoded)) {
114+
if ( (allow_partial || decoder.ReachedFinal())
115+
&& decoder.GetBestPath(&decoded) ) {
120116
num_success++;
121-
if (!saw_endstate) {
117+
if (!decoder.ReachedFinal())
122118
KALDI_WARN << "Decoder did not reach end-state, outputting partial traceback.";
123-
}
119+
124120
std::vector<int32> alignment;
125121
std::vector<int32> words;
126122
StdArc::Weight weight;

src/decoder/faster-decoder.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,25 @@ class FasterDecoder {
9797
}
9898
}
9999

100-
bool GetOutput(bool is_final, fst::MutableFst<fst::StdArc> *fst_out) {
101-
// GetOutput gets the decoding output. If is_final == true, it limits itself
100+
bool ReachedFinal() {
101+
Weight best_weight = Weight::Zero();
102+
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
103+
Weight this_weight = Times(e->val->weight, fst_.Final(e->key));
104+
if (this_weight != Weight::Zero())
105+
return true;
106+
}
107+
return false;
108+
}
109+
110+
bool GetBestPath(fst::MutableFst<fst::StdArc> *fst_out) {
111+
// GetBestPath gets the decoding output. If is_final == true, it limits itself
102112
// to final states; otherwise it gets the most likely token not taking into
103113
// account final-probs. fst_out will be empty (Start() == kNoStateId) if
104114
// nothing was available. It returns true if it got output (thus, fst_out
105115
// will be nonempty).
106116
fst_out->DeleteStates();
107117
Token *best_tok = NULL;
118+
bool is_final = ReachedFinal();
108119
if (!is_final) {
109120
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail)
110121
if (best_tok == NULL || *best_tok < *(e->val) )

src/decoder/lattice-simple-decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class LatticeSimpleDecoder {
120120

121121
// Outputs an FST corresponding to the single best path
122122
// through the lattice.
123-
bool GetTraceback(fst::MutableFst<LatticeArc> *ofst) const {
123+
bool GetBestPath(fst::MutableFst<LatticeArc> *ofst) const {
124124
fst::VectorFst<LatticeArc> fst;
125125
if (!GetRawLattice(&fst)) return false;
126126
// std::cout << "Raw lattice is:\n";

src/decoder/simple-decoder.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class SimpleDecoder {
5252
ClearToks(prev_toks_);
5353
}
5454

55-
void Decode(DecodableInterface *decodable) {
55+
// Returns true if any tokens reached the end of the file (regardless of
56+
// whether they are in a final state); query ReachedFinal() after Decode()
57+
// to see whether we reached a final state.
58+
bool Decode(DecodableInterface *decodable) {
5659
// clean up from last time:
5760
ClearToks(cur_toks_);
5861
ClearToks(prev_toks_);
@@ -68,15 +71,30 @@ class SimpleDecoder {
6871
ProcessNonemitting();
6972
PruneToks(beam_, &cur_toks_);
7073
}
74+
return (!cur_toks_.empty());
7175
}
7276

73-
bool GetOutput(bool is_final, fst::MutableFst<fst::StdArc> *fst_out) {
74-
// GetOutput gets the decoding output. If is_final == true, it limits itself to final states;
75-
// otherwise it gets the most likely token not taking into account final-probs.
76-
// fst_out will be empty (Start() == kNoStateId) if nothing was available.
77-
// It returns true if it got output (thus, fst_out will be nonempty).
77+
bool ReachedFinal() {
78+
Weight best_weight = Weight::Zero();
79+
for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
80+
iter != cur_toks_.end();
81+
++iter) {
82+
Weight this_weight = Times(iter->second->weight_, fst_.Final(iter->first));
83+
if (this_weight != Weight::Zero())
84+
return true;
85+
}
86+
return false;
87+
}
88+
89+
// GetBestPath gets the decoding traceback. If we reached a final state,
90+
// it limits itself to final states;
91+
// otherwise it gets the most likely token not taking into account final-probs.
92+
// fst_out will be empty (Start() == kNoStateId) if nothing was available.
93+
// If Decode() returned true, it is safe to assume GetBestPath will return true.
94+
bool GetBestPath(fst::MutableFst<fst::StdArc> *fst_out) {
7895
fst_out->DeleteStates();
7996
Token *best_tok = NULL;
97+
bool is_final = ReachedFinal();
8098
if (!is_final) {
8199
for (unordered_map<StateId, Token*>::iterator iter = cur_toks_.begin();
82100
iter != cur_toks_.end();

src/featbin/compute-cmvn-stats.cc

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ int main(int argc, char *argv[]) {
2727

2828
const char *usage =
2929
"Compute cepstral mean and variance normalization statistics\n"
30-
"Per-utterance by default, or per-speaker if spk2utt option provided\n"
31-
"Usage: compute-cmvn-stats [options] feats-rspecifier stats-wspecifier\n";
32-
30+
"If wspecifier provided: per-utterance by default, or per-speaker if\n"
31+
"spk2utt option provided; if wxfilename: global\n"
32+
"Usage: compute-cmvn-stats [options] feats-rspecifier (stats-wspecifier|stats-wxfilename)\n";
33+
3334
ParseOptions po(usage);
3435
std::string spk2utt_rspecifier;
36+
bool binary = false;
3537
po.Register("spk2utt", &spk2utt_rspecifier, "rspecifier for speaker to utterance-list map");
36-
38+
po.Register("binary", &binary, "write in binary mode (applies only to global CMN/CVN)");
3739
po.Read(argc, argv);
3840

3941
if (po.NumArgs() != 2) {
@@ -42,45 +44,66 @@ int main(int argc, char *argv[]) {
4244
}
4345

4446
std::string rspecifier = po.GetArg(1);
45-
std::string wspecifier = po.GetArg(2);
47+
std::string wspecifier_or_wxfilename = po.GetArg(2);
48+
49+
if (ClassifyWspecifier(wspecifier_or_wxfilename, NULL, NULL, NULL)
50+
!= kNoWspecifier) { // writing to a Table: per-speaker or per-utt CMN/CVN.
51+
std::string wspecifier = wspecifier_or_wxfilename;
4652

47-
DoubleMatrixWriter writer(wspecifier);
53+
DoubleMatrixWriter writer(wspecifier);
4854

49-
if (spk2utt_rspecifier != "") {
50-
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
51-
RandomAccessBaseFloatMatrixReader feat_reader(rspecifier);
52-
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
53-
std::string spk = spk2utt_reader.Key();
54-
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
55-
bool is_init = false;
56-
Matrix<double> stats;
57-
for (size_t i = 0; i < uttlist.size(); i++) {
58-
std::string utt = uttlist[i];
59-
if (!feat_reader.HasKey(utt))
60-
KALDI_WARN << "Did not find features for utterance " << utt;
61-
else {
62-
const Matrix<BaseFloat> &feats = feat_reader.Value(utt);
63-
if (!is_init) {
64-
InitCmvnStats(feats.NumCols(), &stats);
65-
is_init = true;
55+
if (spk2utt_rspecifier != "") {
56+
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
57+
RandomAccessBaseFloatMatrixReader feat_reader(rspecifier);
58+
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
59+
std::string spk = spk2utt_reader.Key();
60+
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
61+
bool is_init = false;
62+
Matrix<double> stats;
63+
for (size_t i = 0; i < uttlist.size(); i++) {
64+
std::string utt = uttlist[i];
65+
if (!feat_reader.HasKey(utt))
66+
KALDI_WARN << "Did not find features for utterance " << utt;
67+
else {
68+
const Matrix<BaseFloat> &feats = feat_reader.Value(utt);
69+
if (!is_init) {
70+
InitCmvnStats(feats.NumCols(), &stats);
71+
is_init = true;
72+
}
73+
AccCmvnStats(feats, NULL, &stats);
6674
}
67-
AccCmvnStats(feats, NULL, &stats);
6875
}
76+
if (stats.NumRows() == 0)
77+
KALDI_WARN << "No stats accumulated for speaker " << spk;
78+
else
79+
writer.Write(spk, stats);
80+
}
81+
} else { // per-utterance normalization
82+
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
83+
for (; !feat_reader.Done(); feat_reader.Next()) {
84+
Matrix<double> stats;
85+
const Matrix<BaseFloat> &feats = feat_reader.Value();
86+
InitCmvnStats(feats.NumCols(), &stats);
87+
AccCmvnStats(feats, NULL, &stats);
88+
writer.Write(feat_reader.Key(), stats);
6989
}
70-
if (stats.NumRows() == 0)
71-
KALDI_WARN << "No stats accumulated for speaker " << spk;
72-
else
73-
writer.Write(spk, stats);
7490
}
75-
} else { // per-utterance normalization
91+
} else { // accumulate global stats
92+
std::string wxfilename = wspecifier_or_wxfilename;
93+
bool is_init = false;
94+
Matrix<double> stats;
7695
SequentialBaseFloatMatrixReader feat_reader(rspecifier);
7796
for (; !feat_reader.Done(); feat_reader.Next()) {
78-
Matrix<double> stats;
7997
const Matrix<BaseFloat> &feats = feat_reader.Value();
80-
InitCmvnStats(feats.NumCols(), &stats);
98+
if (!is_init) {
99+
InitCmvnStats(feats.NumCols(), &stats);
100+
is_init = true;
101+
}
81102
AccCmvnStats(feats, NULL, &stats);
82-
writer.Write(feat_reader.Key(), stats);
83103
}
104+
Matrix<float> fstats(stats);
105+
Output ko(wxfilename, binary);
106+
fstats.Write(ko.Stream(), binary);
84107
}
85108
return 0;
86109
} catch(const std::exception& e) {

0 commit comments

Comments
 (0)