forked from lvapeab/nmt-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample_ensemble.py
62 lines (56 loc) · 3.32 KB
/
sample_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import logging
import ast
from keras_wrapper.extra.read_write import pkl2dict
from nmt_keras import check_params
from nmt_keras.apply_model import sample_ensemble
logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Use several translation models for obtaining predictions from a source text file.")
parser.add_argument("-ds", "--dataset", required=True, help="Dataset instance with data")
parser.add_argument("-t", "--text", required=True, help="Text file with source sentences")
parser.add_argument("-s", "--splits", nargs='+', required=False, default=['val'], help="Splits to sample. "
"Should be already included"
"into the dataset object.")
parser.add_argument("-d", "--dest", required=False, help="File to save translations in. If not specified, "
"translations are outputted in STDOUT.")
parser.add_argument("-v", "--verbose", required=False, default=0, type=int, help="Verbosity level")
parser.add_argument("-c", "--config", required=False, help="Config pkl for loading the model configuration. "
"If not specified, hyperparameters "
"are read from config.py")
parser.add_argument("-n", "--n-best", action="store_true", default=False, help="Write n-best list (n = beam size)")
parser.add_argument("-w", "--weights", nargs="*", help="Weight given to each model in the ensemble. You should provide the same number of weights than models."
"By default, it applies the same weight to each model (1/N).", default=[])
parser.add_argument("-g", "--glossary", required=False, help="Glossary file for overwriting translations.")
parser.add_argument("-m", "--models", nargs="+", required=True, help="Path to the models")
parser.add_argument("-ch", "--changes", nargs="*", help="Changes to the config. Following the syntax Key=Value",
default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.config is None:
logging.info("Reading parameters from config.py")
from config import load_parameters
params = load_parameters()
else:
logging.info("Loading parameters from %s" % str(args.config))
params = pkl2dict(args.config)
try:
for arg in args.changes:
try:
k, v = arg.split('=')
except ValueError:
print ('Overwritten arguments must have the form key=Value. \n Currently are: %s' % str(args.changes))
exit(1)
try:
params[k] = ast.literal_eval(v)
except ValueError:
params[k] = v
except ValueError:
print ('Error processing arguments: (', k, ",", v, ")")
exit(2)
params = check_params(params)
sample_ensemble(args, params)