Skip to content

Commit 8fab002

Browse files
committed
Merge pull request kaldi-asr#579 from kkm000/arpa-2
Major rewrite of ARPA compiler
2 parents 939b343 + 6be3696 commit 8fab002

26 files changed

+761
-1384
lines changed

src/bin/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11

22
all:
3+
-rm -f arpa2fst
34
EXTRA_CXXFLAGS = -Wno-sign-compare
45
include ../kaldi.mk
56

67
BINFILES = align-equal align-equal-compiled acc-tree-stats \
78
show-alignments compile-questions cluster-phones \
89
compute-wer compute-wer-bootci make-h-transducer \
910
add-self-loops convert-ali \
10-
compile-train-graphs compile-train-graphs-fsts arpa2fst \
11+
compile-train-graphs compile-train-graphs-fsts \
1112
make-pdf-to-tid-transducer make-ilabel-transducer show-transitions \
1213
ali-to-phones ali-to-post weight-silence-post acc-lda est-lda \
1314
ali-to-pdf est-mllt build-tree build-tree-two-level decode-faster \
@@ -37,4 +38,3 @@ ADDLIBS = ../lm/kaldi-lm.a ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a \
3738
TESTFILES =
3839

3940
include ../makefiles/default_rules.mk
40-

src/bin/arpa2fst.cc

Lines changed: 0 additions & 66 deletions
This file was deleted.

src/lm/Makefile

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,13 @@ EXTRA_CXXFLAGS = -Wno-sign-compare
22

33
all:
44

5-
# Disable linking math libs because not needed here. Just for compilation speed.
6-
MATHLIB = NONE
7-
8-
# Uncomment following line to use IRSTLM toolkit installed in ../lmtoolkit
9-
#include ./irstlm.mk
10-
115
include ../kaldi.mk
126

13-
TESTFILES = arpa-file-parser-test lm-lib-test
7+
TESTFILES = arpa-file-parser-test arpa-lm-compiler-test
148

15-
OBJFILES = arpa-file-parser.o const-arpa-lm.o kaldi-lmtable.o kaldi-lm.o \
9+
OBJFILES = arpa-file-parser.o arpa-lm-compiler.o const-arpa-lm.o \
1610
kaldi-rnnlm.o mikolov-rnnlm-lib.o
1711

18-
TESTOUTPUTS = composed.fst output.fst output1.fst output2.fst
19-
2012
LIBNAME = kaldi-lm
2113

2214
ADDLIBS = ../fstext/kaldi-fstext.a ../util/kaldi-util.a ../thread/kaldi-thread.a \

src/lm/arpa-file-parser-test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
#include <string>
2828
#include <sstream>
2929
#include <vector>
30-
#include "lm/kaldi-lm.h"
3130

31+
#include "base/kaldi-common.h"
32+
#include "fst/fstlib.h"
3233
#include "lm/arpa-file-parser.h"
3334

