Skip to content

Commit 77e28ee

Browse files
committed
sandbox/online: various decoder changes (moving decoders to 'new interface')
git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/online@3362 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
1 parent e2f0467 commit 77e28ee

18 files changed

+327
-149
lines changed

src/decoder/biglm-faster-decoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class BiglmFasterDecoder {
9696
}
9797

9898
bool ReachedFinal() {
99-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
99+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
100100
PairId state_pair = e->key;
101101
StateId state = PairToState(state_pair),
102102
lm_state = PairToLmState(state_pair);
@@ -121,12 +121,12 @@ class BiglmFasterDecoder {
121121
// to the best final token (i.e. the one with best weight best_weight, below).
122122
bool is_final = ReachedFinal();
123123
if (!is_final) {
124-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail)
124+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
125125
if (best_tok == NULL || *best_tok < *(e->val) )
126126
best_tok = e->val;
127127
} else {
128128
Weight best_weight = Weight::Zero();
129-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
129+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
130130
Weight fst_final = fst_.Final(PairToState(e->key)),
131131
lm_final = lm_diff_fst_->Final(PairToLmState(e->key)),
132132
final = Times(fst_final, lm_final);
@@ -404,7 +404,7 @@ class BiglmFasterDecoder {
404404
void ProcessNonemitting(BaseFloat cutoff) {
405405
// Processes nonemitting arcs for one frame.
406406
KALDI_ASSERT(queue_.empty());
407-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail)
407+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
408408
queue_.push_back(e->key);
409409
while (!queue_.empty()) {
410410
PairId state_pair = queue_.back();

src/decoder/faster-decoder.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void FasterDecoder::DecodeNonblocking(DecodableInterface *decodable,
7676

7777

7878
bool FasterDecoder::ReachedFinal() {
79-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
79+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
8080
Weight this_weight = Times(e->val->weight_, fst_.Final(e->key));
8181
if (this_weight != Weight::Zero())
8282
return true;
@@ -94,12 +94,12 @@ bool FasterDecoder::GetBestPath(fst::MutableFst<LatticeArc> *fst_out) {
9494
Token *best_tok = NULL;
9595
bool is_final = ReachedFinal();
9696
if (!is_final) {
97-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail)
97+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
9898
if (best_tok == NULL || *best_tok < *(e->val) )
9999
best_tok = e->val;
100100
} else {
101101
Weight best_weight = Weight::Zero();
102-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
102+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
103103
Weight this_weight = Times(e->val->weight_, fst_.Final(e->key));
104104
if (this_weight != Weight::Zero() &&
105105
this_weight.Value() < best_weight.Value()) {
@@ -305,7 +305,7 @@ BaseFloat FasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
305305
void FasterDecoder::ProcessNonemitting(BaseFloat cutoff) {
306306
// Processes nonemitting arcs for one frame.
307307
KALDI_ASSERT(queue_.empty());
308-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail)
308+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
309309
queue_.push_back(e->key);
310310
while (!queue_.empty()) {
311311
StateId state = queue_.back();

src/decoder/faster-decoder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,13 @@ class FasterDecoder {
8080
~FasterDecoder() { ClearToks(toks_.Clear()); }
8181

8282
void Decode(DecodableInterface *decodable);
83-
83+
84+
/// Returns true if a final state was active on the last frame.
8485
bool ReachedFinal();
8586

87+
/// Returns true if the output best path was not the empty
88+
/// FST (will only return false in unusual circumstances where
89+
/// no tokens survived).
8690
bool GetBestPath(fst::MutableFst<LatticeArc> *fst_out);
8791

8892
/// As a new alternative to Decode(), you can call InitDecoding

src/decoder/lattice-biglm-faster-decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ class LatticeBiglmFasterDecoder {
756756

757757
KALDI_ASSERT(queue_.empty());
758758
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
759-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
759+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
760760
queue_.push_back(e->key);
761761
// for pruning with current best token
762762
best_cost = std::min(best_cost, static_cast<BaseFloat>(e->val->tot_cost));

src/decoder/lattice-faster-decoder.cc

Lines changed: 138 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ void LatticeFasterDecoder::InitDecoding() {
4545
cost_offsets_.clear();
4646
ClearActiveTokens();
4747
warned_ = false;
48-
final_active_ = false;
49-
final_costs_.clear();
5048
num_toks_ = 0;
49+
decoding_finalized_ = false;
50+
final_costs_.clear();
5151
StateId start_state = fst_.Start();
5252
KALDI_ASSERT(start_state != fst::kNoStateId);
5353
active_toks_.resize(1);
@@ -70,15 +70,15 @@ bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) {
7070

7171
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
7272
if (NumFramesDecoded() % config_.prune_interval == 0) {
73-
PruneActiveTokens(config_.lattice_beam * 0.1); // use larger delta.
73+
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
7474
}
7575
ProcessEmitting(decodable); // Note: the value returned by
7676
// NumFramesDecoded() is incremented by
7777
// ProcessEmitting().
7878
ProcessNonemitting();
7979
}
80-
PruneActiveTokensFinal();
81-
80+
FinalizeDecoding();
81+
8282
// Returns true if we have any kind of traceback available (not necessarily
8383
// to the end state; query ReachedFinal() for that).
8484
return !final_costs_.empty();
@@ -103,6 +103,14 @@ bool LatticeFasterDecoder::GetRawLattice(fst::MutableFst<LatticeArc> *ofst) cons
103103
typedef Arc::StateId StateId;
104104
typedef Arc::Weight Weight;
105105
typedef Arc::Label Label;
106+
107+
unordered_map<Token*, BaseFloat> final_costs_local;
108+
109+
const unordered_map<Token*, BaseFloat> &final_costs =
110+
(decoding_finalized_ ? final_costs_ : final_costs_local);
111+
if (!decoding_finalized_)
112+
ComputeFinalCosts(&final_costs_local, NULL, NULL);
113+
106114
ofst->DeleteStates();
107115
// num-frames plus one (since frames are one-based, and we have
108116
// an extra frame for the start-state).
@@ -152,9 +160,9 @@ bool LatticeFasterDecoder::GetRawLattice(fst::MutableFst<LatticeArc> *ofst) cons
152160
ofst->AddArc(cur_state, arc);
153161
}
154162
if (f == num_frames) {
155-
std::map<Token*, BaseFloat>::const_iterator iter =
156-
final_costs_.find(tok);
157-
if (iter != final_costs_.end())
163+
unordered_map<Token*, BaseFloat>::const_iterator iter =
164+
final_costs.find(tok);
165+
if (iter != final_costs.end())
158166
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
159167
}
160168
}
@@ -323,34 +331,24 @@ void LatticeFasterDecoder::PruneForwardLinks(
323331
// the final-probs for pruning, otherwise it treats all tokens as final.
324332
void LatticeFasterDecoder::PruneForwardLinksFinal() {
325333
int32 frame_plus_one = active_toks_.size() - 1;
326-
334+
327335
if (active_toks_[frame_plus_one].toks == NULL ) // empty list; should not happen.
328336
KALDI_WARN << "No tokens alive at end of file\n";
329337

330-
// First go through, working out the best token (do it in parallel
331-
// including final-probs and not including final-probs; we'll take
332-
// the one with final-probs if it's valid).
333-
const BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
334-
BaseFloat best_cost_final = infinity,
335-
best_cost_nofinal = infinity;
336-
unordered_map<Token*, BaseFloat> tok_to_final_cost;
337338

338-
Elem *cur_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
339-
for (Elem *e = cur_toks; e != NULL; e = e->tail) {
340-
StateId state = e->key;
341-
Token *tok = e->val;
342-
BaseFloat final_cost = fst_.Final(state).Value();
343-
best_cost_final = std::min(best_cost_final, tok->tot_cost + final_cost);
344-
tok_to_final_cost[tok] = final_cost;
345-
best_cost_nofinal = std::min(best_cost_nofinal, tok->tot_cost);
346-
}
347-
DeleteElems(cur_toks);
348-
final_active_ = (best_cost_final != infinity);
349-
350-
// Now go through tokens on this frame, pruning forward links... may have
351-
// to iterate a few times until there is no more change, because the list is
352-
// not in topological order.
339+
typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
340+
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
341+
decoding_finalized_ = true;
342+
// We call DeleteElems() as a nicety, not because it's really necessary;
343+
// otherwise there would be a time, after calling PruneTokensForFrame() on the
344+
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
345+
// to nonexistent tokens.
346+
DeleteElems(toks_.Clear());
353347

348+
// Now go through tokens on this frame, pruning forward links... may have to
349+
// iterate a few times until there is no more change, because the list is not
350+
// in topological order. This is a modified version of the code in
351+
// PruneForwardLinks, but here we also take account of the final-probs.
354352
bool changed = true;
355353
BaseFloat delta = 1.0e-05;
356354
while (changed) {
@@ -362,13 +360,14 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() {
362360
// to the "final-prob", so instead of initializing tok_extra_cost to infinity
363361
// below we set it to the difference between the (score+final_prob) of this token,
364362
// and the best such (score+final_prob).
365-
BaseFloat tok_extra_cost;
366-
if (final_active_) {
367-
BaseFloat final_cost = tok_to_final_cost[tok];
368-
tok_extra_cost = (tok->tot_cost + final_cost) - best_cost_final;
369-
} else
370-
tok_extra_cost = tok->tot_cost - best_cost_nofinal;
371-
363+
364+
IterType iter = final_costs_.find(tok);
365+
KALDI_ASSERT(iter != final_costs_.end());
366+
BaseFloat final_cost = iter->second; // is zero if were no final-probs.
367+
BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
368+
// tok_extra_cost will be a "min" over either directly being final, or
369+
// being indirectly final through other links, and the loop below may
370+
// decrease its value:
372371
for (link = tok->links; link != NULL; ) {
373372
// See if we need to excise this link...
374373
Token *next_tok = link->next_tok;
@@ -398,31 +397,29 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() {
398397
// showed up as having no forward links. Here, the tok_extra_cost has
399398
// an extra component relating to the final-prob.
400399
if (tok_extra_cost > config_.lattice_beam)
401-
tok_extra_cost = infinity;
400+
tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
402401
// to be pruned in PruneTokensForFrame
403402

404403
if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta))
405404
changed = true;
406405
tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
407406
}
408407
} // while changed
408+
409+
}
409410

410-
// Now put surviving Tokens in the final_costs_ hash, which is a class
411-
// member (unlike tok_to_final_costs).
412-
for (Token *tok = active_toks_[frame_plus_one].toks;
413-
tok != NULL; tok = tok->next) {
414-
if (tok->extra_cost != infinity) {
415-
// If the token was not pruned away,
416-
if (final_active_) {
417-
BaseFloat final_cost = tok_to_final_cost[tok];
418-
if (final_cost != infinity)
419-
final_costs_[tok] = final_cost;
420-
} else {
421-
final_costs_[tok] = 0;
422-
}
423-
}
411+
BaseFloat LatticeFasterDecoder::FinalRelativeCost() const {
412+
if (!decoding_finalized_) {
413+
BaseFloat relative_cost;
414+
ComputeFinalCosts(NULL, &relative_cost, NULL);
415+
return relative_cost;
416+
} else {
417+
// we're not allowed to call that function if FinalizeDecoding() has
418+
// been called; return a cached value.
419+
return final_relative_cost_;
424420
}
425421
}
422+
426423

427424
// Prune away any tokens on this frame that have no forward links.
428425
// [we don't do this in PruneForwardLinks because it would give us
@@ -485,12 +482,88 @@ void LatticeFasterDecoder::PruneActiveTokens(BaseFloat delta) {
485482
<< " to " << num_toks_;
486483
}
487484

488-
// Version of PruneActiveTokens that we call on the final frame.
489-
// Takes into account the final-prob of tokens.
490-
void LatticeFasterDecoder::PruneActiveTokensFinal() {
485+
void LatticeFasterDecoder::ComputeFinalCosts(
486+
unordered_map<Token*, BaseFloat> *final_costs,
487+
BaseFloat *final_relative_cost,
488+
BaseFloat *final_best_cost) const {
489+
KALDI_ASSERT(!decoding_finalized_);
490+
if (final_costs != NULL)
491+
final_costs->clear();
492+
const Elem *final_toks = toks_.GetList();
493+
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
494+
BaseFloat best_cost = infinity,
495+
best_cost_with_final = infinity;
496+
while (final_toks != NULL) {
497+
StateId state = final_toks->key;
498+
Token *tok = final_toks->val;
499+
const Elem *next = final_toks->tail;
500+
BaseFloat final_cost = fst_.Final(state).Value();
501+
BaseFloat cost = tok->tot_cost,
502+
cost_with_final = cost + final_cost;
503+
best_cost = std::min(cost, best_cost);
504+
best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
505+
if (final_costs != NULL)
506+
(*final_costs)[tok] = final_cost;
507+
final_toks = next;
508+
}
509+
if (best_cost_with_final == infinity && final_costs != NULL) {
510+
// No states were final, so set all the costs in *final_costs to zero.
511+
typedef unordered_map<Token*, BaseFloat>::iterator IterType;
512+
for (IterType iter = final_costs->begin();
513+
iter != final_costs->end(); ++iter)
514+
iter->second = 0.0;
515+
}
516+
if (final_relative_cost != NULL) {
517+
if (best_cost == infinity && best_cost_with_final == infinity) {
518+
// Likely this will only happen if there are no tokens surviving.
519+
// This seems the least bad way to handle it.
520+
*final_relative_cost = infinity;
521+
} else {
522+
*final_relative_cost = best_cost_with_final - best_cost;
523+
}
524+
}
525+
if (final_best_cost != NULL) {
526+
if (best_cost_with_final != infinity) { // final-state exists.
527+
*final_best_cost = best_cost_with_final;
528+
} else { // no final-state exists.
529+
*final_best_cost = best_cost;
530+
}
531+
}
532+
}
533+
534+
void LatticeFasterDecoder::DecodeNonblocking(DecodableInterface *decodable,
535+
int32 max_num_frames) {
536+
KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
537+
"You must call InitDecoding() before DecodeNonblocking");
538+
int32 num_frames_ready = decodable->NumFramesReady();
539+
// num_frames_ready must be >= num_frames_decoded, or else
540+
// the number of frames ready must have decreased (which doesn't
541+
// make sense) or the decodable object changed between calls
542+
// (which isn't allowed).
543+
KALDI_ASSERT(num_frames_ready >= NumFramesDecoded());
544+
int32 target_frames_decoded = num_frames_ready;
545+
if (max_num_frames >= 0)
546+
target_frames_decoded = std::min(target_frames_decoded,
547+
NumFramesDecoded() + max_num_frames);
548+
while (NumFramesDecoded() < target_frames_decoded) {
549+
if (NumFramesDecoded() % config_.prune_interval == 0) {
550+
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
551+
}
552+
// note: ProcessEmitting() increments NumFramesDecoded().
553+
ProcessEmitting(decodable);
554+
ProcessNonemitting();
555+
}
556+
}
557+
558+
559+
// FinalizeDecoding() is a version of PruneActiveTokens that we call
560+
// (optionally) on the final frame. Takes into account the final-prob of
561+
// tokens. This function used to be called PruneActiveTokensFinal().
562+
void LatticeFasterDecoder::FinalizeDecoding() {
491563
int32 final_frame_plus_one = NumFramesDecoded();
492564
int32 num_toks_begin = num_toks_;
493-
// prune final frame (with final-probs); sets final_active_ and final_probs_
565+
// PruneForwardLinksFinal() prunes final frame (with final-probs), and
566+
// sets decoding_finalized_.
494567
PruneForwardLinksFinal();
495568
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
496569
bool b1, b2; // values not used.
@@ -576,13 +649,14 @@ void LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
576649
// (zero-based) used to get likelihoods
577650
// from the decodable object.
578651
active_toks_.resize(active_toks_.size() + 1);
579-
// Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_.
580-
Elem *last_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
581-
// in simple-decoder.h.
652+
653+
Elem *final_toks = toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
654+
// in simple-decoder.h. Removes the Elems from
655+
// being indexed in the hash in toks_.
582656
Elem *best_elem = NULL;
583657
BaseFloat adaptive_beam;
584658
size_t tok_cnt;
585-
BaseFloat cur_cutoff = GetCutoff(last_toks, &tok_cnt, &adaptive_beam, &best_elem);
659+
BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
586660
PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
587661

588662
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
@@ -619,10 +693,10 @@ void LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
619693
cost_offsets_.resize(frame + 1, 0.0);
620694
cost_offsets_[frame] = cost_offset;
621695

622-
// the tokens are now owned here, in last_toks, and the hash is empty.
696+
// the tokens are now owned here, in final_toks, and the hash is empty.
623697
// 'owned' is a complex thing here; the point is we need to call DeleteElem
624698
// on each elem 'e' to let toks_ know we're done with them.
625-
for (Elem *e = last_toks, *e_tail; e != NULL; e = e_tail) {
699+
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
626700
// loop this way because we delete "e" as we go.
627701
StateId state = e->key;
628702
Token *tok = e->val;
@@ -674,7 +748,7 @@ void LatticeFasterDecoder::ProcessNonemitting() {
674748

675749
KALDI_ASSERT(queue_.empty());
676750
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
677-
for (Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
751+
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
678752
queue_.push_back(e->key);
679753
// for pruning with current best token
680754
best_cost = std::min(best_cost, static_cast<BaseFloat>(e->val->tot_cost));

0 commit comments

Comments
 (0)