-
Notifications
You must be signed in to change notification settings - Fork 12
/
ngram-expand.cc
147 lines (136 loc) · 4.9 KB
/
ngram-expand.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2012 Aix-Marseille Univ.
// Author: [email protected] (Benoit Favre)
#include <fst/fstlib.h>
#include <unordered_map>
#include <list>
using namespace fst;
using namespace std;
typedef int64 State;
struct Context {
State inputState;
State outputState;
list<StdArc> seq;
size_t length;
Context() {}
Context(State _inputState, State _outputState, size_t _length) : inputState(_inputState), outputState(_outputState), length(_length) {}
Context& operator=(const Context& other) {
inputState = other.inputState;
outputState = other.outputState;
length = other.length;
seq = other.seq;
return *this;
}
void print() {
cerr << "state: in=" << inputState << " out=" << outputState << " labels:";
for(list<StdArc>::const_iterator i = seq.begin(); i != seq.end(); i++) {
cerr << " " << i->ilabel;
}
cerr << endl;
}
void push(const StdArc& arc) {
seq.push_back(arc);
inputState = arc.nextstate;
if(seq.size() > length) seq.pop_front();
}
StdArc::Weight weight() const {
return seq.back().weight;
}
int64 ilabel() const {
return seq.back().ilabel;
}
int64 olabel() const {
return seq.back().ilabel;
}
};
// custom specialization of std::hash can be injected in namespace std
namespace std
{
template<> struct hash<Context>
{
typedef Context argument_type;
typedef size_t result_type;
size_t operator()(const Context& a) const {
size_t output = a.inputState;
list<StdArc>::const_iterator i = a.seq.begin();
if(a.seq.size() == a.length) i++;
for(; i != a.seq.end(); i++) output ^= i->ilabel ^ i->olabel;
return output;
}
};
template <> struct equal_to<Context>
{
typedef Context first_argument_type;
typedef Context second_argument_type;
typedef size_t result_type;
int operator()(const Context& a, const Context& b) const {
if(a.inputState != b.inputState) return false;
if(a.seq.size() != b.seq.size()) return false;
list<StdArc>::const_iterator i = a.seq.begin();
list<StdArc>::const_iterator j = b.seq.begin();
if(a.seq.size() == a.length) i++;
if(b.seq.size() == b.length) j++;
for(; i != a.seq.end() && j != b.seq.end(); i++, j++) {
if(i->ilabel != j->ilabel) return false;
if(i->olabel != j->olabel) return false;
}
return true;
}
};
}
int main(int argc, char** argv) {
int ngram_size = 2;
if(argc >= 2) ngram_size = atoi(argv[1]);
StdVectorFst *input = StdVectorFst::Read("");
if(ngram_size < 2) { // nothing to do
input->Write("");
return 0;
}
StdVectorFst output;
unordered_map<Context, State> outputStates;
list<Context> queue; // queue all unprocessed contexts
State outputStart = 0;
output.AddState();
output.SetStart(outputStart);
queue.push_back(Context(input->Start(), outputStart, ngram_size));
while(queue.size() > 0) {
Context current = queue.front();
queue.pop_front();
State inputState = current.inputState;
State arcStartState = current.outputState;
for(ArcIterator<StdVectorFst> aiter(*input, inputState); !aiter.Done(); aiter.Next()) {
const StdArc &arc = aiter.Value();
Context next = current;
next.push(arc);
unordered_map<Context, State>::iterator found = outputStates.find(next);
int arcEndState = output.NumStates();
if(found == outputStates.end()) {
outputStates[next] = arcEndState;
output.AddState();
next.outputState = arcEndState;
queue.push_back(next);
} else {
arcEndState = found->second;
}
if(input->Final(arc.nextstate) != arc.weight.Zero()) {
output.SetFinal(arcEndState, input->Final(arc.nextstate));
}
output.AddArc(arcStartState, StdArc(next.ilabel(), next.olabel(), next.weight(), arcEndState));
}
}
output.SetInputSymbols(input->InputSymbols());
output.SetOutputSymbols(input->OutputSymbols());
output.Write("");
delete input;
}