Skip to content

Commit a25d92c

Browse files
author
Ilya Platonov
committed
Adding "dirty" nnet3 online decoder based on nnet2 code.
No ivector supported (but easy to add). Tested on chain model only.
1 parent b5665cc commit a25d92c

File tree

9 files changed

+991
-4
lines changed

9 files changed

+991
-4
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#!/bin/bash
2+
3+
# Copyright 2014 Johns Hopkins University (Author: Daniel Povey)
4+
# Apache 2.0
5+
6+
# Begin configuration section.
7+
stage=0
8+
nj=4
9+
cmd=run.pl
10+
max_active=7000
11+
threaded=false
12+
modify_ivector_config=false # only relevant to threaded decoder.
13+
beam=15.0
14+
lattice_beam=6.0
15+
acwt=0.1 # note: only really affects adaptation and pruning (scoring is on
16+
# lattices).
17+
post_decode_acwt=1.0 # can be used in 'chain' systems to scale acoustics by 10 so the
18+
# regular scoring script works.
19+
per_utt=false
20+
online=true # only relevant to non-threaded decoder.
21+
do_endpointing=false
22+
do_speex_compressing=false
23+
scoring_opts=
24+
skip_scoring=false
25+
silence_weight=1.0 # set this to a value less than 1 (e.g. 0) to enable silence weighting.
26+
max_state_duration=40 # This only has an effect if you are doing silence
27+
# weighting. This default is probably reasonable. transition-ids repeated
28+
# more than this many times in an alignment are treated as silence.
29+
iter=final
30+
# End configuration section.
31+
32+
echo "$0 $@" # Print the command line for logging
33+
34+
[ -f ./path.sh ] && . ./path.sh; # source the path.
35+
. parse_options.sh || exit 1;
36+
37+
if [ $# != 3 ]; then
38+
echo "Usage: $0 [options] <graph-dir> <data-dir> <decode-dir>"
39+
echo "... where <decode-dir> is assumed to be a sub-directory of the directory"
40+
echo " where the models are, as prepared by steps/online/nnet2/prepare_online_decoding.sh"
41+
echo "e.g.: $0 exp/tri3b/graph data/test exp/tri3b_online/decode/"
42+
echo ""
43+
echo ""
44+
echo "main options (for others, see top of script file)"
45+
echo " --config <config-file> # config containing options"
46+
echo " --nj <nj> # number of parallel jobs"
47+
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
48+
echo " --acwt <float> # acoustic scale used for lattice generation "
49+
echo " --per-utt <true|false> # If true, decode per utterance without"
50+
echo " # carrying forward adaptation info from previous"
51+
echo " # utterances of each speaker. Default: false"
52+
echo " --online <true|false> # Set this to false if you don't really care about"
53+
echo " # simulating online decoding and just want the best"
54+
echo " # results. This will use all the data within each"
55+
echo " # utterance (plus any previous utterance, if not in"
56+
echo " # per-utterance mode) to estimate the iVectors."
57+
echo " --scoring-opts <string> # options to local/score.sh"
58+
echo " --iter <iter> # Iteration of model to decode; default is final."
59+
exit 1;
60+
fi
61+
62+
63+
graphdir=$1
64+
data=$2
65+
dir=$3
66+
srcdir=`dirname $dir`; # The model directory is one level up from decoding directory.
67+
sdata=$data/split$nj;
68+
69+
mkdir -p $dir/log
70+
[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1;
71+
echo $nj > $dir/num_jobs
72+
73+
for f in $srcdir/conf/online_nnet2_decoding.conf $srcdir/${iter}.mdl \
74+
$graphdir/HCLG.fst $graphdir/words.txt $data/wav.scp; do
75+
if [ ! -f $f ]; then
76+
echo "$0: no such file $f"
77+
exit 1;
78+
fi
79+
done
80+
81+
if ! $per_utt; then
82+
spk2utt_rspecifier="ark:$sdata/JOB/spk2utt"
83+
else
84+
mkdir -p $dir/per_utt
85+
for j in $(seq $nj); do
86+
awk '{print $1, $1}' <$sdata/$j/utt2spk >$dir/per_utt/utt2spk.$j || exit 1;
87+
done
88+
spk2utt_rspecifier="ark:$dir/per_utt/utt2spk.JOB"
89+
fi
90+
91+
if [ -f $data/segments ]; then
92+
wav_rspecifier="ark,s,cs:extract-segments scp,p:$sdata/JOB/wav.scp $sdata/JOB/segments ark:- |"
93+
else
94+
wav_rspecifier="ark,s,cs:wav-copy scp,p:$sdata/JOB/wav.scp ark:- |"
95+
fi
96+
if $do_speex_compressing; then
97+
wav_rspecifier="$wav_rspecifier compress-uncompress-speex ark:- ark:- |"
98+
fi
99+
if $do_endpointing; then
100+
wav_rspecifier="$wav_rspecifier extend-wav-with-silence ark:- ark:- |"
101+
fi
102+
103+
if [ "$silence_weight" != "1.0" ]; then
104+
silphones=$(cat $graphdir/phones/silence.csl) || exit 1
105+
silence_weighting_opts="--ivector-silence-weighting.max-state-duration=$max_state_duration --ivector-silence-weighting.silence_phones=$silphones --ivector-silence-weighting.silence-weight=$silence_weight"
106+
else
107+
silence_weighting_opts=
108+
fi
109+
110+
111+
if $threaded; then
112+
decoder=online2-wav-nnet2-latgen-threaded
113+
# note: the decoder actually uses 4 threads, but the average usage will normally
114+
# be more like 2.
115+
parallel_opts="--num-threads 2"
116+
opts="--modify-ivector-config=$modify_ivector_config --verbose=1"
117+
else
118+
decoder=online2-wav-nnet3-latgen-faster
119+
parallel_opts=
120+
opts="--online=$online"
121+
fi
122+
123+
if [ "$post_decode_acwt" == 1.0 ]; then
124+
lat_wspecifier="ark:|gzip -c >$dir/lat.JOB.gz"
125+
else
126+
lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz"
127+
fi
128+
129+
130+
if [ -f $srcdir/frame_subsampling_factor ]; then
131+
# e.g. for 'chain' systems
132+
frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)"
133+
fi
134+
135+
if [ $stage -le 0 ]; then
136+
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
137+
$decoder $opts $silence_weighting_opts --do-endpointing=$do_endpointing $frame_subsampling_opt \
138+
--config=$srcdir/conf/online_nnet2_decoding.conf \
139+
--max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
140+
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \
141+
$srcdir/${iter}.mdl $graphdir/HCLG.fst $spk2utt_rspecifier "$wav_rspecifier" \
142+
$lat_wspecifier || exit 1;
143+
fi
144+
145+
if ! $skip_scoring ; then
146+
[ ! -x local/score.sh ] && \
147+
echo "Not scoring because local/score.sh does not exist or not executable." && exit 1;
148+
local/score.sh --cmd "$cmd" $scoring_opts $data $graphdir $dir
149+
fi
150+
151+
exit 0;

src/nnet3/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
2424
nnet-example-utils.o nnet-training.o \
2525
nnet-diagnostics.o nnet-combine.o nnet-am-decodable-simple.o \
2626
nnet-optimize-utils.o nnet-chain-example.o \
27-
nnet-chain-training.o nnet-chain-diagnostics.o nnet-chain-combine.o
27+
nnet-chain-training.o nnet-chain-diagnostics.o nnet-chain-combine.o \
28+
online-nnet3-decodable.o
2829

2930
LIBNAME = kaldi-nnet3
3031

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// nnet3/online-nnet3-decodable.cc
2+
3+
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4+
5+
// See ../../COPYING for clarification regarding multiple authors
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16+
// MERCHANTABLITY OR NON-INFRINGEMENT.
17+
// See the Apache 2 License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
#include "nnet3/online-nnet3-decodable.h"
21+
#include "nnet3/nnet-utils.h"
22+
23+
namespace kaldi {
24+
namespace nnet3 {
25+
26+
DecodableNnet3Online::DecodableNnet3Online(
27+
const AmNnetSimple &nnet,
28+
const TransitionModel &trans_model,
29+
const DecodableNnet3OnlineOptions &opts,
30+
OnlineFeatureInterface *input_feats):
31+
compiler_(nnet.GetNnet(), opts_.optimize_config),
32+
features_(input_feats),
33+
nnet_(nnet),
34+
trans_model_(trans_model),
35+
opts_(opts),
36+
feat_dim_(input_feats->Dim()),
37+
num_pdfs_(nnet.GetNnet().OutputDim("output")),
38+
begin_frame_(-1) {
39+
KALDI_ASSERT(opts_.max_nnet_batch_size > 0);
40+
log_priors_ = nnet_.Priors();
41+
KALDI_ASSERT((log_priors_.Dim() == 0 || log_priors_.Dim() == trans_model_.NumPdfs()) &&
42+
"Priors in neural network must match with transition model (if exist).");
43+
44+
ComputeSimpleNnetContext(nnet_.GetNnet(), &left_context_, &right_context_);
45+
log_priors_.ApplyLog();
46+
}
47+
48+
49+
50+
BaseFloat DecodableNnet3Online::LogLikelihood(int32 frame, int32 index) {
51+
ComputeForFrame(frame);
52+
int32 pdf_id = trans_model_.TransitionIdToPdf(index);
53+
KALDI_ASSERT(frame >= begin_frame_ &&
54+
frame < begin_frame_ + scaled_loglikes_.NumRows());
55+
return scaled_loglikes_(frame - begin_frame_, pdf_id);
56+
}
57+
58+
59+
bool DecodableNnet3Online::IsLastFrame(int32 frame) const {
60+
KALDI_ASSERT(false && "Method is not imlemented");
61+
return false;
62+
}
63+
64+
int32 DecodableNnet3Online::NumFramesReady() const {
65+
int32 features_ready = features_->NumFramesReady();
66+
if (features_ready == 0)
67+
return 0;
68+
bool input_finished = features_->IsLastFrame(features_ready - 1);
69+
if (opts_.pad_input) {
70+
// normal case... we'll pad with duplicates of first + last frame to get the
71+
// required left and right context.
72+
if (input_finished) return subsampling(features_ready);
73+
else return std::max<int32>(0, subsampling(features_ready - right_context_));
74+
} else {
75+
return std::max<int32>(0, subsampling(features_ready - right_context_ - left_context_));
76+
}
77+
}
78+
79+
int32 DecodableNnet3Online::subsampling(int32 num_frames) const {
80+
return (num_frames) / opts_.frame_subsampling_factor;
81+
}
82+
83+
void DecodableNnet3Online::ComputeForFrame(int32 subsampled_frame) {
84+
int32 features_ready = features_->NumFramesReady();
85+
bool input_finished = features_->IsLastFrame(features_ready - 1);
86+
KALDI_ASSERT(subsampled_frame >= 0);
87+
if (subsampled_frame >= begin_frame_ &&
88+
subsampled_frame < begin_frame_ + scaled_loglikes_.NumRows())
89+
return;
90+
KALDI_ASSERT(subsampled_frame < NumFramesReady());
91+
92+
int32 subsample = opts_.frame_subsampling_factor;
93+
94+
int32 input_frame_begin;
95+
if (opts_.pad_input)
96+
input_frame_begin = subsampled_frame * subsample - left_context_;
97+
else
98+
input_frame_begin = subsampled_frame * subsample;
99+
int32 max_possible_input_frame_end = features_ready /* - ( features_ready - right_context_) % subsample */;
100+
if (input_finished && opts_.pad_input)
101+
max_possible_input_frame_end += right_context_;
102+
int32 input_frame_end = std::min<int32>(max_possible_input_frame_end,
103+
input_frame_begin +
104+
left_context_ + right_context_ +
105+
opts_.max_nnet_batch_size);
106+
KALDI_ASSERT(input_frame_end > input_frame_begin);
107+
Matrix<BaseFloat> features(input_frame_end - input_frame_begin,
108+
feat_dim_);
109+
for (int32 t = input_frame_begin; t < input_frame_end; t++) {
110+
SubVector<BaseFloat> row(features, t - input_frame_begin);
111+
int32 t_modified = t;
112+
// The next two if-statements take care of "pad_input"
113+
if (t_modified < 0)
114+
t_modified = 0;
115+
if (t_modified >= features_ready)
116+
t_modified = features_ready - 1;
117+
features_->GetFrame(t_modified, &row);
118+
}
119+
120+
int32 num_subsampled_frames = subsampling(input_frame_end - input_frame_begin -
121+
left_context_ - right_context_);
122+
DoNnetComputation(input_frame_begin,
123+
features, subsampled_frame * subsample, num_subsampled_frames);
124+
125+
begin_frame_ = subsampled_frame;
126+
}
127+
128+
void DecodableNnet3Online::DoNnetComputation(
129+
int32 input_t_start,
130+
const MatrixBase<BaseFloat> &input_feats,
131+
int32 output_t_start,
132+
int32 num_subsampled_frames) {
133+
ComputationRequest request;
134+
request.need_model_derivative = false;
135+
request.store_component_stats = false;
136+
137+
bool shift_time = true; // shift the 'input' and 'output' to a consistent
138+
// time, to take advantage of caching in the compiler.
139+
// An optimization.
140+
int32 time_offset = (shift_time ? -output_t_start : 0);
141+
142+
// First add the regular features-- named "input".
143+
request.inputs.reserve(2);
144+
request.inputs.push_back(
145+
IoSpecification("input", time_offset + input_t_start,
146+
time_offset + input_t_start + input_feats.NumRows()));
147+
IoSpecification output_spec;
148+
output_spec.name = "output";
149+
output_spec.has_deriv = false;
150+
int32 subsample = opts_.frame_subsampling_factor;
151+
output_spec.indexes.resize(num_subsampled_frames);
152+
// leave n and x values at 0 (the constructor sets these).
153+
for (int32 i = 0; i < num_subsampled_frames; i++)
154+
output_spec.indexes[i].t = time_offset + output_t_start + i * subsample;
155+
request.outputs.resize(1);
156+
request.outputs[0].Swap(&output_spec);
157+
158+
const NnetComputation *computation = compiler_.Compile(request);
159+
Nnet *nnet_to_update = NULL; // we're not doing any update.
160+
NnetComputer computer(opts_.compute_config, *computation,
161+
nnet_.GetNnet(), nnet_to_update);
162+
163+
CuMatrix<BaseFloat> input_feats_cu(input_feats);
164+
computer.AcceptInput("input", &input_feats_cu);
165+
CuMatrix<BaseFloat> ivector_feats_cu;
166+
computer.Forward();
167+
CuMatrix<BaseFloat> cu_output;
168+
computer.GetOutputDestructive("output", &cu_output);
169+
// subtract log-prior (divide by prior)
170+
if (log_priors_.Dim() != 0)
171+
cu_output.AddVecToRows(-1.0, log_priors_);
172+
// apply the acoustic scale
173+
cu_output.Scale(opts_.acoustic_scale);
174+
scaled_loglikes_.Resize(0, 0);
175+
// the following statement just swaps the pointers if we're not using a GPU.
176+
cu_output.Swap(&scaled_loglikes_);
177+
// current_log_post_subsampled_offset_ = output_t_start / subsample;
178+
}
179+
180+
} // namespace nnet3
181+
} // namespace kaldi

0 commit comments

Comments
 (0)