Skip to content

Commit 59736ea

Browse files
committed
trunk: add multi-threaded online-nnet2 decoding program, online2-wav-nnet2-latgen-threaded, which does decoding and nnet evaluation in different threads. Usage is otherwise similar to online2-wav-nnet2-latgen-faster.
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4844 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
1 parent 41a8f9b commit 59736ea

34 files changed

+1839
-88
lines changed

egs/librispeech/s5/local/online/run_nnet2.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,26 @@ if [ $stage -le 13 ]; then
148148
fi
149149

150150
exit 0;
151+
###### Comment out the "exit 0" above to run the multi-threaded decoding. #####
152+
153+
if [ $stage -le 14 ]; then
154+
# Demonstrate the multi-threaded decoding.
155+
# put back the pp
156+
test=dev_clean
157+
steps/online/nnet2/decode.sh --threaded true \
158+
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
159+
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
160+
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded || exit 1;
161+
fi
162+
163+
if [ $stage -le 15 ]; then
164+
# Demonstrate the multi-threaded decoding with endpointing.
165+
# put back the pp
166+
test=dev_clean
167+
steps/online/nnet2/decode.sh --threaded true --do-endpointing true \
168+
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
169+
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
170+
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded_ep || exit 1;
171+
fi
172+
173+
exit 0;

egs/sprakbanken/s5/run.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ utils/mkgraph.sh data/lang_test_4g exp/tri3b exp/tri3b/graph_4g || exit 1;
151151
steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 7 \
152152
exp/tri3b/graph_4g data/test1k exp/tri3b/decode_4g_test1k || exit 1;
153153

154-
# Train RNN for reranking
155-
local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
156-
# Consumes a lot of memory! Do not run in parallel
157-
local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
154+
# This is commented out for now as it's not important for the main recipe.
155+
## Train RNN for reranking
156+
#local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
157+
## Consumes a lot of memory! Do not run in parallel
158+
#local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
158159

159160

160161
# From 3b system

egs/wsj/s5/steps/online/nnet2/decode.sh

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ stage=0
88
nj=4
99
cmd=run.pl
1010
max_active=7000
11+
threaded=false
12+
modify_ivector_config=false # only relevant to threaded decoder.
1113
beam=15.0
1214
lattice_beam=6.0
1315
acwt=0.1 # note: only really affects adaptation and pruning (scoring is on
1416
# lattices).
1517
per_utt=false
16-
online=true
18+
online=true # only relevant to non-threaded decoder.
1719
do_endpointing=false
1820
do_speex_compressing=false
1921
scoring_opts=
@@ -92,9 +94,23 @@ if $do_endpointing; then
9294
wav_rspecifier="$wav_rspecifier extend-wav-with-silence ark:- ark:- |"
9395
fi
9496

97+
98+
99+
if $threaded; then
100+
decoder=online2-wav-nnet2-latgen-threaded
101+
# note: the decoder actually uses 4 threads, but the average usage will normally
102+
# be more like 2.
103+
parallel_opts="--num-threads 2"
104+
opts="--modify-ivector-config=$modify_ivector_config --verbose=1"
105+
else
106+
decoder=online2-wav-nnet2-latgen-faster
107+
parallel_opts=
108+
opts="--online=$online"
109+
fi
110+
95111
if [ $stage -le 0 ]; then
96-
$cmd JOB=1:$nj $dir/log/decode.JOB.log \
97-
online2-wav-nnet2-latgen-faster --online=$online --do-endpointing=$do_endpointing \
112+
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
113+
$decoder $opts --do-endpointing=$do_endpointing \
98114
--config=$srcdir/conf/online_nnet2_decoding.conf \
99115
--max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
100116
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \

src/base/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ all:
33

44
include ../kaldi.mk
55

6-
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test
6+
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test timer-test
77

88
OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o
99

1010
LIBNAME = kaldi-base
1111

12-
ADDLIBS =
12+
ADDLIBS =
1313

1414
include ../makefiles/default_rules.mk
1515

