diff --git a/src/fast_align.cc b/src/fast_align.cc index 3b4f827..c45893a 100644 --- a/src/fast_align.cc +++ b/src/fast_align.cc @@ -58,6 +58,7 @@ void ParseLine(const string& line, string input; string conditional_probability_filename = ""; +string existing_probability_filename = ""; int is_reverse = 0; int ITERATIONS = 5; int favor_diagonal = 0; @@ -79,13 +80,14 @@ struct option options[] = { {"alpha", required_argument, 0, 'a'}, {"no_null_word", no_argument, &no_null_word, 1 }, {"conditional_probabilities", required_argument, 0, 'c'}, + {"existing_probabilities", required_argument, 0, 'e'}, {0,0,0,0} }; bool InitCommandLine(int argc, char** argv) { while (1) { int oi; - int c = getopt_long(argc, argv, "i:rI:dp:T:ova:Nc:", options, &oi); + int c = getopt_long(argc, argv, "i:rI:dp:T:ova:Nc:e:", options, &oi); if (c == -1) break; switch(c) { case 'i': input = optarg; break; @@ -99,6 +101,7 @@ bool InitCommandLine(int argc, char** argv) { case 'a': alpha = atof(optarg); break; case 'N': no_null_word = 1; break; case 'c': conditional_probability_filename = optarg; break; + case 'e': existing_probability_filename = optarg; break; default: return false; } } @@ -116,6 +119,7 @@ int main(int argc, char** argv) { << " -o: [USE] Optimize how close to the diagonal alignment points should be\n" << " -r: Run alignment in reverse (condition on target and predict source)\n" << " -c: Output conditional probability table\n" + << " -e: Start with existing conditional probability table\n" << " Advanced options:\n" << " -I: number of iterations in EM training (default = 5)\n" << " -p: p_null parameter (default = 0.08)\n" @@ -132,12 +136,19 @@ int main(int argc, char** argv) { double prob_align_not_null = 1.0 - prob_align_null; const unsigned kNULL = d.Convert(""); TTable s2t, t2s; + if (!existing_probability_filename.empty()) { + bool success = s2t.ImportFromFile(existing_probability_filename.c_str(), '\t', d); + if (!success) { + cerr << "Can't read " << existing_probability_filename << endl; + return 1; + } + } unordered_map, unsigned, PairHash> size_counts; double tot_len_ratio = 0; double mean_srclen_multiplier = 0; vector probs; - for (int iter = 0; iter < ITERATIONS; ++iter) { - const bool final_iteration = (iter == (ITERATIONS - 1)); + for (int iter = 0; iter < ITERATIONS || (iter==0 && ITERATIONS==0); ++iter) { + const bool final_iteration = (iter >= (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; ifstream in(input.c_str()); if (!in) { diff --git a/src/ttables.h b/src/ttables.h index f5f1cb7..f2a2cb9 100644 --- a/src/ttables.h +++ b/src/ttables.h @@ -17,6 +17,9 @@ #include #include +#include +#include +#include #include struct Md { @@ -107,6 +110,35 @@ class TTable { } file.close(); } + bool ImportFromFile(const char* filename, char delim, Dict& d) { + std::ifstream in(filename); + if (!in) { + return false; + } else { + std::string line; + while(true) { + std::getline(in, line); + if (!in) break; + std::string sourceWord, targetWord, valueString; + std::stringstream stream(line); + + bool success = true; + success &= (std::getline(stream, sourceWord, delim) != NULL); + success &= (std::getline(stream, targetWord, delim) != NULL); + success &= (std::getline(stream, valueString, delim) != NULL); + + if (success) { + unsigned source = d.Convert(sourceWord); + unsigned target = d.Convert(targetWord); + double value = atof(valueString.c_str()); + ttable[source][target] = value; + } else { + return false; + } + } + } + return true; + } public: Word2Word2Double ttable; Word2Word2Double counts;