@@ -45,9 +45,9 @@ void LatticeFasterDecoder::InitDecoding() {
4545 cost_offsets_.clear ();
4646 ClearActiveTokens ();
4747 warned_ = false ;
48- final_active_ = false ;
49- final_costs_.clear ();
5048 num_toks_ = 0 ;
49+ decoding_finalized_ = false ;
50+ final_costs_.clear ();
5151 StateId start_state = fst_.Start ();
5252 KALDI_ASSERT (start_state != fst::kNoStateId );
5353 active_toks_.resize (1 );
@@ -70,15 +70,15 @@ bool LatticeFasterDecoder::Decode(DecodableInterface *decodable) {
7070
7171 while (!decodable->IsLastFrame (NumFramesDecoded () - 1 )) {
7272 if (NumFramesDecoded () % config_.prune_interval == 0 ) {
73- PruneActiveTokens (config_.lattice_beam * 0.1 ); // use larger delta.
73+ PruneActiveTokens (config_.lattice_beam * config_. prune_scale );
7474 }
7575 ProcessEmitting (decodable); // Note: the value returned by
7676 // NumFramesDecoded() is incremented by
7777 // ProcessEmitting().
7878 ProcessNonemitting ();
7979 }
80- PruneActiveTokensFinal ();
81-
80+ FinalizeDecoding ();
81+
8282 // Returns true if we have any kind of traceback available (not necessarily
8383 // to the end state; query ReachedFinal() for that).
8484 return !final_costs_.empty ();
@@ -103,6 +103,14 @@ bool LatticeFasterDecoder::GetRawLattice(fst::MutableFst<LatticeArc> *ofst) cons
103103 typedef Arc::StateId StateId;
104104 typedef Arc::Weight Weight;
105105 typedef Arc::Label Label;
106+
107+ unordered_map<Token*, BaseFloat> final_costs_local;
108+
109+ const unordered_map<Token*, BaseFloat> &final_costs =
110+ (decoding_finalized_ ? final_costs_ : final_costs_local);
111+ if (!decoding_finalized_)
112+ ComputeFinalCosts (&final_costs_local, NULL , NULL );
113+
106114 ofst->DeleteStates ();
107115 // num-frames plus one (since frames are one-based, and we have
108116 // an extra frame for the start-state).
@@ -152,9 +160,9 @@ bool LatticeFasterDecoder::GetRawLattice(fst::MutableFst<LatticeArc> *ofst) cons
152160 ofst->AddArc (cur_state, arc);
153161 }
154162 if (f == num_frames) {
155- std::map <Token*, BaseFloat>::const_iterator iter =
156- final_costs_ .find (tok);
157- if (iter != final_costs_ .end ())
163+ unordered_map <Token*, BaseFloat>::const_iterator iter =
164+ final_costs .find (tok);
165+ if (iter != final_costs .end ())
158166 ofst->SetFinal (cur_state, LatticeWeight (iter->second , 0 ));
159167 }
160168 }
@@ -323,34 +331,24 @@ void LatticeFasterDecoder::PruneForwardLinks(
323331// the final-probs for pruning, otherwise it treats all tokens as final.
324332void LatticeFasterDecoder::PruneForwardLinksFinal () {
325333 int32 frame_plus_one = active_toks_.size () - 1 ;
326-
334+
327335 if (active_toks_[frame_plus_one].toks == NULL ) // empty list; should not happen.
328336 KALDI_WARN << " No tokens alive at end of file\n " ;
329337
330- // First go through, working out the best token (do it in parallel
331- // including final-probs and not including final-probs; we'll take
332- // the one with final-probs if it's valid).
333- const BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity ();
334- BaseFloat best_cost_final = infinity,
335- best_cost_nofinal = infinity;
336- unordered_map<Token*, BaseFloat> tok_to_final_cost;
337338
338- Elem *cur_toks = toks_.Clear (); // analogous to swapping prev_toks_ / cur_toks_
339- for (Elem *e = cur_toks; e != NULL ; e = e->tail ) {
340- StateId state = e->key ;
341- Token *tok = e->val ;
342- BaseFloat final_cost = fst_.Final (state).Value ();
343- best_cost_final = std::min (best_cost_final, tok->tot_cost + final_cost);
344- tok_to_final_cost[tok] = final_cost;
345- best_cost_nofinal = std::min (best_cost_nofinal, tok->tot_cost );
346- }
347- DeleteElems (cur_toks);
348- final_active_ = (best_cost_final != infinity);
349-
350- // Now go through tokens on this frame, pruning forward links... may have
351- // to iterate a few times until there is no more change, because the list is
352- // not in topological order.
339+ typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
340+ ComputeFinalCosts (&final_costs_, &final_relative_cost_, &final_best_cost_);
341+ decoding_finalized_ = true ;
342+ // We call DeleteElems() as a nicety, not because it's really necessary;
343+ // otherwise there would be a time, after calling PruneTokensForFrame() on the
344+ // final frame, when toks_.GetList() or toks_.Clear() would contain pointers
345+ // to nonexistent tokens.
346+ DeleteElems (toks_.Clear ());
353347
348+ // Now go through tokens on this frame, pruning forward links... may have to
349+ // iterate a few times until there is no more change, because the list is not
350+ // in topological order. This is a modified version of the code in
351+ // PruneForwardLinks, but here we also take account of the final-probs.
354352 bool changed = true ;
355353 BaseFloat delta = 1.0e-05 ;
356354 while (changed) {
@@ -362,13 +360,14 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() {
362360 // to the "final-prob", so instead of initializing tok_extra_cost to infinity
363361 // below we set it to the difference between the (score+final_prob) of this token,
364362 // and the best such (score+final_prob).
365- BaseFloat tok_extra_cost;
366- if (final_active_) {
367- BaseFloat final_cost = tok_to_final_cost[tok];
368- tok_extra_cost = (tok->tot_cost + final_cost) - best_cost_final;
369- } else
370- tok_extra_cost = tok->tot_cost - best_cost_nofinal;
371-
363+
364+ IterType iter = final_costs_.find (tok);
365+ KALDI_ASSERT (iter != final_costs_.end ());
366+ BaseFloat final_cost = iter->second ; // is zero if were no final-probs.
367+ BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
368+ // tok_extra_cost will be a "min" over either directly being final, or
369+ // being indirectly final through other links, and the loop below may
370+ // decrease its value:
372371 for (link = tok->links ; link != NULL ; ) {
373372 // See if we need to excise this link...
374373 Token *next_tok = link->next_tok ;
@@ -398,31 +397,29 @@ void LatticeFasterDecoder::PruneForwardLinksFinal() {
398397 // showed up as having no forward links. Here, the tok_extra_cost has
399398 // an extra component relating to the final-prob.
400399 if (tok_extra_cost > config_.lattice_beam )
401- tok_extra_cost = infinity;
400+ tok_extra_cost = std::numeric_limits<BaseFloat>:: infinity () ;
402401 // to be pruned in PruneTokensForFrame
403402
404403 if (!ApproxEqual (tok->extra_cost , tok_extra_cost, delta))
405404 changed = true ;
406405 tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_.
407406 }
408407 } // while changed
408+
409+ }
409410
410- // Now put surviving Tokens in the final_costs_ hash, which is a class
411- // member (unlike tok_to_final_costs).
412- for (Token *tok = active_toks_[frame_plus_one].toks ;
413- tok != NULL ; tok = tok->next ) {
414- if (tok->extra_cost != infinity) {
415- // If the token was not pruned away,
416- if (final_active_) {
417- BaseFloat final_cost = tok_to_final_cost[tok];
418- if (final_cost != infinity)
419- final_costs_[tok] = final_cost;
420- } else {
421- final_costs_[tok] = 0 ;
422- }
423- }
411+ BaseFloat LatticeFasterDecoder::FinalRelativeCost () const {
412+ if (!decoding_finalized_) {
413+ BaseFloat relative_cost;
414+ ComputeFinalCosts (NULL , &relative_cost, NULL );
415+ return relative_cost;
416+ } else {
417+ // we're not allowed to call that function if FinalizeDecoding() has
418+ // been called; return a cached value.
419+ return final_relative_cost_;
424420 }
425421}
422+
426423
427424// Prune away any tokens on this frame that have no forward links.
428425// [we don't do this in PruneForwardLinks because it would give us
@@ -485,12 +482,88 @@ void LatticeFasterDecoder::PruneActiveTokens(BaseFloat delta) {
485482 << " to " << num_toks_;
486483}
487484
488- // Version of PruneActiveTokens that we call on the final frame.
489- // Takes into account the final-prob of tokens.
490- void LatticeFasterDecoder::PruneActiveTokensFinal () {
485+ void LatticeFasterDecoder::ComputeFinalCosts (
486+ unordered_map<Token*, BaseFloat> *final_costs,
487+ BaseFloat *final_relative_cost,
488+ BaseFloat *final_best_cost) const {
489+ KALDI_ASSERT (!decoding_finalized_);
490+ if (final_costs != NULL )
491+ final_costs->clear ();
492+ const Elem *final_toks = toks_.GetList ();
493+ BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity ();
494+ BaseFloat best_cost = infinity,
495+ best_cost_with_final = infinity;
496+ while (final_toks != NULL ) {
497+ StateId state = final_toks->key ;
498+ Token *tok = final_toks->val ;
499+ const Elem *next = final_toks->tail ;
500+ BaseFloat final_cost = fst_.Final (state).Value ();
501+ BaseFloat cost = tok->tot_cost ,
502+ cost_with_final = cost + final_cost;
503+ best_cost = std::min (cost, best_cost);
504+ best_cost_with_final = std::min (cost_with_final, best_cost_with_final);
505+ if (final_costs != NULL )
506+ (*final_costs)[tok] = final_cost;
507+ final_toks = next;
508+ }
509+ if (best_cost_with_final == infinity && final_costs != NULL ) {
510+ // No states were final, so set all the costs in *final_costs to zero.
511+ typedef unordered_map<Token*, BaseFloat>::iterator IterType;
512+ for (IterType iter = final_costs->begin ();
513+ iter != final_costs->end (); ++iter)
514+ iter->second = 0.0 ;
515+ }
516+ if (final_relative_cost != NULL ) {
517+ if (best_cost == infinity && best_cost_with_final == infinity) {
518+ // Likely this will only happen if there are no tokens surviving.
519+ // This seems the least bad way to handle it.
520+ *final_relative_cost = infinity;
521+ } else {
522+ *final_relative_cost = best_cost_with_final - best_cost;
523+ }
524+ }
525+ if (final_best_cost != NULL ) {
526+ if (best_cost_with_final != infinity) { // final-state exists.
527+ *final_best_cost = best_cost_with_final;
528+ } else { // no final-state exists.
529+ *final_best_cost = best_cost;
530+ }
531+ }
532+ }
533+
534+ void LatticeFasterDecoder::DecodeNonblocking (DecodableInterface *decodable,
535+ int32 max_num_frames) {
536+ KALDI_ASSERT (!active_toks_.empty () && !decoding_finalized_ &&
537+ " You must call InitDecoding() before DecodeNonblocking" );
538+ int32 num_frames_ready = decodable->NumFramesReady ();
539+ // num_frames_ready must be >= num_frames_decoded, or else
540+ // the number of frames ready must have decreased (which doesn't
541+ // make sense) or the decodable object changed between calls
542+ // (which isn't allowed).
543+ KALDI_ASSERT (num_frames_ready >= NumFramesDecoded ());
544+ int32 target_frames_decoded = num_frames_ready;
545+ if (max_num_frames >= 0 )
546+ target_frames_decoded = std::min (target_frames_decoded,
547+ NumFramesDecoded () + max_num_frames);
548+ while (NumFramesDecoded () < target_frames_decoded) {
549+ if (NumFramesDecoded () % config_.prune_interval == 0 ) {
550+ PruneActiveTokens (config_.lattice_beam * config_.prune_scale );
551+ }
552+ // note: ProcessEmitting() increments NumFramesDecoded().
553+ ProcessEmitting (decodable);
554+ ProcessNonemitting ();
555+ }
556+ }
557+
558+
559+ // FinalizeDecoding() is a version of PruneActiveTokens that we call
560+ // (optionally) on the final frame. Takes into account the final-prob of
561+ // tokens. This function used to be called PruneActiveTokensFinal().
562+ void LatticeFasterDecoder::FinalizeDecoding () {
491563 int32 final_frame_plus_one = NumFramesDecoded ();
492564 int32 num_toks_begin = num_toks_;
493- // prune final frame (with final-probs); sets final_active_ and final_probs_
565+ // PruneForwardLinksFinal() prunes final frame (with final-probs), and
566+ // sets decoding_finalized_.
494567 PruneForwardLinksFinal ();
495568 for (int32 f = final_frame_plus_one - 1 ; f >= 0 ; f--) {
496569 bool b1, b2; // values not used.
@@ -576,13 +649,14 @@ void LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
576649 // (zero-based) used to get likelihoods
577650 // from the decodable object.
578651 active_toks_.resize (active_toks_.size () + 1 );
579- // Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_.
580- Elem *last_toks = toks_.Clear (); // analogous to swapping prev_toks_ / cur_toks_
581- // in simple-decoder.h.
652+
653+ Elem *final_toks = toks_.Clear (); // analogous to swapping prev_toks_ / cur_toks_
654+ // in simple-decoder.h. Removes the Elems from
655+ // being indexed in the hash in toks_.
582656 Elem *best_elem = NULL ;
583657 BaseFloat adaptive_beam;
584658 size_t tok_cnt;
585- BaseFloat cur_cutoff = GetCutoff (last_toks , &tok_cnt, &adaptive_beam, &best_elem);
659+ BaseFloat cur_cutoff = GetCutoff (final_toks , &tok_cnt, &adaptive_beam, &best_elem);
586660 PossiblyResizeHash (tok_cnt); // This makes sure the hash is always big enough.
587661
588662 BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity ();
@@ -619,10 +693,10 @@ void LatticeFasterDecoder::ProcessEmitting(DecodableInterface *decodable) {
619693 cost_offsets_.resize (frame + 1 , 0.0 );
620694 cost_offsets_[frame] = cost_offset;
621695
622- // the tokens are now owned here, in last_toks , and the hash is empty.
696+ // the tokens are now owned here, in final_toks , and the hash is empty.
623697 // 'owned' is a complex thing here; the point is we need to call DeleteElem
624698 // on each elem 'e' to let toks_ know we're done with them.
625- for (Elem *e = last_toks , *e_tail; e != NULL ; e = e_tail) {
699+ for (Elem *e = final_toks , *e_tail; e != NULL ; e = e_tail) {
626700 // loop this way because we delete "e" as we go.
627701 StateId state = e->key ;
628702 Token *tok = e->val ;
@@ -674,7 +748,7 @@ void LatticeFasterDecoder::ProcessNonemitting() {
674748
675749 KALDI_ASSERT (queue_.empty ());
676750 BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity ();
677- for (Elem *e = toks_.GetList (); e != NULL ; e = e->tail ) {
751+ for (const Elem *e = toks_.GetList (); e != NULL ; e = e->tail ) {
678752 queue_.push_back (e->key );
679753 // for pruning with current best token
680754 best_cost = std::min (best_cost, static_cast <BaseFloat>(e->val ->tot_cost ));
0 commit comments