1818// See the Apache 2 License for the specific language governing permissions and
1919// limitations under the License.
2020
21- #include " nnet3/online-nnet3-decodable.h "
21+ #include < nnet3/online-nnet3-decodable-simple.h >
2222#include " nnet3/nnet-utils.h"
2323
2424namespace kaldi {
2525namespace nnet3 {
2626
27- DecodableNnet3Online::DecodableNnet3Online (
28- const AmNnetSimple &nnet ,
27+ DecodableNnet3SimpleOnline::DecodableNnet3SimpleOnline (
28+ const AmNnetSimple &am_nnet ,
2929 const TransitionModel &trans_model,
3030 const DecodableNnet3OnlineOptions &opts,
3131 OnlineFeatureInterface *input_feats):
32- compiler_ (nnet .GetNnet(), opts.optimize_config),
32+ compiler_ (am_nnet .GetNnet(), opts.optimize_config),
3333 features_ (input_feats),
34- nnet_ (nnet ),
34+ am_nnet_ (am_nnet ),
3535 trans_model_ (trans_model),
3636 opts_ (opts),
3737 feat_dim_ (input_feats->Dim ()),
38- num_pdfs_(nnet .GetNnet().OutputDim(" output" )),
38+ num_pdfs_(am_nnet .GetNnet().OutputDim(" output" )),
3939 begin_frame_(-1 ) {
4040 KALDI_ASSERT (opts_.max_nnet_batch_size > 0 );
41- log_priors_ = nnet_ .Priors ();
41+ log_priors_ = am_nnet_ .Priors ();
4242 KALDI_ASSERT ((log_priors_.Dim () == 0 || log_priors_.Dim () == trans_model_.NumPdfs ()) &&
4343 " Priors in neural network must match with transition model (if exist)." );
4444
45- ComputeSimpleNnetContext (nnet_ .GetNnet (), &left_context_, &right_context_);
45+ ComputeSimpleNnetContext (am_nnet_ .GetNnet (), &left_context_, &right_context_);
4646 log_priors_.ApplyLog ();
4747}
4848
4949
5050
51- BaseFloat DecodableNnet3Online ::LogLikelihood (int32 frame, int32 index) {
51+ BaseFloat DecodableNnet3SimpleOnline ::LogLikelihood (int32 frame, int32 index) {
5252 ComputeForFrame (frame);
5353 int32 pdf_id = trans_model_.TransitionIdToPdf (index);
5454 KALDI_ASSERT (frame >= begin_frame_ &&
@@ -57,31 +57,31 @@ BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
5757}
5858
5959
60- bool DecodableNnet3Online ::IsLastFrame (int32 frame) const {
60+ bool DecodableNnet3SimpleOnline ::IsLastFrame (int32 frame) const {
6161 KALDI_ASSERT (false && " Method is not imlemented" );
6262 return false ;
6363}
6464
65- int32 DecodableNnet3Online ::NumFramesReady () const {
65+ int32 DecodableNnet3SimpleOnline ::NumFramesReady () const {
6666 int32 features_ready = features_->NumFramesReady ();
6767 if (features_ready == 0 )
6868 return 0 ;
6969 bool input_finished = features_->IsLastFrame (features_ready - 1 );
7070 if (opts_.pad_input ) {
7171 // normal case... we'll pad with duplicates of first + last frame to get the
7272 // required left and right context.
73- if (input_finished) return subsampling (features_ready);
74- else return std::max<int32>(0 , subsampling (features_ready - right_context_));
73+ if (input_finished) return NumSubsampledFrames (features_ready);
74+ else return std::max<int32>(0 , NumSubsampledFrames (features_ready - right_context_));
7575 } else {
76- return std::max<int32>(0 , subsampling (features_ready - right_context_ - left_context_));
76+ return std::max<int32>(0 , NumSubsampledFrames (features_ready - right_context_ - left_context_));
7777 }
7878}
7979
80- int32 DecodableNnet3Online::subsampling (int32 num_frames) const {
80+ int32 DecodableNnet3SimpleOnline::NumSubsampledFrames (int32 num_frames) const {
8181 return (num_frames) / opts_.frame_subsampling_factor ;
8282}
8383
84- void DecodableNnet3Online ::ComputeForFrame (int32 subsampled_frame) {
84+ void DecodableNnet3SimpleOnline ::ComputeForFrame (int32 subsampled_frame) {
8585 int32 features_ready = features_->NumFramesReady ();
8686 bool input_finished = features_->IsLastFrame (features_ready - 1 );
8787 KALDI_ASSERT (subsampled_frame >= 0 );
@@ -118,13 +118,13 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
118118 features_->GetFrame (t_modified, &row);
119119 }
120120
121- int32 num_subsampled_frames = subsampling (input_frame_end - input_frame_begin -
121+ int32 num_subsampled_frames = NumSubsampledFrames (input_frame_end - input_frame_begin -
122122 left_context_ - right_context_);
123123 // I'm not checking if the input feature vector is ok.
124124 // It should be done, but I'm not sure if it is the best place.
125125 // Maybe a new "nnet3 feature pipeline"?
126- int32 mfcc_dim = nnet_ .GetNnet ().InputDim (" input" );
127- int32 ivector_dim = nnet_ .GetNnet ().InputDim (" ivector" );
126+ int32 mfcc_dim = am_nnet_ .GetNnet ().InputDim (" input" );
127+ int32 ivector_dim = am_nnet_ .GetNnet ().InputDim (" ivector" );
128128 // MFCCs in the left chunk
129129 SubMatrix<BaseFloat> mfcc_mat = features.ColRange (0 ,mfcc_dim);
130130
@@ -143,7 +143,7 @@ void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
143143 begin_frame_ = subsampled_frame;
144144}
145145
146- void DecodableNnet3Online ::DoNnetComputation (
146+ void DecodableNnet3SimpleOnline ::DoNnetComputation (
147147 int32 input_t_start,
148148 const MatrixBase<BaseFloat> &input_feats,
149149 const VectorBase<BaseFloat> &ivector,
@@ -182,7 +182,7 @@ void DecodableNnet3Online::DoNnetComputation(
182182 const NnetComputation *computation = compiler_.Compile (request);
183183 Nnet *nnet_to_update = NULL ; // we're not doing any update.
184184 NnetComputer computer (opts_.compute_config , *computation,
185- nnet_ .GetNnet (), nnet_to_update);
185+ am_nnet_ .GetNnet (), nnet_to_update);
186186
187187 CuMatrix<BaseFloat> input_feats_cu (input_feats);
188188 computer.AcceptInput (" input" , &input_feats_cu);
0 commit comments