Skip to content

Commit 775c770

Browse files
hhadiandanpovey
authored andcommitted
[src] Minor optimizations in "e2e" numerator code (kaldi-asr#2508)
1 parent 598b177 commit 775c770

File tree

2 files changed

+59
-43
lines changed

2 files changed

+59
-43
lines changed

src/chain/chain-generic-numerator.cc

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ GenericNumeratorComputation::GenericNumeratorComputation(
6868
}
6969

7070
offsets_.Resize(num_sequences);
71-
std::unordered_map<int, MatrixIndexT> pdf_to_index;
71+
std::unordered_map<int32, MatrixIndexT> pdf_to_index;
7272
int32 pdf_stride = nnet_output_.Stride();
7373
int32 view_stride = nnet_output_.Stride() * num_sequences;
74+
pdf_to_index.reserve(view_stride);
7475
nnet_output_stride_ = pdf_stride;
7576
for (int seq = 0; seq < num_sequences; seq++) {
7677
for (int32 s = 0; s < supervision_.e2e_fsts[seq].NumStates(); s++) {
@@ -161,49 +162,46 @@ BaseFloat GenericNumeratorComputation::AlphaRemainingFrames(int seq,
161162

162163
KALDI_ASSERT(seq >= 0 && seq < num_sequences);
163164

164-
SubMatrix<BaseFloat> alpha_view(*alpha,
165-
0, alpha->NumRows(),
166-
0, alpha->NumCols());
167-
168165
// variables for log_likelihood computation
169166
double log_scale_product = 0,
170167
log_prob_product = 0;
171168

172169
for (int t = 1; t <= num_frames; ++t) {
173-
SubMatrix<BaseFloat> prev_alpha_t(alpha_view, t - 1, 1, 0,
174-
alpha_view.NumCols() - 1);
175-
SubMatrix<BaseFloat> this_alpha_t(alpha_view, t, 1, 0,
176-
alpha_view.NumCols() - 1);
170+
const BaseFloat *probs_tm1 = probs.RowData(t - 1);
171+
BaseFloat *alpha_t = alpha->RowData(t);
172+
const BaseFloat *alpha_tm1 = alpha->RowData(t - 1);
177173

178174
for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) {
179175
for (auto tr = in_transitions_[seq][h].begin();
180-
tr != in_transitions_[seq][h].end(); tr++) {
176+
tr != in_transitions_[seq][h].end(); ++tr) {
181177
BaseFloat transition_prob = tr->transition_prob;
182178
int32 pdf_id = tr->pdf_id,
183179
prev_hmm_state = tr->hmm_state;
184-
BaseFloat prob = probs(t-1, pdf_id);
185-
alpha_view(t, h) = LogAdd(alpha_view(t, h),
186-
alpha_view(t-1, prev_hmm_state) + transition_prob + prob);
180+
BaseFloat prob = probs_tm1[pdf_id];
181+
alpha_t[h] = LogAdd(alpha_t[h],
182+
alpha_tm1[prev_hmm_state] + transition_prob + prob);
187183
}
188184
}
189-
double sum = alpha_view(t-1, alpha_view.NumCols() - 1);
190-
this_alpha_t.Add(-sum);
191-
sum = this_alpha_t.LogSumExp();
185+
double sum = alpha_tm1[alpha->NumCols() - 1];
186+
SubMatrix<BaseFloat> alpha_t_mat(*alpha, t, 1, 0,
187+
alpha->NumCols() - 1);
188+
alpha_t_mat.Add(-sum);
189+
sum = alpha_t_mat.LogSumExp();
192190

193-
alpha_view(t, alpha_view.NumCols() - 1) = sum;
191+
alpha_t[alpha->NumCols() - 1] = sum;
194192
log_scale_product += sum;
195193
}
196-
SubMatrix<BaseFloat> last_alpha(alpha_view, alpha_view.NumRows() - 1, 1,
197-
0, alpha_view.NumCols() - 1);
194+
SubMatrix<BaseFloat> last_alpha(*alpha, alpha->NumRows() - 1, 1,
195+
0, alpha->NumCols() - 1);
198196
SubVector<BaseFloat> final_probs(final_probs_.RowData(seq),
199-
alpha_view.NumCols() - 1);
197+
alpha->NumCols() - 1);
200198

201199
// adjust last_alpha
202-
double sum = alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1);
200+
double sum = (*alpha)(alpha->NumRows() - 1, alpha->NumCols() - 1);
203201
log_scale_product -= sum;
204202
last_alpha.AddVecToRows(1.0, final_probs);
205203
sum = last_alpha.LogSumExp();
206-
alpha_view(alpha_view.NumRows() - 1, alpha_view.NumCols() - 1) = sum;
204+
(*alpha)(alpha->NumRows() - 1, alpha->NumCols() - 1) = sum;
207205

208206
// second part of criterion
209207
log_prob_product = sum - offsets_(seq);
@@ -242,7 +240,8 @@ bool GenericNumeratorComputation::ForwardBackward(
242240
// Backward part
243241
BetaLastFrame(seq, alpha, &beta);
244242
BetaRemainingFrames(seq, probs, alpha, &beta, &derivs);
245-
ok = ok || CheckValues(seq, probs, alpha, beta, derivs);
243+
if (GetVerboseLevel() >= 1)
244+
ok = ok && CheckValues(seq, probs, alpha, beta, derivs);
246245
}
247246
// Transfer and add the derivatives to the values in the matrix
248247
AddSpecificPdfsIndirect(&derivs, index_to_pdf_, nnet_output_deriv);
@@ -268,7 +267,6 @@ BaseFloat GenericNumeratorComputation::ComputeObjf() {
268267
return partial_loglike;
269268
}
270269

271-
272270
BaseFloat GenericNumeratorComputation::GetTotalProb(
273271
const Matrix<BaseFloat> &alpha) {
274272
return alpha(alpha.NumRows() - 1, alpha.NumCols() - 1);
@@ -306,36 +304,33 @@ void GenericNumeratorComputation::BetaRemainingFrames(int seq,
306304
num_states = supervision_.e2e_fsts[seq].NumStates();
307305
KALDI_ASSERT(seq >= 0 && seq < num_sequences);
308306

309-
SubMatrix<BaseFloat> log_prob_deriv(*derivs,
310-
0, derivs->NumRows(),
311-
0, derivs->NumCols());
312-
313307
for (int t = num_frames - 1; t >= 0; --t) {
314-
SubVector<BaseFloat> this_beta(beta->RowData(t % 2), num_states);
315-
const SubVector<BaseFloat> next_beta(beta->RowData((t + 1) % 2),
316-
num_states);
317-
318-
BaseFloat inv_arbitrary_scale = alpha(t, num_states);
308+
const BaseFloat *alpha_t = alpha.RowData(t),
309+
*beta_tp1 = beta->RowData((t + 1) % 2),
310+
*probs_t = probs.RowData(t);
311+
BaseFloat *log_prob_deriv_t = derivs->RowData(t),
312+
*beta_t = beta->RowData(t % 2);
319313

314+
BaseFloat inv_arbitrary_scale = alpha_t[num_states];
320315
for (int32 h = 0; h < supervision_.e2e_fsts[seq].NumStates(); h++) {
321316
BaseFloat tot_variable_factor;
322317
tot_variable_factor = -std::numeric_limits<BaseFloat>::infinity();
323318
for (auto tr = out_transitions_[seq][h].begin();
324-
tr != out_transitions_[seq][h].end(); tr++) {
319+
tr != out_transitions_[seq][h].end(); ++tr) {
325320
BaseFloat transition_prob = tr->transition_prob;
326321
int32 pdf_id = tr->pdf_id,
327322
next_hmm_state = tr->hmm_state;
328323
BaseFloat variable_factor = transition_prob +
329-
next_beta(next_hmm_state) +
330-
probs(t, pdf_id) - inv_arbitrary_scale;
324+
beta_tp1[next_hmm_state] +
325+
probs_t[pdf_id] - inv_arbitrary_scale;
331326
tot_variable_factor = LogAdd(tot_variable_factor,
332327
variable_factor);
333328

334-
BaseFloat occupation_prob = variable_factor + alpha(t, h);
335-
log_prob_deriv(t, pdf_id) = LogAdd(log_prob_deriv(t, pdf_id),
329+
BaseFloat occupation_prob = variable_factor + alpha_t[h];
330+
log_prob_deriv_t[pdf_id] = LogAdd(log_prob_deriv_t[pdf_id],
336331
occupation_prob);
337332
}
338-
this_beta(h) = tot_variable_factor;
333+
beta_t[h] = tot_variable_factor;
339334
}
340335
}
341336
}
@@ -381,7 +376,29 @@ bool GenericNumeratorComputation::CheckValues(int seq,
381376
const Matrix<BaseFloat> &alpha,
382377
const Matrix<BaseFloat> &beta,
383378
const Matrix<BaseFloat> &derivs) const {
384-
// empty checks for now
379+
const int32 num_frames = supervision_.frames_per_sequence;
380+
// only check the derivs for the first and last frames
381+
const std::vector<int32> times = {0, num_frames - 1};
382+
for (const int32 t: times) {
383+
BaseFloat deriv_sum = 0.0;
384+
for (int32 n = 0; n < probs.NumCols(); n++) {
385+
int32 pdf_stride = nnet_output_.Stride();
386+
int32 pdf2seq = index_to_pdf_[n] / pdf_stride;
387+
if (pdf2seq != seq) // this pdf is not in the space of this sequence
388+
continue;
389+
deriv_sum += Exp(derivs(t, n));
390+
}
391+
392+
if (!ApproxEqual(deriv_sum, 1.0)) {
393+
KALDI_WARN << "On time " << t
394+
<< " for seq " << seq << ", deriv sum "
395+
<< deriv_sum << " != 1.0";
396+
if (fabs(deriv_sum - 1.0) > 0.05 || deriv_sum - deriv_sum != 0) {
397+
KALDI_WARN << "Excessive error detected, will abandon this minibatch";
398+
return false;
399+
}
400+
}
401+
}
385402
return true;
386403
}
387404

src/chain/chain-generic-numerator.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class GenericNumeratorComputation {
170170
BaseFloat GetTotalProb(const Matrix<BaseFloat> &alpha);
171171

172172
// some checking that we can do if debug mode is activated, or on frame zero.
173-
// Sets ok_ to false if a bad problem is detected.
173+
// Returns false if a bad problem is detected.
174174
bool CheckValues(int32 seq,
175175
const Matrix<BaseFloat> &probs,
176176
const Matrix<BaseFloat> &alpha,
@@ -196,8 +196,7 @@ class GenericNumeratorComputation {
196196
Matrix<BaseFloat> final_probs_; // indexed by seq, state
197197

198198
// an offset subtracted from the logprobs of transitions out of the first
199-
// state of each graph to help reduce numerical problems. Note the
200-
// generic forward-backward computations cannot be done in log-space.
199+
// state of each graph to help reduce numerical problems.
201200
Vector<BaseFloat> offsets_;
202201
};
203202

0 commit comments

Comments
 (0)