Skip to content

Commit 43136eb

Browse files
committed
update for lattice output
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@291 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
1 parent 17478bd commit 43136eb

File tree

2 files changed

+51
-68
lines changed

2 files changed

+51
-68
lines changed

src/decoder/faster-decoder.h

Lines changed: 25 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,10 @@
2323
#include "util/hash-list.h"
2424
#include "fst/fstlib.h"
2525
#include "itf/decodable-itf.h"
26+
#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc
2627

2728
namespace kaldi {
2829

29-
// macros to switch off all debugging messages without runtime cost
30-
#define DEBUG_CMD(x) x;
31-
#define DEBUG_OUT3(x) KALDI_VLOG(3) << x;
32-
#define DEBUG_OUT2(x) KALDI_VLOG(2) << x;
33-
#define DEBUG_OUT1(x) KALDI_VLOG(1) << x;
34-
//#define DEBUG_OUT1(x)
35-
//#define DEBUG_OUT2(x)
36-
//#define DEBUG_OUT3(x)
37-
//#define DEBUG_CMD(x)
38-
3930
struct FasterDecoderOptions {
4031
BaseFloat beam;
4132
int32 max_active;
@@ -83,15 +74,11 @@ class FasterDecoder {
8374
// clean up from last time:
8475
ClearToks(toks_.Clear());
8576
StateId start_state = fst_.Start();
86-
DEBUG_OUT2("Initial state: " << start_state)
8777
assert(start_state != fst::kNoStateId);
8878
Arc dummy_arc(0, 0, Weight::One(), start_state);
8979
toks_.Insert(start_state, new Token(dummy_arc, NULL));
9080
ProcessNonemitting(std::numeric_limits<float>::max());
9181
for (int32 frame = 0; !decodable->IsLastFrame(frame-1); frame++) {
92-
DEBUG_OUT1("==== FRAME " << frame << " =====")
93-
if ((frame%50) == 0)
94-
KALDI_VLOG(2) << "==== FRAME " << frame << " =====";
9582
BaseFloat adaptive_beam = ProcessEmitting(decodable, frame);
9683
ProcessNonemitting(adaptive_beam);
9784
}
@@ -107,7 +94,7 @@ class FasterDecoder {
10794
return false;
10895
}
10996

110-
bool GetBestPath(fst::MutableFst<fst::StdArc> *fst_out) {
97+
bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out) {
11198
// GetBestPath gets the decoding output. If is_final == true, it limits itself
11299
// to final states; otherwise it gets the most likely token not taking into
113100
// account final-probs. fst_out will be empty (Start() == kNoStateId) if
@@ -124,7 +111,6 @@ class FasterDecoder {
124111
Weight best_weight = Weight::Zero();
125112
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
126113
Weight this_weight = Times(e->val->weight, fst_.Final(e->key));
127-
if (this_weight != Weight::Zero()) DEBUG_OUT1("final state reached: " << e->key << " path weight:" << this_weight)
128114
if (this_weight != Weight::Zero() &&
129115
this_weight.Value() < best_weight.Value()) {
130116
best_weight = this_weight;
@@ -133,29 +119,37 @@ class FasterDecoder {
133119
}
134120
}
135121
if (best_tok == NULL) return false; // No output.
136-
DEBUG_OUT1("best final token: path weight:" << best_tok->weight)
137-
138122

139-
std::vector<Arc> arcs_reverse; // arcs in reverse order.
123+
std::vector<LatticeArc> arcs_reverse; // arcs in reverse order.
140124
for (Token *tok = best_tok; tok != NULL; tok = tok->prev_) {
141-
arcs_reverse.push_back(tok->arc_);
125+
BaseFloat amscore = tok->weight_a.Value(),
126+
lmscore = tok->arc_.weight.Value() - amscore;
127+
LatticeArc l_arc(tok->arc_.ilabel,
128+
tok->arc_.olabel,
129+
LatticeWeight(lmscore, amscore),
130+
tok->arc_.nextstate);
131+
arcs_reverse.push_back(l_arc);
142132
}
143133
assert(arcs_reverse.back().nextstate == fst_.Start());
144134
arcs_reverse.pop_back(); // that was a "fake" token... gives no info.
145135

146136
StateId cur_state = fst_out->AddState();
147137
fst_out->SetStart(cur_state);
148138
for (ssize_t i = static_cast<ssize_t>(arcs_reverse.size())-1; i >= 0; i--) {
149-
Arc arc = arcs_reverse[i];
139+
LatticeArc arc = arcs_reverse[i];
150140
arc.nextstate = fst_out->AddState();
151141
fst_out->AddArc(cur_state, arc);
152-
DEBUG_OUT1("arc: " << arc.ilabel << " : " << arc.olabel)
153142
cur_state = arc.nextstate;
154143
}
155-
if (is_final)
156-
fst_out->SetFinal(cur_state, fst_.Final(best_tok->arc_.nextstate));
157-
else
158-
fst_out->SetFinal(cur_state, Weight::One());
144+
if (is_final) {
145+
fst_out->SetFinal(cur_state,
146+
LatticeWeight(fst_.Final(best_tok->arc_.nextstate).Value(), 0.0));
147+
} else {
148+
fst_out->SetFinal(cur_state, LatticeWeight::One());
149+
}
150+
KALDI_LOG << "best path, final:" << fst_.Final(best_tok->arc_.nextstate);
151+
//fst::FstPrinter<LatticeArc> fstprinter(*fst_out, NULL, NULL, NULL, false, true);
152+
//fstprinter.Print(&std::cout, "standard output");
159153
RemoveEpsLocal(fst_out);
160154
return true;
161155
}
@@ -168,18 +162,15 @@ class FasterDecoder {
168162
Token *prev_;
169163
int32 ref_count_;
170164
Weight weight;
165+
Weight weight_a;
171166
inline Token(Arc &arc, Token *prev): arc_(arc), prev_(prev), ref_count_(1) {
172-
DEBUG_OUT2("advance: " << arc.nextstate << " " << arc.ilabel << ":"
173-
<< arc.olabel << "/" << arc.weight)
174-
DEBUG_OUT3("create t")
175167
if (prev) {
176-
DEBUG_OUT3("inc t(" << prev->weight << "):" << prev->ref_count_ )
177168
prev->ref_count_++;
178169
weight = Times(prev->weight, arc.weight);
179170
} else {
180171
weight = arc.weight;
181172
}
182-
DEBUG_OUT3("new weight t:" << weight)
173+
weight_a = Weight::One();
183174
}
184175
inline bool operator < (const Token &other) {
185176
return weight.Value() > other.weight.Value();
@@ -192,11 +183,9 @@ class FasterDecoder {
192183
}
193184
inline static void TokenDelete(Token *tok) {
194185
if (tok->ref_count_ == 1) {
195-
DEBUG_OUT3( "kill t" )
196186
delete tok;
197187
} else {
198188
tok->ref_count_--;
199-
DEBUG_OUT3("dec t:" << tok->ref_count_)
200189
}
201190
}
202191
};
@@ -206,7 +195,6 @@ class FasterDecoder {
206195
/// Gets the weight cutoff. Also counts the active tokens.
207196
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
208197
BaseFloat *adaptive_beam, Elem **best_elem) {
209-
DEBUG_OUT1("GetCufoff")
210198
BaseFloat best_weight = 1.0e+10; // positive == high cost == bad.
211199
size_t count = 0;
212200
if (opts_.max_active == std::numeric_limits<int32>::max()) {
@@ -219,7 +207,6 @@ class FasterDecoder {
219207
}
220208
if (tok_count != NULL) *tok_count = count;
221209
if (adaptive_beam != NULL) *adaptive_beam = opts_.beam;
222-
DEBUG_OUT1("count:" << *tok_count << " best:" << best_weight << " cutoff:" << best_weight + opts_.beam << " adaptive:" << *adaptive_beam)
223210
return best_weight + opts_.beam;
224211
} else {
225212
tmp_array_.clear();
@@ -234,7 +221,6 @@ class FasterDecoder {
234221
if (tok_count != NULL) *tok_count = count;
235222
if (tmp_array_.size() <= static_cast<size_t>(opts_.max_active)) {
236223
if (adaptive_beam) *adaptive_beam = opts_.beam;
237-
DEBUG_OUT1("count:" << *tok_count << " best:" << best_weight << " cutoff:" << best_weight + opts_.beam << " adaptive:" << *adaptive_beam)
238224
return best_weight + opts_.beam;
239225
} else {
240226
// the lowest elements (lowest costs, highest likes)
@@ -248,7 +234,6 @@ class FasterDecoder {
248234
if (adaptive_beam)
249235
*adaptive_beam = std::min(opts_.beam,
250236
ans - best_weight + opts_.beam_delta);
251-
DEBUG_OUT1("count:" << *tok_count << " best:" << best_weight << " cutoff:" << ans << " adaptive:" << *adaptive_beam)
252237
return ans;
253238
}
254239
}
@@ -259,13 +244,11 @@ class FasterDecoder {
259244
* opts_.hash_ratio);
260245
if (new_sz > toks_.Size()) {
261246
toks_.SetSize(new_sz);
262-
DEBUG_OUT1("resize hash:" << new_sz)
263247
}
264248
}
265249

266250
// ProcessEmitting returns the likelihood cutoff used.
267251
BaseFloat ProcessEmitting(DecodableInterface *decodable, int frame) {
268-
DEBUG_OUT1("PropagateEmitting")
269252
Elem *last_toks = toks_.Clear();
270253
size_t tok_cnt;
271254
BaseFloat adaptive_beam;
@@ -308,7 +291,6 @@ class FasterDecoder {
308291
// because we delete "e" as we go.
309292
StateId state = e->key;
310293
Token *tok = e->val;
311-
DEBUG_OUT2("get token: " << " state:" << state << " weight:" << tok->weight)
312294
if (tok->weight.Value() < weight_cutoff) { // not pruned.
313295
// np++;
314296
assert(state == tok->arc_.nextstate);
@@ -317,32 +299,23 @@ class FasterDecoder {
317299
aiter.Next()) {
318300
Arc arc = aiter.Value();
319301
if (arc.ilabel != 0) { // propagate..
320-
arc.weight = Times(arc.weight,
321-
Weight(- decodable->LogLikelihood(frame, arc.ilabel)));
322-
DEBUG_OUT2("acoustic: " <<
323-
Weight(- decodable->LogLikelihood(frame, arc.ilabel)))
302+
Weight amscore(- decodable->LogLikelihood(frame, arc.ilabel));
303+
arc.weight = Times(arc.weight, amscore);
324304
BaseFloat new_weight = arc.weight.Value() + tok->weight.Value();
325305
if (new_weight < next_weight_cutoff) { // not pruned..
326306
Token *new_tok = new Token(arc, tok);
307+
new_tok->weight_a = amscore;
327308
Elem *e_found = toks_.Find(arc.nextstate);
328309
if (e_found == NULL) {
329-
DEBUG_OUT2("insert to: " << arc.nextstate)
330310
toks_.Insert(arc.nextstate, new_tok);
331311
} else {
332-
DEBUG_OUT2("combine: " << arc.nextstate)
333-
DEBUG_OUT2("combine: " << e_found->val->weight)
334-
DEBUG_OUT2("with: " << new_tok->weight)
335312
if ( *(e_found->val) < *new_tok ) {
336-
DEBUG_OUT2("delete first")
337313
Token::TokenDelete(e_found->val);
338314
e_found->val = new_tok;
339315
} else {
340-
DEBUG_OUT2("delete second")
341316
Token::TokenDelete(new_tok);
342317
}
343318
}
344-
} else {
345-
DEBUG_OUT2("prune")
346319
}
347320
}
348321
}
@@ -358,24 +331,20 @@ class FasterDecoder {
358331
void ProcessNonemitting(BaseFloat adaptive_beam) {
359332
// Processes nonemitting arcs for one frame. Propagates within
360333
// cur_toks_.
361-
DEBUG_OUT1("PropagateEpsilon")
362334
assert(queue_.empty());
363335
float best_weight = 1.0e+10;
364336
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
365337
queue_.push_back(e->key);
366338
best_weight = std::min(best_weight, e->val->weight.Value());
367339
}
368340
BaseFloat cutoff = best_weight + adaptive_beam;
369-
DEBUG_OUT1("queue:" << queue_.size() << " best:" << best_weight << " cutoff:" << cutoff)
370341

371342
while (!queue_.empty()) {
372343
StateId state = queue_.back();
373344
queue_.pop_back();
374345
Token *tok = toks_.Find(state)->val; // would segfault if state not
375-
DEBUG_OUT2("pop token: state:" << state << " weight:" << tok->weight)
376346
// in toks_ but this can't happen.
377347
if (tok->weight.Value() > cutoff) { // Don't bother processing successors.
378-
DEBUG_OUT2("prune")
379348
continue;
380349
}
381350
assert(tok != NULL && state == tok->arc_.nextstate);
@@ -390,20 +359,14 @@ class FasterDecoder {
390359
} else {
391360
Elem *e_found = toks_.Find(arc.nextstate);
392361
if (e_found == NULL) {
393-
DEBUG_OUT2("insert/queue to: " << arc.nextstate)
394362
toks_.Insert(arc.nextstate, new_tok);
395363
queue_.push_back(arc.nextstate);
396364
} else {
397-
DEBUG_OUT2("combine: " << arc.nextstate)
398-
DEBUG_OUT2("combine: " << e_found->val->weight)
399-
DEBUG_OUT2("with: " << new_tok->weight)
400365
if ( *(e_found->val) < *new_tok ) {
401-
DEBUG_OUT2("delete first")
402366
Token::TokenDelete(e_found->val);
403367
e_found->val = new_tok;
404368
queue_.push_back(arc.nextstate);
405369
} else {
406-
DEBUG_OUT2("delete second")
407370
Token::TokenDelete(new_tok);
408371
}
409372
}

src/gmmbin/gmm-decode-faster.cc

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "decoder/faster-decoder.h"
2525
#include "decoder/decodable-am-diag-gmm.h"
2626
#include "util/timer.h"
27+
#include "lat/kaldi-lattice.h" // for {Compact}LatticeArc
2728

2829
using namespace kaldi;
2930

@@ -66,7 +67,7 @@ int main(int argc, char *argv[]) {
6667

6768
const char *usage =
6869
"Decode features using GMM-based model.\n"
69-
"Usage: gmm-decode-faster [options] model-in fst-in features-rspecifier words-wspecifier [alignments-wspecifier]\n";
70+
"Usage: gmm-decode-faster [options] model-in fst-in features-rspecifier words-wspecifier [alignments-wspecifier lattice-wspecifier]\n";
7071
ParseOptions po(usage);
7172
bool allow_partial = true;
7273
BaseFloat acoustic_scale = 0.1;
@@ -82,7 +83,7 @@ int main(int argc, char *argv[]) {
8283
"Produce output even when final state was not reached");
8384
po.Read(argc, argv);
8485

85-
if (po.NumArgs() < 4 || po.NumArgs() > 5) {
86+
if (po.NumArgs() < 4 || po.NumArgs() > 6) {
8687
po.PrintUsage();
8788
exit(1);
8889
}
@@ -91,7 +92,8 @@ int main(int argc, char *argv[]) {
9192
fst_in_filename = po.GetArg(2),
9293
feature_rspecifier = po.GetArg(3),
9394
words_wspecifier = po.GetArg(4),
94-
alignment_wspecifier = po.GetOptArg(5);
95+
alignment_wspecifier = po.GetArg(5),
96+
lattice_wspecifier = po.GetOptArg(6);
9597

9698
TransitionModel trans_model;
9799
AmDiagGmm am_gmm;
@@ -106,6 +108,17 @@ int main(int argc, char *argv[]) {
106108

107109
Int32VectorWriter alignment_writer(alignment_wspecifier);
108110

111+
LatticeWriter lattice_writer;
112+
bool write_lattices = false;
113+
if (lattice_wspecifier != "") {
114+
if (lattice_writer.Open(lattice_wspecifier)) {
115+
write_lattices = true;
116+
} else {
117+
KALDI_WARN << "Could not open table for writing lattices: "
118+
<< lattice_wspecifier;
119+
}
120+
}
121+
109122
fst::SymbolTable *word_syms = NULL;
110123
if (word_syms_filename != "")
111124
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
@@ -142,7 +155,7 @@ int main(int argc, char *argv[]) {
142155
acoustic_scale);
143156
decoder.Decode(&gmm_decodable);
144157

145-
fst::VectorFst<fst::StdArc> decoded; // linear FST.
158+
fst::VectorFst<LatticeArc> decoded; // linear FST.
146159

147160
if ( (allow_partial || decoder.ReachedFinal())
148161
&& decoder.GetBestPath(&decoded) ) {
@@ -154,14 +167,21 @@ int main(int argc, char *argv[]) {
154167
KALDI_WARN << "Decoder did not reach end-state, outputting partial traceback.";
155168
std::vector<int32> alignment;
156169
std::vector<int32> words;
157-
fst::StdArc::Weight weight;
170+
LatticeWeight weight;
158171
frame_count += features.NumRows();
159172

160173
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
161174

162175
words_writer.Write(key, words);
163176
if (alignment_writer.IsOpen())
164177
alignment_writer.Write(key, alignment);
178+
179+
if (write_lattices) {
180+
//if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling
181+
// fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &decoded);
182+
lattice_writer.Write(key, decoded);
183+
}
184+
165185
if (word_syms != NULL) {
166186
std::cerr << key << ' ';
167187
for (size_t i = 0; i < words.size(); i++) {
@@ -172,7 +192,7 @@ int main(int argc, char *argv[]) {
172192
}
173193
std::cerr << '\n';
174194
}
175-
BaseFloat like = -weight.Value();
195+
BaseFloat like = - weight.Value1() - weight.Value2(); // Times
176196
tot_like += like;
177197
KALDI_LOG << "Log-like per frame for utterance " << key << " is "
178198
<< (like / features.NumRows()) << " over "

0 commit comments

Comments
 (0)