2626
2727namespace kaldi {
2828
29- using std::vector;
30- using std::set;
31-
32- typedef unordered_set<fst::StdArc::Label> LabelSet;
29+ typedef fst::StdArc::Label Label;
30+ typedef std::vector<std::pair<Label, Label>> LabelPairVector;
3331
3432void ReadSymbolList (const std::string &rxfilename,
3533 fst::SymbolTable *word_syms,
36- LabelSet *lset ) {
34+ LabelPairVector *lpairs ) {
3735 Input ki (rxfilename);
3836 std::string line;
39- KALDI_ASSERT (lset != NULL );
40- lset ->clear ();
37+ KALDI_ASSERT (lpairs != NULL );
38+ lpairs ->clear ();
4139 while (getline (ki.Stream (), line)) {
4240 std::string sym;
4341 std::istringstream ss (line);
@@ -52,45 +50,22 @@ void ReadSymbolList(const std::string &rxfilename,
5250 << line << " , file is: "
5351 << PrintableRxfilename (rxfilename);
5452 }
55- lset->insert (lab);
56- }
57- }
58-
59- void MapWildCards (const LabelSet &wildcards, fst::StdVectorFst *ofst) {
60- // map all wildcards symbols to epsilons
61- for (fst::StateIterator<fst::StdVectorFst> siter (*ofst);
62- !siter.Done (); siter.Next ()) {
63- fst::StdArc::StateId s = siter.Value ();
64- for (fst::MutableArcIterator<fst::StdVectorFst> aiter (ofst, s);
65- !aiter.Done (); aiter.Next ()) {
66- fst::StdArc arc (aiter.Value ());
67- LabelSet::const_iterator it = wildcards.find (arc.ilabel );
68- if (it != wildcards.end ()) {
69- KALDI_VLOG (4 ) << " MapWildCards: mapping symbol " << arc.ilabel
70- << " to epsilon" << std::endl;
71- arc.ilabel = 0 ;
72- }
73- it = wildcards.find (arc.olabel );
74- if (it != wildcards.end ()) {
75- arc.olabel = 0 ;
76- }
77- aiter.SetValue (arc);
78- }
53+ lpairs->emplace_back (lab, 0 );
7954 }
8055}
8156
8257// convert from Lattice to standard FST
8358// also maps wildcard symbols to epsilons
8459// then removes epsilons
8560void ConvertLatticeToUnweightedAcceptor (const kaldi::Lattice &ilat,
86- const LabelSet &wildcards,
61+ const LabelPairVector &wildcards,
8762 fst::StdVectorFst *ofst) {
8863 // first convert from lattice to normal FST
8964 fst::ConvertLattice (ilat, ofst);
9065 // remove weights, project to output, sort according to input arg
9166 fst::Map (ofst, fst::RmWeightMapper<fst::StdArc>());
9267 fst::Project (ofst, fst::PROJECT_OUTPUT); // The words are on the output side
93- MapWildCards ( wildcards, ofst );
68+ fst::Relabel (ofst, wildcards, wildcards );
9469 fst::RmEpsilon (ofst); // Don't tolerate epsilons as they make it hard to
9570 // tally errors
9671 fst::ArcSort (ofst, fst::StdILabelCompare ());
@@ -259,7 +234,7 @@ int main(int argc, char *argv[]) {
259234 KALDI_ERR << " Could not read symbol table from file "
260235 << word_syms_filename;
261236
262- LabelSet wildcards;
237+ LabelPairVector wildcards;
263238 if (wild_syms_rxfilename != " " ) {
264239 KALDI_WARN << " --wildcard-symbols-list option deprecated." ;
265240 KALDI_ASSERT (wildcard_symbols.empty () && " Do not use both "
@@ -275,7 +250,7 @@ int main(int argc, char *argv[]) {
275250 << " --wildcard-symbols option, got: " << wildcard_symbols;
276251 }
277252 for (size_t i = 0 ; i < wildcard_symbols_vec.size (); i++)
278- wildcards.insert (wildcard_symbols_vec[i]);
253+ wildcards.emplace_back (wildcard_symbols_vec[i], 0 );
279254 }
280255
281256 int32 n_done = 0 , n_fail = 0 ;
@@ -301,9 +276,9 @@ int main(int argc, char *argv[]) {
301276 const std::vector<int32> &reference = reference_reader.Value (key);
302277 VectorFst<StdArc> reference_fst;
303278 MakeLinearAcceptor (reference, &reference_fst);
304- MapWildCards (wildcards, &reference_fst); // Remove any wildcards in
305- // reference.
306279
280+ // Remove any wildcards in reference.
281+ fst::Relabel (&reference_fst, wildcards, wildcards);
307282 CheckFst (reference_fst, " reference_fst_" , key);
308283
309284 // recreate edit distance fst if necessary
@@ -384,7 +359,10 @@ int main(int argc, char *argv[]) {
384359 CompactLattice clat;
385360 CompactLattice oracle_clat;
386361 ConvertLattice (lat, &clat);
387- fst::Compose (oracle_clat_mask, clat, &oracle_clat);
362+ fst::Relabel (&clat, wildcards, LabelPairVector ());
363+ fst::Compose (oracle_clat_mask, clat, &oracle_clat_mask);
364+ fst::ShortestPath (oracle_clat_mask, &oracle_clat);
365+ fst::Project (&oracle_clat, fst::PROJECT_OUTPUT);
388366
389367 if (oracle_clat.Start () == fst::kNoStateId ) {
390368 KALDI_WARN << " Failed to find the oracle path in the original "
0 commit comments