-
Notifications
You must be signed in to change notification settings - Fork 40
/
translate.py
88 lines (75 loc) · 4.06 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (C) 2018 Mikel Artetxe <[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import argparse
import os
import subprocess
from shlex import quote
import train
ROOT = os.path.dirname(os.path.abspath(__file__))
THIRD_PARTY = os.path.abspath(os.environ['MONOSES_THIRD_PARTY']) if 'MONOSES_THIRD_PARTY' in os.environ else ROOT + '/third-party'
MOSES = THIRD_PARTY + '/moses'
def bash(command):
subprocess.run(['bash', '-c', command])
def main():
parser = argparse.ArgumentParser(description='Translate text using a trained model')
parser.add_argument('model', metavar='PATH', help='Working directory of the trained model')
parser.add_argument('-r', '--reverse', action='store_true', help='Use the reverse model (trg->src)')
parser.add_argument('--src', metavar='STR', required=True, help='Input language code')
parser.add_argument('--trg', metavar='STR', required=True, help='Output language code')
parser.add_argument('--tok', action='store_true', help='Tokenized input/output')
parser.add_argument('--step', metavar='N', type=int, default=10, help='Step number (defaults to 10)')
parser.add_argument('--nmt-checkpoints', nargs='+', metavar='N', default=[10, 20, 30, 40, 50, 60], help='Use a checkpoint ensemble over the given iterations')
parser.add_argument('--threads', metavar='N', type=int, default=20, help='Number of threads (defaults to 20)')
parser.add_argument('--cpu', action='store_true', help='Force CPU decoding')
parser.add_argument('--fp16', action='store_true', help='Enable FP16 decoding')
args = parser.parse_args()
direction = 'trg2src' if args.reverse else 'src2trg'
command = 'cat -'
if not args.tok:
command += ' | ' + train.tokenize_command(args, args.src)
command += ' | ' + quote(MOSES + '/scripts/recaser/truecase.perl')
command += ' --model ' + quote(args.model + '/step1/truecase-model.' + direction[:3])
if args.step == 10:
command += ' | python3 ' + quote(train.SUBWORD_NMT + '/subword_nmt/apply_bpe.py') + ' -c ' + quote(args.model + '/step9/bpe.codes')
command += ' | python3 ' + quote(train.FAIRSEQ + '/interactive.py') + ' ' + quote(args.model + '/step10/data.bin')
command += ' --path '
command += ':'.join([quote(args.model + '/step10/' + direction + '.' + str(it) + '.pt') for it in args.nmt_checkpoints])
command += ' --source-lang src --target-lang trg'
command += ' --beam 5'
command += ' --max-tokens 1000'
command += ' --buffer-size 10000'
if args.cpu:
command += ' --cpu'
if args.fp16:
command += ' --fp16'
command += ' | grep -P \'^H\t\''
command += ' | cut -f3'
command += ' | sed -r \'s/(@@ )|(@@ ?$)//g\''
else:
command += ' | ' + quote(MOSES + '/bin/moses2')
if args.step == 6:
command += ' -f ' + quote(args.model + '/step6/' + direction + '.moses.ini')
elif args.step == 7:
command += ' -f ' + quote(args.model + '/step7/' + direction + '.moses.ini')
elif args.step == 8:
command += ' -f ' + quote(args.model + '/step8/' + direction + '.moses.ini')
command += ' --threads ' + str(args.threads)
command += ' 2> /dev/null'
command += ' | ' + quote(MOSES + '/scripts/recaser/detruecase.perl')
if not args.tok:
command += ' | ' + quote(MOSES + '/scripts/tokenizer/detokenizer.perl') + ' -q -l ' + quote(args.trg)
bash(command)
if __name__ == '__main__':
main()