Skip to content

Commit 2c1b11a

Browse files
authored
[src] Change decodable code so nnet context does not have to be recomputed. (kaldi-asr#2549)
1 parent fdb6774 commit 2c1b11a

4 files changed

Lines changed: 29 additions & 4 deletions

File tree

src/nnet3/nnet-am-decodable-simple.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ DecodableNnetSimple::DecodableNnetSimple(
4646
(feats_.NumRows() + opts_.frame_subsampling_factor - 1) /
4747
opts_.frame_subsampling_factor;
4848
KALDI_ASSERT(IsSimpleNnet(nnet));
49-
ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_);
49+
compiler_.GetSimpleNnetContext(&nnet_left_context_, &nnet_right_context_);
5050
KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL));
5151
KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 &&
5252
"You need to set the --online-ivector-period option!"));

src/nnet3/nnet-example-utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ struct ExampleGenerationConfig {
150150
struct ChunkTimeInfo is used by class UtteranceSplitter to output
151151
information about how we split an utterance into chunks.
152152
*/
153-
154153
struct ChunkTimeInfo {
155154
int32 first_frame;
156155
int32 num_frames;

src/nnet3/nnet-optimize.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <iomanip>
2222
#include "nnet3/nnet-optimize.h"
2323
#include "nnet3/nnet-optimize-utils.h"
24+
#include "nnet3/nnet-utils.h"
2425
#include "base/timer.h"
2526

2627
namespace kaldi {
@@ -638,7 +639,8 @@ CachingOptimizingCompiler::CachingOptimizingCompiler(
638639
seconds_taken_total_(0.0), seconds_taken_compile_(0.0),
639640
seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0),
640641
seconds_taken_check_(0.0), seconds_taken_indexes_(0.0),
641-
seconds_taken_io_(0.0), cache_(config.cache_capacity) { }
642+
seconds_taken_io_(0.0), cache_(config.cache_capacity),
643+
nnet_left_context_(-1), nnet_right_context_(-1) { }
642644

643645
CachingOptimizingCompiler::CachingOptimizingCompiler(
644646
const Nnet &nnet,
@@ -648,8 +650,18 @@ CachingOptimizingCompiler::CachingOptimizingCompiler(
648650
seconds_taken_total_(0.0), seconds_taken_compile_(0.0),
649651
seconds_taken_optimize_(0.0), seconds_taken_expand_(0.0),
650652
seconds_taken_check_(0.0), seconds_taken_indexes_(0.0),
651-
seconds_taken_io_(0.0), cache_(config.cache_capacity) { }
653+
seconds_taken_io_(0.0), cache_(config.cache_capacity),
654+
nnet_left_context_(-1), nnet_right_context_(-1) { }
652655

656+
void CachingOptimizingCompiler::GetSimpleNnetContext(
657+
int32 *nnet_left_context, int32 *nnet_right_context) {
658+
if (nnet_left_context_ == -1) {
659+
ComputeSimpleNnetContext(nnet_, &nnet_left_context_,
660+
&nnet_right_context_);
661+
}
662+
*nnet_left_context = nnet_left_context_;
663+
*nnet_right_context = nnet_right_context_;
664+
}
653665

654666
void CachingOptimizingCompiler::ReadCache(std::istream &is, bool binary) {
655667
{

src/nnet3/nnet-optimize.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,16 @@ class CachingOptimizingCompiler {
242242
void ReadCache(std::istream &is, bool binary);
243243
void WriteCache(std::ostream &os, bool binary);
244244

245+
246+
// GetSimpleNnetContext() is equivalent to calling:
247+
// ComputeSimpleNnetContext(nnet_, &nnet_left_context,
248+
// &nnet_right_context)
249+
// but it caches it inside this class. This functionality is independent of
250+
// the rest of the functionality of this class; it just happens to be a
251+
// convenient place to put this mechanism.
252+
void GetSimpleNnetContext(int32 *nnet_left_context,
253+
int32 *nnet_right_context);
254+
245255
private:
246256

247257
// This function just implements the work of Compile(); it's made a separate
@@ -290,6 +300,10 @@ class CachingOptimizingCompiler {
290300
double seconds_taken_io_;
291301

292302
ComputationCache cache_;
303+
304+
// These following two variables are only used by the function GetSimpleNnetContext().
305+
int32 nnet_left_context_;
306+
int32 nnet_right_context_;
293307
};
294308

295309

0 commit comments

Comments
 (0)