-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_sample_relativity.py
148 lines (120 loc) · 5.02 KB
/
generate_sample_relativity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import logging
import torch
from torch import cuda
import options
import data
from generator_sample_relativity import LSTMModel
from sequence_generator_sample_relativity import SequenceGenerator
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# python3 generate.py --data data-bin/iwslt14.tokenized.de-en/WMT/preprocessed/ --src_lang en --trg_lang de --batch-size 64 --gpuid 0
logging.basicConfig(
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
parser = argparse.ArgumentParser(
description="Driver program for JHU Adversarial-NMT.")
# Load args
options.add_general_args(parser)
options.add_dataset_args(parser)
options.add_checkpoint_args(parser)
options.add_distributed_training_args(parser)
options.add_generation_args(parser)
options.add_generator_model_args(parser)
options.add_discriminator_model_args(parser)
def main(args):
use_cuda = (len(args.gpuid) >= 1)
if args.gpuid:
cuda.set_device(args.gpuid[0])
print(args.replace_unk)
# Load dataset
if args.replace_unk is None:
dataset = data.load_dataset(
args.data,
['test'],
args.src_lang,
args.trg_lang,
)
else:
dataset = data.load_raw_text_dataset(
args.data,
['test'],
args.src_lang,
args.trg_lang,
)
if args.src_lang is None or args.trg_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.src_lang, args.trg_lang = dataset.src, dataset.dst
print('| [{}] dictionary: {} types'.format(
dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(
dataset.dst, len(dataset.dst_dict)))
print('| {} {} {} examples'.format(
args.data, 'test', len(dataset.splits['test'])))
# Set model parameters
args.encoder_embed_dim = 1000
args.encoder_layers = 2
args.encoder_dropout_out = 0
args.decoder_embed_dim = 1000
args.decoder_layers = 2
args.decoder_out_embed_dim = 1000
args.decoder_dropout_out = 0
args.bidirectional = False
# Load model
g_model_path = 'checkpoints/sample_relativity/best_gmodel.pt'
assert os.path.exists(g_model_path)
generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
model_dict = generator.state_dict()
model = torch.load(g_model_path)
pretrained_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
generator.load_state_dict(model_dict)
generator.eval()
print("Generator loaded successfully!")
if use_cuda > 0:
generator.cuda()
else:
generator.cpu()
max_positions = generator.encoder.max_positions()
testloader = dataset.eval_dataloader(
'test',
max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
translator = SequenceGenerator(
generator, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
with open('predictions.txt', 'wb') as translation_writer:
with open('real.txt', 'wb') as ground_truth_writer:
translations = translator.generate_batched_itr(
testloader, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda)
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(
target_tokens, args.remove_bpe, escape_unk=True)
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens = hypo['tokens'].int().cpu()
hypo_str = dataset.dst_dict.string(
hypo_tokens, args.remove_bpe)
hypo_str += '\n'
target_str += '\n'
translation_writer.write(hypo_str.encode('utf-8'))
ground_truth_writer.write(target_str.encode('utf-8'))
if __name__ == "__main__":
ret = parser.parse_known_args()
args = ret[0]
if ret[1]:
logging.warning("unknown arguments: {0}".format(
parser.parse_known_args()[1]))
main(args)