Skip to content

Commit e014051

Browse files
dogancandanpovey
authored andcommitted
[src] Fix bug in lattice-oracle relating to wildcards in lattice output. (kaldi-asr#2461)
1 parent 3f4f425 commit e014051

File tree

1 file changed

+16
-38
lines changed

1 file changed

+16
-38
lines changed

src/latbin/lattice-oracle.cc

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,16 @@
2626

2727
namespace kaldi {
2828

29-
using std::vector;
30-
using std::set;
31-
32-
typedef unordered_set<fst::StdArc::Label> LabelSet;
29+
typedef fst::StdArc::Label Label;
30+
typedef std::vector<std::pair<Label, Label>> LabelPairVector;
3331

3432
void ReadSymbolList(const std::string &rxfilename,
3533
fst::SymbolTable *word_syms,
36-
LabelSet *lset) {
34+
LabelPairVector *lpairs) {
3735
Input ki(rxfilename);
3836
std::string line;
39-
KALDI_ASSERT(lset != NULL);
40-
lset->clear();
37+
KALDI_ASSERT(lpairs != NULL);
38+
lpairs->clear();
4139
while (getline(ki.Stream(), line)) {
4240
std::string sym;
4341
std::istringstream ss(line);
@@ -52,45 +50,22 @@ void ReadSymbolList(const std::string &rxfilename,
5250
<< line << ", file is: "
5351
<< PrintableRxfilename(rxfilename);
5452
}
55-
lset->insert(lab);
56-
}
57-
}
58-
59-
void MapWildCards(const LabelSet &wildcards, fst::StdVectorFst *ofst) {
60-
// map all wildcards symbols to epsilons
61-
for (fst::StateIterator<fst::StdVectorFst> siter(*ofst);
62-
!siter.Done(); siter.Next()) {
63-
fst::StdArc::StateId s = siter.Value();
64-
for (fst::MutableArcIterator<fst::StdVectorFst> aiter(ofst, s);
65-
!aiter.Done(); aiter.Next()) {
66-
fst::StdArc arc(aiter.Value());
67-
LabelSet::const_iterator it = wildcards.find(arc.ilabel);
68-
if (it != wildcards.end()) {
69-
KALDI_VLOG(4) << "MapWildCards: mapping symbol " << arc.ilabel
70-
<< " to epsilon" << std::endl;
71-
arc.ilabel = 0;
72-
}
73-
it = wildcards.find(arc.olabel);
74-
if (it != wildcards.end()) {
75-
arc.olabel = 0;
76-
}
77-
aiter.SetValue(arc);
78-
}
53+
lpairs->emplace_back(lab, 0);
7954
}
8055
}
8156

8257
// convert from Lattice to standard FST
8358
// also maps wildcard symbols to epsilons
8459
// then removes epsilons
8560
void ConvertLatticeToUnweightedAcceptor(const kaldi::Lattice &ilat,
86-
const LabelSet &wildcards,
61+
const LabelPairVector &wildcards,
8762
fst::StdVectorFst *ofst) {
8863
// first convert from lattice to normal FST
8964
fst::ConvertLattice(ilat, ofst);
9065
// remove weights, project to output, sort according to input arg
9166
fst::Map(ofst, fst::RmWeightMapper<fst::StdArc>());
9267
fst::Project(ofst, fst::PROJECT_OUTPUT); // The words are on the output side
93-
MapWildCards(wildcards, ofst);
68+
fst::Relabel(ofst, wildcards, wildcards);
9469
fst::RmEpsilon(ofst); // Don't tolerate epsilons as they make it hard to
9570
// tally errors
9671
fst::ArcSort(ofst, fst::StdILabelCompare());
@@ -259,7 +234,7 @@ int main(int argc, char *argv[]) {
259234
KALDI_ERR << "Could not read symbol table from file "
260235
<< word_syms_filename;
261236

262-
LabelSet wildcards;
237+
LabelPairVector wildcards;
263238
if (wild_syms_rxfilename != "") {
264239
KALDI_WARN << "--wildcard-symbols-list option deprecated.";
265240
KALDI_ASSERT(wildcard_symbols.empty() && "Do not use both "
@@ -275,7 +250,7 @@ int main(int argc, char *argv[]) {
275250
<< "--wildcard-symbols option, got: " << wildcard_symbols;
276251
}
277252
for (size_t i = 0; i < wildcard_symbols_vec.size(); i++)
278-
wildcards.insert(wildcard_symbols_vec[i]);
253+
wildcards.emplace_back(wildcard_symbols_vec[i], 0);
279254
}
280255

281256
int32 n_done = 0, n_fail = 0;
@@ -301,9 +276,9 @@ int main(int argc, char *argv[]) {
301276
const std::vector<int32> &reference = reference_reader.Value(key);
302277
VectorFst<StdArc> reference_fst;
303278
MakeLinearAcceptor(reference, &reference_fst);
304-
MapWildCards(wildcards, &reference_fst); // Remove any wildcards in
305-
// reference.
306279

280+
// Remove any wildcards in reference.
281+
fst::Relabel(&reference_fst, wildcards, wildcards);
307282
CheckFst(reference_fst, "reference_fst_", key);
308283

309284
// recreate edit distance fst if necessary
@@ -384,7 +359,10 @@ int main(int argc, char *argv[]) {
384359
CompactLattice clat;
385360
CompactLattice oracle_clat;
386361
ConvertLattice(lat, &clat);
387-
fst::Compose(oracle_clat_mask, clat, &oracle_clat);
362+
fst::Relabel(&clat, wildcards, LabelPairVector());
363+
fst::Compose(oracle_clat_mask, clat, &oracle_clat_mask);
364+
fst::ShortestPath(oracle_clat_mask, &oracle_clat);
365+
fst::Project(&oracle_clat, fst::PROJECT_OUTPUT);
388366

389367
if (oracle_clat.Start() == fst::kNoStateId) {
390368
KALDI_WARN << "Failed to find the oracle path in the original "

0 commit comments

Comments
 (0)