Skip to content

Commit 461746a

Browse files
author
Ilya Platonov
committed
Preparing code to move into kaldi repository.
Fixing stuff and renaming into am_nnet.
1 parent 75f8b02 commit 461746a

File tree

7 files changed

+39
-76
lines changed

7 files changed

+39
-76
lines changed

egs/apiai_decode/s5/README

Lines changed: 0 additions & 37 deletions
This file was deleted.

src/nnet3/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
2828
discriminative-supervision.o nnet-discriminative-example.o \
2929
nnet-discriminative-diagnostics.o \
3030
discriminative-training.o nnet-discriminative-training.o \
31-
online-nnet3-decodable.o
31+
online-nnet3-decodable-simple.o
3232

3333

3434
LIBNAME = kaldi-nnet3

src/nnet3/online-nnet3-decodable.cc renamed to src/nnet3/online-nnet3-decodable-simple.cc

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,37 @@
1818
// See the Apache 2 License for the specific language governing permissions and
1919
// limitations under the License.
2020

21-
#include "nnet3/online-nnet3-decodable.h"
21+
#include <nnet3/online-nnet3-decodable-simple.h>
2222
#include "nnet3/nnet-utils.h"
2323

2424
namespace kaldi {
2525
namespace nnet3 {
2626

27-
DecodableNnet3Online::DecodableNnet3Online(
28-
const AmNnetSimple &nnet,
27+
DecodableNnet3SimpleOnline::DecodableNnet3SimpleOnline(
28+
const AmNnetSimple &am_nnet,
2929
const TransitionModel &trans_model,
3030
const DecodableNnet3OnlineOptions &opts,
3131
OnlineFeatureInterface *input_feats):
32-
compiler_(nnet.GetNnet(), opts.optimize_config),
32+
compiler_(am_nnet.GetNnet(), opts.optimize_config),
3333
features_(input_feats),
34-
nnet_(nnet),
34+
am_nnet_(am_nnet),
3535
trans_model_(trans_model),
3636
opts_(opts),
3737
feat_dim_(input_feats->Dim()),
38-
num_pdfs_(nnet.GetNnet().OutputDim("output")),
38+
num_pdfs_(am_nnet.GetNnet().OutputDim("output")),
3939
begin_frame_(-1) {
4040
KALDI_ASSERT(opts_.max_nnet_batch_size > 0);
41-
log_priors_ = nnet_.Priors();
41+
log_priors_ = am_nnet_.Priors();
4242
KALDI_ASSERT((log_priors_.Dim() == 0 || log_priors_.Dim() == trans_model_.NumPdfs()) &&
4343
"Priors in neural network must match with transition model (if exist).");
4444

45-
ComputeSimpleNnetContext(nnet_.GetNnet(), &left_context_, &right_context_);
45+
ComputeSimpleNnetContext(am_nnet_.GetNnet(), &left_context_, &right_context_);
4646
log_priors_.ApplyLog();
4747
}
4848

4949

5050

51-
BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
51+
BaseFloat DecodableNnet3SimpleOnline::LogLikelihood(int32 frame, int32 index) {
5252
ComputeForFrame(frame);
5353
int32 pdf_id = trans_model_.TransitionIdToPdf(index);
5454
KALDI_ASSERT(frame >= begin_frame_ &&
@@ -57,31 +57,31 @@ BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
5757
}
5858

5959

60-
bool DecodableNnet3Online::IsLastFrame(int32 frame) const {
60+
bool DecodableNnet3SimpleOnline::IsLastFrame(int32 frame) const {
6161
KALDI_ASSERT(false && "Method is not imlemented");
6262
return false;
6363
}
6464

65-
int32 DecodableNnet3Online::NumFramesReady() const {
65+
int32 DecodableNnet3SimpleOnline::NumFramesReady() const {
6666
int32 features_ready = features_->NumFramesReady();
6767
if (features_ready == 0)
6868
return 0;
6969
bool input_finished = features_->IsLastFrame(features_ready - 1);
7070
if (opts_.pad_input) {
7171
// normal case... we'll pad with duplicates of first + last frame to get the
7272
// required left and right context.
73-
if (input_finished) return subsampling(features_ready);
74-
else return std::max<int32>(0, subsampling(features_ready - right_context_));
73+
if (input_finished) return NumSubsampledFrames(features_ready);
74+
else return std::max<int32>(0, NumSubsampledFrames(features_ready - right_context_));
7575
} else {
76-
return std::max<int32>(0, subsampling(features_ready - right_context_ - left_context_));
76+
return std::max<int32>(0, NumSubsampledFrames(features_ready - right_context_ - left_context_));
7777
}
7878
}
7979

80-
int32 DecodableNnet3Online::subsampling(int32 num_frames) const {
80+
int32 DecodableNnet3SimpleOnline::NumSubsampledFrames(int32 num_frames) const {
8181
return (num_frames) / opts_.frame_subsampling_factor;
8282
}
8383

84-
void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
84+
void DecodableNnet3SimpleOnline::ComputeForFrame(int32 subsampled_frame) {
8585
int32 features_ready = features_->NumFramesReady();
8686
bool input_finished = features_->IsLastFrame(features_ready - 1);
8787
KALDI_ASSERT(subsampled_frame >= 0);
@@ -118,13 +118,13 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
118118
features_->GetFrame(t_modified, &row);
119119
}
120120

121-
int32 num_subsampled_frames = subsampling(input_frame_end - input_frame_begin -
121+
int32 num_subsampled_frames = NumSubsampledFrames(input_frame_end - input_frame_begin -
122122
left_context_ - right_context_);
123123
// I'm not checking if the input feature vector is ok.
124124
// It should be done, but I'm not sure if it is the best place.
125125
// Maybe a new "nnet3 feature pipeline"?
126-
int32 mfcc_dim = nnet_.GetNnet().InputDim("input");
127-
int32 ivector_dim = nnet_.GetNnet().InputDim("ivector");
126+
int32 mfcc_dim = am_nnet_.GetNnet().InputDim("input");
127+
int32 ivector_dim = am_nnet_.GetNnet().InputDim("ivector");
128128
// MFCCs in the left chunk
129129
SubMatrix<BaseFloat> mfcc_mat = features.ColRange(0,mfcc_dim);
130130

@@ -143,7 +143,7 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
143143
begin_frame_ = subsampled_frame;
144144
}
145145

146-
void DecodableNnet3Online::DoNnetComputation(
146+
void DecodableNnet3SimpleOnline::DoNnetComputation(
147147
int32 input_t_start,
148148
const MatrixBase<BaseFloat> &input_feats,
149149
const VectorBase<BaseFloat> &ivector,
@@ -182,7 +182,7 @@ void DecodableNnet3Online::DoNnetComputation(
182182
const NnetComputation *computation = compiler_.Compile(request);
183183
Nnet *nnet_to_update = NULL; // we're not doing any update.
184184
NnetComputer computer(opts_.compute_config, *computation,
185-
nnet_.GetNnet(), nnet_to_update);
185+
am_nnet_.GetNnet(), nnet_to_update);
186186

187187
CuMatrix<BaseFloat> input_feats_cu(input_feats);
188188
computer.AcceptInput("input", &input_feats_cu);
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// nnet3/online-nnet3-decodable.h
1+
// nnet3/online-nnet3-decodable-simple.h
22

33
// Copyright 2014 Johns Hopkins Universithy (author: Daniel Povey)
44
// 2016 Api.ai (Author: Ilya Platonov)
@@ -62,9 +62,9 @@ struct DecodableNnet3OnlineOptions {
6262
"frames (this will rarely make a difference)");
6363

6464
opts->Register("frame-subsampling-factor", &frame_subsampling_factor,
65-
"Required if the frame-rate of the output (e.g. in 'chain' "
66-
"models) is less than the frame-rate of the original "
67-
"alignment.");
65+
"Required if the frame-rate of the output (e.g. in 'chain' "
66+
"models) is less than the frame-rate of the original "
67+
"alignment.");
6868

6969
// register the optimization options with the prefix "optimization".
7070
ParseOptions optimization_opts("optimization", opts);
@@ -84,9 +84,9 @@ struct DecodableNnet3OnlineOptions {
8484
feature input from a matrix.
8585
*/
8686

87-
class DecodableNnet3Online: public DecodableInterface {
87+
class DecodableNnet3SimpleOnline: public DecodableInterface {
8888
public:
89-
DecodableNnet3Online(const AmNnetSimple &nnet,
89+
DecodableNnet3SimpleOnline(const AmNnetSimple &am_nnet,
9090
const TransitionModel &trans_model,
9191
const DecodableNnet3OnlineOptions &opts,
9292
OnlineFeatureInterface *input_feats);
@@ -108,7 +108,7 @@ class DecodableNnet3Online: public DecodableInterface {
108108
/// them (and possibly for some succeeding frames)
109109
void ComputeForFrame(int32 frame);
110110
// corrects number of frames by frame_subsampling_factor;
111-
int32 subsampling(int32) const;
111+
int32 NumSubsampledFrames(int32) const;
112112

113113
void DoNnetComputation(
114114
int32 input_t_start,
@@ -120,7 +120,7 @@ class DecodableNnet3Online: public DecodableInterface {
120120
CachingOptimizingCompiler compiler_;
121121

122122
OnlineFeatureInterface *features_;
123-
const AmNnetSimple &nnet_;
123+
const AmNnetSimple &am_nnet_;
124124
const TransitionModel &trans_model_;
125125
DecodableNnet3OnlineOptions opts_;
126126
CuVector<BaseFloat> log_priors_; // log-priors taken from the model.
@@ -143,7 +143,7 @@ class DecodableNnet3Online: public DecodableInterface {
143143
// opts_.max_nnet_batch_size.
144144
Matrix<BaseFloat> scaled_loglikes_;
145145

146-
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3Online);
146+
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableNnet3SimpleOnline);
147147
};
148148

149149
} // namespace nnet3

src/online2/online-nnet3-decoding.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ namespace kaldi {
2727
SingleUtteranceNnet3Decoder::SingleUtteranceNnet3Decoder(
2828
const OnlineNnet3DecodingConfig &config,
2929
const TransitionModel &tmodel,
30-
const nnet3::AmNnetSimple &model,
30+
const nnet3::AmNnetSimple &am_model,
3131
const fst::Fst<fst::StdArc> &fst,
3232
OnlineNnet2FeaturePipeline *feature_pipeline):
3333
config_(config),
3434
feature_pipeline_(feature_pipeline),
3535
tmodel_(tmodel),
36-
decodable_(model, tmodel, config.decodable_opts, feature_pipeline),
36+
decodable_(am_model, tmodel, config.decodable_opts, feature_pipeline),
3737
decoder_(fst, config.decoder_opts) {
3838
decoder_.InitDecoding();
3939
}

src/online2/online-nnet3-decoding.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#include <vector>
2727
#include <deque>
2828

29+
#include "../nnet3/online-nnet3-decodable-simple.h"
2930
#include "matrix/matrix-lib.h"
3031
#include "util/common-utils.h"
3132
#include "base/kaldi-error.h"
32-
#include "nnet3/online-nnet3-decodable.h"
3333
#include "online2/online-nnet2-feature-pipeline.h"
3434
#include "online2/online-endpoint.h"
3535
#include "decoder/lattice-faster-online-decoder.h"
@@ -71,7 +71,7 @@ class SingleUtteranceNnet3Decoder {
7171
// class, it's owned externally.
7272
SingleUtteranceNnet3Decoder(const OnlineNnet3DecodingConfig &config,
7373
const TransitionModel &tmodel,
74-
const nnet3::AmNnetSimple &model,
74+
const nnet3::AmNnetSimple &am_model,
7575
const fst::Fst<fst::StdArc> &fst,
7676
OnlineNnet2FeaturePipeline *feature_pipeline);
7777

@@ -116,7 +116,7 @@ class SingleUtteranceNnet3Decoder {
116116

117117
const TransitionModel &tmodel_;
118118

119-
nnet3::DecodableNnet3Online decodable_;
119+
nnet3::DecodableNnet3SimpleOnline decodable_;
120120

121121
LatticeFasterOnlineDecoder decoder_;
122122

src/online2bin/online2-wav-nnet3-latgen-faster.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ int main(int argc, char *argv[]) {
151151
}
152152

153153
TransitionModel trans_model;
154-
nnet3::AmNnetSimple nnet;
154+
nnet3::AmNnetSimple am_nnet;
155155
{
156156
bool binary;
157157
Input ki(nnet3_rxfilename, &binary);
158158
trans_model.Read(ki.Stream(), binary);
159-
nnet.Read(ki.Stream(), binary);
159+
am_nnet.Read(ki.Stream(), binary);
160160
}
161161

162162
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(fst_rxfilename);
@@ -203,7 +203,7 @@ int main(int argc, char *argv[]) {
203203

204204
SingleUtteranceNnet3Decoder decoder(nnet3_decoding_config,
205205
trans_model,
206-
nnet,
206+
am_nnet,
207207
*decode_fst,
208208
&feature_pipeline);
209209
OnlineTimer decoding_timer(utt);

0 commit comments

Comments
 (0)