3435
namespace kaldi {

src/lm/arpa-file-parser.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
namespace kaldi {
3131

32-
ArpaFileParser::ArpaFileParser(ArpaParseOptions options, fst::SymbolTable* symbols)
32+
ArpaFileParser::ArpaFileParser(ArpaParseOptions options,
33+
fst::SymbolTable* symbols)
3334
: options_(options), symbols_(symbols), line_number_(0) {
3435
}
3536

src/lm/arpa-file-parser.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct NGram {
6565
/**
6666
ArpaFileParser is an abstract base class for ARPA LM file conversion.
6767
68-
See ConstArpaLmBuilder for a usage example.
68+
See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
6969
*/
7070
class ArpaFileParser {
7171
public:
@@ -85,6 +85,7 @@ class ArpaFileParser {
8585
/// supported.
8686
void Read(std::istream &is, bool binary);
8787

88+
/// Parser options.
8889
const ArpaParseOptions& Options() const { return options_; }
8990

9091
protected:
@@ -104,7 +105,7 @@ class ArpaFileParser {
104105
/// Override function called after the last n-gram has been consumed.
105106
virtual void ReadComplete() { }
106107

107-
/// Read-only access to symbol table.
108+
/// Read-only access to symbol table. Not owned, do not make public.
108109
const fst::SymbolTable* Symbols() const { return symbols_; }
109110

110111
/// Inside ConsumeNGram(), provides the current line number.

src/lm/arpa-lm-compiler-test.cc

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
// lm/arpa-lm-compiler-test.cc
2+
3+
// Copyright 2009-2011 Gilles Boulianne
4+
// Copyright 2016 Smart Action LLC (kkm)
5+
6+
// See ../../COPYING for clarification regarding multiple authors
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16+
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17+
// MERCHANTABLITY OR NON-INFRINGEMENT.
18+
// See the Apache 2 License for the specific language governing permissions and
19+
// limitations under the License.
20+
21+
#include <iostream>
22+
#include <string>
23+
#include <sstream>
24+
25+
#include "base/kaldi-error.h"
26+
#include "base/kaldi-math.h"
27+
#include "lm/arpa-lm-compiler.h"
28+
#include "util/kaldi-io.h"
29+
30+
namespace kaldi {
31+
32+
// Predefine some symbol values, because any integer is as good than any other.
33+
enum {
34+
kEps = 0,
35+
kDisambig,
36+
kBos,kEos,
37+
};
38+
39+
// Number of random sentences for coverage test.
40+
static const int kRandomSentences = 50;
41+
42+
// Creates an FST that generates any sequence of symbols taken from given
43+
// symbol table. The FST is then associated with the symbol table.
44+
static fst::StdVectorFst* CreateGenFst(bool seps, const fst::SymbolTable* pst) {
45+
fst::StdVectorFst* genFst = new fst::StdVectorFst;
46+
genFst->SetInputSymbols(pst);
47+
genFst->SetOutputSymbols(pst);
48+
49+
fst::StdArc::StateId midId = genFst->AddState();
50+
if (!seps) {
51+
fst::StdArc::StateId initId = genFst->AddState();
52+
fst::StdArc::StateId finalId = genFst->AddState();
53+
genFst->SetStart(initId);
54+
genFst->SetFinal(finalId, fst::StdArc::Weight::One());
55+
genFst->AddArc(initId, fst::StdArc(kBos, kBos, 0, midId));
56+
genFst->AddArc(midId, fst::StdArc(kEos, kEos, 0, finalId));
57+
} else {
58+
genFst->SetStart(midId);
59+
genFst->SetFinal(midId, fst::StdArc::Weight::One());
60+
}
61+
62+
// Add a loop for each symbol in the table except the four special ones.
63+
fst::SymbolTableIterator si(*pst);
64+
for (si.Reset(); !si.Done(); si.Next()) {
65+
if (si.Value() == kBos || si.Value() == kEos ||
66+
si.Value() == kEps || si.Value() == kDisambig)
67+
continue;
68+
genFst->AddArc(midId, fst::StdArc(si.Value(), si.Value(),
69+
fst::StdArc::Weight::One(), midId));
70+
}
71+
return genFst;
72+
}
73+
74+
// Compile given ARPA file.
75+
ArpaLmCompiler* Compile(bool seps, const string &infile) {
76+
ArpaParseOptions options;
77+
fst::SymbolTable symbols;
78+
// Use spaces on special symbols, so we rather fail than read them by mistake.
79+
symbols.AddSymbol(" <eps>", kEps);
80+
symbols.AddSymbol(" #0", kDisambig);
81+
options.bos_symbol = symbols.AddSymbol("<s>", kBos);
82+
options.eos_symbol = symbols.AddSymbol("</s>", kEos);
83+
options.oov_handling = ArpaParseOptions::kAddToSymbols;
84+
85+
// Tests in this form cannot be run with epsilon substitution, unless every
86+
// random path is also fitted with a #0-transducing self-loop.
87+
ArpaLmCompiler* lm_compiler =
88+
new ArpaLmCompiler(options,
89+
seps ? kDisambig : 0,
90+
&symbols);
91+
ReadKaldiObject(infile, lm_compiler);
92+
return lm_compiler;
93+
}
94+
95+
// Add a state to an FSA after last_state, add a form last_state to the new
96+
// atate, and return the new state.
97+
fst::StdArc::StateId AddToChainFsa(fst::StdMutableFst* fst,
98+
fst::StdArc::StateId last_state,
99+
int64 symbol) {
100+
fst::StdArc::StateId next_state = fst->AddState();
101+
fst->AddArc(last_state, fst::StdArc(symbol, symbol, 0, next_state));
102+
return next_state;
103+
}
104+
105+
// Add a disambiguator-generating self loop to every state of an FST.
106+
void AddSelfLoops(fst::StdMutableFst* fst) {
107+
for (fst::StateIterator<fst::StdMutableFst> siter(*fst);
108+
!siter.Done(); siter.Next()) {
109+
fst->AddArc(siter.Value(),
110+
fst::StdArc(kEps, kDisambig, 0, siter.Value()));
111+
}
112+
}
113+
114+
// Compiles infile and then runs kRandomSentences random coverage tests on the
115+
// compiled FST.
116+
bool CoverageTest(bool seps, const string &infile) {
117+
// Compile ARPA model.
118+
ArpaLmCompiler* lm_compiler = Compile(seps, infile);
119+
120+
// Create an FST that generates any sequence of symbols taken from the model
121+
// output.
122+
fst::StdVectorFst* genFst =
123+
CreateGenFst(seps, lm_compiler->Fst().OutputSymbols());
124+
125+
int num_successes = 0;
126+
for (int32 i = 0; i < kRandomSentences; ++i) {
127+
// Generate a random sentence FST.
128+
fst::StdVectorFst sentence;
129+
RandGen(*genFst, &sentence);
130+
if (seps)
131+
AddSelfLoops(&sentence);
132+
133+
// The past must successfullycompose with the LM FST.
134+
fst::StdVectorFst composition;
135+
Compose(sentence, lm_compiler->Fst(), &composition);
136+
if (composition.Start() != fst::kNoStateId)
137+
++num_successes;
138+
}
139+
140+
delete genFst;
141+
delete lm_compiler;
142+
143+
bool ok = num_successes == kRandomSentences;
144+
if (!ok) {
145+
KALDI_WARN << "Coverage test failed on " << infile << ": composed "
146+
<< num_successes << "/" << kRandomSentences;
147+
}
148+
return ok;
149+
}
150+
151+
bool ScoringTest(bool seps, const string &infile, const string& sentence,
152+
float expected) {
153+
ArpaLmCompiler* lm_compiler = Compile(seps, infile);
154+
const fst::SymbolTable* symbols = lm_compiler->Fst().InputSymbols();
155+
156+
// Create a sentence FST for scoring.
157+
fst::StdVectorFst sentFst;
158+
fst::StdArc::StateId state = sentFst.AddState();
159+
sentFst.SetStart(state);
160+
if (!seps) {
161+
state = AddToChainFsa(&sentFst, state, kBos);
162+
}
163+
std::stringstream ss(sentence);
164+
string word;
165+
while (ss >> word) {
166+
int64 word_sym = symbols->Find(word);
167+
KALDI_ASSERT(word_sym != -1);
168+
state = AddToChainFsa(&sentFst, state, word_sym);
169+
}
170+
if (!seps) {
171+
state = AddToChainFsa(&sentFst, state, kEos);
172+
}
173+
if (seps) {
174+
AddSelfLoops(&sentFst);
175+
}
176+
sentFst.SetFinal(state, 0);
177+
sentFst.SetOutputSymbols(symbols);
178+
179+
// Do the composition and extract final weight.
180+
fst::StdVectorFst composed;
181+
fst::Compose(sentFst, lm_compiler->Fst(), &composed);
182+
delete lm_compiler;
183+
184+
if (composed.Start() == fst::kNoStateId) {
185+
KALDI_WARN << "Test sentence " << sentence << " did not compose "
186+
<< "with the language model FST\n";
187+
return false;
188+
}
189+
190+
std::vector<fst::StdArc::Weight> shortest;
191+
fst::ShortestDistance(composed, &shortest, true);
192+
float actual = shortest[composed.Start()].Value();
193+
194+
bool ok = ApproxEqual(expected, actual);
195+
if (!ok) {
196+
KALDI_WARN << "Scored " << sentence << " in " << infile
197+
<< ": Expected=" << expected << " actual=" << actual;
198+
}
199+
return ok;
200+
}
201+
202+
} // namespace kaldi
203+
204+
bool RunAllTests(bool seps) {
205+
bool ok = true;
206+
ok &= kaldi::CoverageTest(seps, "test_data/missing_backoffs.arpa");
207+
ok &= kaldi::CoverageTest(seps, "test_data/unused_backoffs.arpa");
208+
ok &= kaldi::CoverageTest(seps, "test_data/input.arpa");
209+
210+
ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "b b b a", 59.2649);
211+
ok &= kaldi::ScoringTest(seps, "test_data/input.arpa", "a b", 4.36082);
212+
if (!ok) {
213+
KALDI_WARN << "Tests " << (seps ? "with" : "without")
214+
<< " epsilon substitution FAILED";
215+
}
216+
return ok;
217+
}
218+
219+
int main(int argc, char *argv[]) {
220+
bool ok = true;
221+
222+
ok &= RunAllTests(false); // Without disambiguators (old behavior).
223+
ok &= RunAllTests(true); // With epsilon substitution (new behavior).
224+
225+
if (ok) {
226+
KALDI_LOG << "All tests passed";
227+
return 0;
228+
}
229+
else {
230+
KALDI_WARN << "Test FAILED";
231+
return 1;
232+
}
233+
}

0 commit comments

Comments
 (0)