src/base/kaldi-error.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,19 @@ std::string KaldiGetStackTrace() {
128128

129129
void KaldiAssertFailure_(const char *func, const char *file,
130130
int32 line, const char *cond_str) {
131-
std::cerr << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
132-
<< GetShortFileName(file)
133-
<< ':' << line << ", failed: " << cond_str << '\n';
131+
std::ostringstream ss;
132+
ss << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
133+
<< GetShortFileName(file)
134+
<< ':' << line << ", failed: " << cond_str << '\n';
134135
#ifdef HAVE_EXECINFO_H
135-
std::cerr << "Stack trace is:\n" << KaldiGetStackTrace();
136+
ss << "Stack trace is:\n" << KaldiGetStackTrace();
136137
#endif
138+
std::cerr << ss.str();
137139
std::cerr.flush();
138-
abort(); // Will later throw instead if needed.
140+
// We used to call abort() here, but switch to throwing an exception
141+
// (like KALDI_ERR) because it's easier to deal with in multi-threaded
142+
// code.
143+
throw std::runtime_error(ss.str());
139144
}
140145

141146

src/base/kaldi-utils.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
#include <string>
2020
#include "base/kaldi-common.h"
2121

22+
23+
#ifdef _WIN32_WINNT_WIN8
24+
#include <Synchapi.h>
25+
#elif defined (_WIN32) || defined(_MSC_VER) || defined(MINGW)
26+
#include <Windows.h>
27+
#else
28+
#include <unistd.h>
29+
#endif
30+
2231
namespace kaldi {
2332

2433
std::string CharToString(const char &c) {
@@ -30,4 +39,12 @@ std::string CharToString(const char &c) {
3039
return (std::string) buf;
3140
}
3241

42+
void Sleep(float seconds) {
43+
#if defined(_MSC_VER) || defined(MINGW)
44+
::Sleep(static_cast<int>(seconds * 1000.0));
45+
#else
46+
usleep(static_cast<int>(seconds * 1000000.0));
47+
#endif
48+
}
49+
3350
} // end namespace kaldi

src/base/kaldi-utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ inline int MachineIsLittleEndian() {
7878
return (*reinterpret_cast<char*>(&check) != 0);
7979
}
8080

81+
// This function kaldi::Sleep() provides a portable way to sleep for a possibly fractional
82+
// number of seconds. On Windows it's only accurate to microseconds.
83+
void Sleep(float seconds);
84+
8185
}
8286

8387
#define KALDI_SWAP8(a) { \
@@ -108,7 +112,7 @@ inline int MachineIsLittleEndian() {
108112
template<bool B> class KaldiCompileTimeAssert { };
109113
template<> class KaldiCompileTimeAssert<true> {
110114
public:
111-
static inline void Check() { }
115+
static inline void Check() { }
112116
};
113117

114118
#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check()

src/base/timer-test.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// base/timer-test.cc
22

33
// Copyright 2009-2011 Microsoft Corporation
4+
// 2014 Johns Hopkins University (author: Daniel Povey)
45

56
// See ../../COPYING for clarification regarding multiple authors
67
//
@@ -19,28 +20,27 @@
1920

2021
#include "base/timer.h"
2122
#include "base/kaldi-common.h"
22-
23+
#include "base/kaldi-utils.h"
2324

2425

2526
namespace kaldi {
2627

2728
void TimerTest() {
28-
29+
float time_secs = 0.025 * (rand() % 10);
30+
std::cout << "target is " << time_secs << "\n";
2931
Timer timer;
30-
#if defined(_MSC_VER) || defined(MINGW)
31-
Sleep(1000);
32-
#else
33-
sleep(1);
34-
#endif
32+
Sleep(time_secs);
3533
BaseFloat f = timer.Elapsed();
36-
std::cout << "time is " << f;
37-
KALDI_ASSERT(fabs(1.0 - f) < 0.1);
34+
std::cout << "time is " << f << std::endl;
35+
if (fabs(time_secs - f) > 0.05)
36+
KALDI_ERR << "Timer fail: waited " << f << " seconds instead of "
37+
<< time_secs << " secs.";
3838
}
3939

4040
}
4141

4242

4343
int main() {
44-
kaldi::TimerTest();
44+
for (int i = 0; i < 4; i++)
45+
kaldi::TimerTest();
4546
}
46-

src/base/timer.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
#ifndef KALDI_BASE_TIMER_H_
2020
#define KALDI_BASE_TIMER_H_
2121

22-
#if defined(_MSC_VER) || defined(MINGW)
23-
2422
#include "base/kaldi-utils.h"
23+
// Note: Sleep(float secs) is included in base/kaldi-utils.h.
24+
25+
26+
#if defined(_MSC_VER) || defined(MINGW)
2527

2628
namespace kaldi
2729
{

src/decoder/decodable-matrix.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,88 @@ class DecodableMatrixScaledMapped: public DecodableInterface {
8383
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaledMapped);
8484
};
8585

86+
/**
87+
This decodable class returns log-likes stored in a matrix; it supports
88+
repeatedly writing to the matrix and setting a time-offset representing the
89+
frame-index of the first row of the matrix. It's intended for use in
90+
multi-threaded decoding; mutex and semaphores are not included. External
91+
code will call SetLoglikes() each time more log-likelihods are available.
92+
If you try to access a log-likelihood that's no longer available because
93+
the frame index is less than the current offset, it is of course an error.
94+
*/
95+
class DecodableMatrixMappedOffset: public DecodableInterface {
96+
public:
97+
DecodableMatrixMappedOffset(const TransitionModel &tm):
98+
trans_model_(tm), frame_offset_(0), input_is_finished_(false) { }
99+
100+
101+
102+
virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); }
103+
104+
// this is not part of the generic Decodable interface.
105+
int32 FirstAvailableFrame() { return frame_offset_; }
106+
107+
// This function is destructive of the input "loglikes" because it may
108+
// under some circumstances do a shallow copy using Swap(). This function
109+
// appends loglikes to any existing likelihoods you've previously supplied.
110+
// frames_to_discard, if nonzero, will discard that number of previously
111+
// available frames, from the left, advancing FirstAvailableFrame() by
112+
// a number equal to frames_to_discard. You should only set frames_to_discard
113+
// to nonzero if you know your decoder won't want to access the loglikes
114+
// for older frames.
115+
void AcceptLoglikes(Matrix<BaseFloat> *loglikes,
116+
int32 frames_to_discard) {
117+
if (loglikes->NumRows() == 0) return;
118+
KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs());
119+
KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() &&
120+
frames_to_discard >= 0);
121+
if (frames_to_discard == loglikes_.NumRows()) {
122+
loglikes_.Swap(loglikes);
123+
loglikes->Resize(0, 0);
124+
} else {
125+
int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard,
126+
new_num_rows = old_rows_kept + loglikes->NumRows();
127+
Matrix<BaseFloat> new_loglikes(new_num_rows, loglikes->NumCols());
128+
new_loglikes.RowRange(0, old_rows_kept).CopyFromMat(
129+
loglikes_.RowRange(frames_to_discard, old_rows_kept));
130+
new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat(
131+
*loglikes);
132+
loglikes_.Swap(&new_loglikes);
133+
}
134+
frame_offset_ += frames_to_discard;
135+
}
136+
137+
void InputIsFinished() { input_is_finished_ = true; }
138+
139+
virtual int32 NumFramesReady() const {
140+
return loglikes_.NumRows() + frame_offset_;
141+
}
142+
143+
virtual bool IsLastFrame(int32 frame) const {
144+
KALDI_ASSERT(frame < NumFramesReady());
145+
return (frame == NumFramesReady() - 1 && input_is_finished_);
146+
}
147+
148+
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
149+
int32 index = frame - frame_offset_;
150+
KALDI_ASSERT(index >= 0 && index < loglikes_.NumRows());
151+
return loglikes_(index, trans_model_.TransitionIdToPdf(tid));
152+
}
153+
154+
155+
156+
virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); }
157+
158+
// nothing special to do in destructor.
159+
virtual ~DecodableMatrixMappedOffset() { }
160+
private:
161+
const TransitionModel &trans_model_; // for tid to pdf mapping
162+
Matrix<BaseFloat> loglikes_;
163+
int32 frame_offset_;
164+
bool input_is_finished_;
165+
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMappedOffset);
166+
};
167+
86168

87169
class DecodableMatrixScaled: public DecodableInterface {
88170
public:

0 commit comments

Comments
 (0)