@@ -68,9 +68,10 @@ GenericNumeratorComputation::GenericNumeratorComputation(
6868 }
6969
7070 offsets_.Resize (num_sequences);
71- std::unordered_map<int , MatrixIndexT> pdf_to_index;
71+ std::unordered_map<int32 , MatrixIndexT> pdf_to_index;
7272 int32 pdf_stride = nnet_output_.Stride ();
7373 int32 view_stride = nnet_output_.Stride () * num_sequences;
74+ pdf_to_index.reserve (view_stride);
7475 nnet_output_stride_ = pdf_stride;
7576 for (int seq = 0 ; seq < num_sequences; seq++) {
7677 for (int32 s = 0 ; s < supervision_.e2e_fsts [seq].NumStates (); s++) {
@@ -161,49 +162,46 @@ BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,
161162
162163 KALDI_ASSERT (seq >= 0 && seq < num_sequences);
163164
164- SubMatrix<BaseFloat> alpha_view (*alpha,
165- 0 , alpha->NumRows (),
166- 0 , alpha->NumCols ());
167-
168165 // variables for log_likelihood computation
169166 double log_scale_product = 0 ,
170167 log_prob_product = 0 ;
171168
172169 for (int t = 1 ; t <= num_frames; ++t) {
173- SubMatrix<BaseFloat> prev_alpha_t (alpha_view, t - 1 , 1 , 0 ,
174- alpha_view.NumCols () - 1 );
175- SubMatrix<BaseFloat> this_alpha_t (alpha_view, t, 1 , 0 ,
176- alpha_view.NumCols () - 1 );
170+ const BaseFloat *probs_tm1 = probs.RowData (t - 1 );
171+ BaseFloat *alpha_t = alpha->RowData (t);
172+ const BaseFloat *alpha_tm1 = alpha->RowData (t - 1 );
177173
178174 for (int32 h = 0 ; h < supervision_.e2e_fsts [seq].NumStates (); h++) {
179175 for (auto tr = in_transitions_[seq][h].begin ();
180- tr != in_transitions_[seq][h].end (); tr++ ) {
176+ tr != in_transitions_[seq][h].end (); ++tr ) {
181177 BaseFloat transition_prob = tr->transition_prob ;
182178 int32 pdf_id = tr->pdf_id ,
183179 prev_hmm_state = tr->hmm_state ;
184- BaseFloat prob = probs (t- 1 , pdf_id) ;
185- alpha_view (t, h) = LogAdd (alpha_view (t, h) ,
186- alpha_view (t- 1 , prev_hmm_state) + transition_prob + prob);
180+ BaseFloat prob = probs_tm1[ pdf_id] ;
181+ alpha_t [h] = LogAdd (alpha_t [h] ,
182+ alpha_tm1[ prev_hmm_state] + transition_prob + prob);
187183 }
188184 }
189- double sum = alpha_view (t-1 , alpha_view.NumCols () - 1 );
190- this_alpha_t .Add (-sum);
191- sum = this_alpha_t .LogSumExp ();
185+ double sum = alpha_tm1[alpha->NumCols () - 1 ];
186+ SubMatrix<BaseFloat> alpha_t_mat (*alpha, t, 1 , 0 ,
187+ alpha->NumCols () - 1 );
188+ alpha_t_mat.Add (-sum);
189+ sum = alpha_t_mat.LogSumExp ();
192190
193- alpha_view (t, alpha_view. NumCols () - 1 ) = sum;
191+ alpha_t [alpha-> NumCols () - 1 ] = sum;
194192 log_scale_product += sum;
195193 }
196- SubMatrix<BaseFloat> last_alpha (alpha_view, alpha_view. NumRows () - 1 , 1 ,
197- 0 , alpha_view. NumCols () - 1 );
194+ SubMatrix<BaseFloat> last_alpha (*alpha, alpha-> NumRows () - 1 , 1 ,
195+ 0 , alpha-> NumCols () - 1 );
198196 SubVector<BaseFloat> final_probs (final_probs_.RowData (seq),
199- alpha_view. NumCols () - 1 );
197+ alpha-> NumCols () - 1 );
200198
201199 // adjust last_alpha
202- double sum = alpha_view (alpha_view. NumRows () - 1 , alpha_view. NumCols () - 1 );
200+ double sum = (*alpha)(alpha-> NumRows () - 1 , alpha-> NumCols () - 1 );
203201 log_scale_product -= sum;
204202 last_alpha.AddVecToRows (1.0 , final_probs);
205203 sum = last_alpha.LogSumExp ();
206- alpha_view (alpha_view. NumRows () - 1 , alpha_view. NumCols () - 1 ) = sum;
204+ (*alpha)(alpha-> NumRows () - 1 , alpha-> NumCols () - 1 ) = sum;
207205
208206 // second part of criterion
209207 log_prob_product = sum - offsets_ (seq);
@@ -242,7 +240,8 @@ bool GenericNumeratorComputation::ForwardBackward(
242240 // Backward part
243241 BetaLastFrame (seq, alpha, &beta);
244242 BetaRemainingFrames (seq, probs, alpha, &beta, &derivs);
245- ok = ok || CheckValues (seq, probs, alpha, beta, derivs);
243+ if (GetVerboseLevel () >= 1 )
244+ ok = ok && CheckValues (seq, probs, alpha, beta, derivs);
246245 }
247246 // Transfer and add the derivatives to the values in the matrix
248247 AddSpecificPdfsIndirect (&derivs, index_to_pdf_, nnet_output_deriv);
@@ -268,7 +267,6 @@ BaseFloat GenericNumeratorComputation::ComputeObjf() {
268267 return partial_loglike;
269268}
270269
271-
272270BaseFloat GenericNumeratorComputation::GetTotalProb (
273271 const Matrix<BaseFloat> &alpha) {
274272 return alpha (alpha.NumRows () - 1 , alpha.NumCols () - 1 );
@@ -306,36 +304,33 @@ void GenericNumeratorComputation::BetaRemainingFrames(int seq,
306304 num_states = supervision_.e2e_fsts [seq].NumStates ();
307305 KALDI_ASSERT (seq >= 0 && seq < num_sequences);
308306
309- SubMatrix<BaseFloat> log_prob_deriv (*derivs,
310- 0 , derivs->NumRows (),
311- 0 , derivs->NumCols ());
312-
313307 for (int t = num_frames - 1 ; t >= 0 ; --t) {
314- SubVector< BaseFloat> this_beta (beta-> RowData (t % 2 ), num_states);
315- const SubVector<BaseFloat> next_beta ( beta->RowData ((t + 1 ) % 2 ),
316- num_states );
317-
318- BaseFloat inv_arbitrary_scale = alpha (t, num_states );
308+ const BaseFloat * alpha_t = alpha. RowData (t),
309+ *beta_tp1 = beta->RowData ((t + 1 ) % 2 ),
310+ * probs_t = probs. RowData (t );
311+ BaseFloat * log_prob_deriv_t = derivs-> RowData (t),
312+ * beta_t = beta-> RowData (t % 2 );
319313
314+ BaseFloat inv_arbitrary_scale = alpha_t [num_states];
320315 for (int32 h = 0 ; h < supervision_.e2e_fsts [seq].NumStates (); h++) {
321316 BaseFloat tot_variable_factor;
322317 tot_variable_factor = -std::numeric_limits<BaseFloat>::infinity ();
323318 for (auto tr = out_transitions_[seq][h].begin ();
324- tr != out_transitions_[seq][h].end (); tr++ ) {
319+ tr != out_transitions_[seq][h].end (); ++tr ) {
325320 BaseFloat transition_prob = tr->transition_prob ;
326321 int32 pdf_id = tr->pdf_id ,
327322 next_hmm_state = tr->hmm_state ;
328323 BaseFloat variable_factor = transition_prob +
329- next_beta ( next_hmm_state) +
330- probs (t, pdf_id) - inv_arbitrary_scale;
324+ beta_tp1[ next_hmm_state] +
325+ probs_t [ pdf_id] - inv_arbitrary_scale;
331326 tot_variable_factor = LogAdd (tot_variable_factor,
332327 variable_factor);
333328
334- BaseFloat occupation_prob = variable_factor + alpha (t, h) ;
335- log_prob_deriv (t, pdf_id) = LogAdd (log_prob_deriv (t, pdf_id) ,
329+ BaseFloat occupation_prob = variable_factor + alpha_t [h] ;
330+ log_prob_deriv_t [ pdf_id] = LogAdd (log_prob_deriv_t [ pdf_id] ,
336331 occupation_prob);
337332 }
338- this_beta (h) = tot_variable_factor;
333+ beta_t [h] = tot_variable_factor;
339334 }
340335 }
341336}
@@ -381,7 +376,29 @@ bool GenericNumeratorComputation::CheckValues(int seq,
381376 const Matrix<BaseFloat> &alpha,
382377 const Matrix<BaseFloat> &beta,
383378 const Matrix<BaseFloat> &derivs) const {
384- // empty checks for now
379+ const int32 num_frames = supervision_.frames_per_sequence ;
380+ // only check the derivs for the first and last frames
381+ const std::vector<int32> times = {0 , num_frames - 1 };
382+ for (const int32 t: times) {
383+ BaseFloat deriv_sum = 0.0 ;
384+ for (int32 n = 0 ; n < probs.NumCols (); n++) {
385+ int32 pdf_stride = nnet_output_.Stride ();
386+ int32 pdf2seq = index_to_pdf_[n] / pdf_stride;
387+ if (pdf2seq != seq) // this pdf is not in the space of this sequence
388+ continue ;
389+ deriv_sum += Exp (derivs (t, n));
390+ }
391+
392+ if (!ApproxEqual (deriv_sum, 1.0 )) {
393+ KALDI_WARN << " On time " << t
394+ << " for seq " << seq << " , deriv sum "
395+ << deriv_sum << " != 1.0" ;
396+ if (fabs (deriv_sum - 1.0 ) > 0.05 || deriv_sum - deriv_sum != 0 ) {
397+ KALDI_WARN << " Excessive error detected, will abandon this minibatch" ;
398+ return false ;
399+ }
400+ }
401+ }
385402 return true ;
386403}
387404
0 commit comments