-
Notifications
You must be signed in to change notification settings - Fork 0
/
classify_word_orders.py
136 lines (111 loc) · 5.89 KB
/
classify_word_orders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from roberta.helpers import load_shuffled_model
import tqdm
import argparse
from sklearn.model_selection import cross_val_score
import torch
import numpy as np
import random
from utils.rand_word_order_utils import ud_load_classify
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
def classify(args, all_examples, all_labels):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
roberta = load_shuffled_model(args.model_path)
roberta.eval()
if 'scramble_position' in args.shuffle_mode:
d = roberta.model.encoder.sentence_encoder.embed_positions.weight.data
d = torch.cat((d[0:2], d[2:][torch.randperm(d.size(0) - 2)]))
roberta.model.encoder.sentence_encoder.embed_positions.weight.data = d
if 'norm_position' in args.shuffle_mode:
d = roberta.model.encoder.sentence_encoder.embed_positions.weight.data
mean = d.mean(dim=1).unsqueeze(-1).repeat(1, d.size(-1))
std = d.std(dim=1).unsqueeze(-1).repeat(1, d.size(-1))
roberta.model.encoder.sentence_encoder.embed_positions.weight.data = torch.normal(mean, std)
all_sent_encodings = []
for sent_idx, (sentence, label) in tqdm.tqdm(enumerate(zip(all_examples, all_labels))):
with torch.no_grad():
if 'pre_encode' in args.shuffle_mode:
sentence = sentence.lower().split()
if args.shuffle_mode.startswith('baseline'):
random.shuffle(sentence)
else:
if label == 'p':
random.shuffle(sentence)
tokens = roberta.encode(" ".join(sentence))
elif 'mid_encode' in args.shuffle_mode:
tokens = roberta.encode(sentence)[1:-1]
if args.shuffle_mode.startswith('baseline'):
random.shuffle(tokens)
else:
if label == 'p':
random.shuffle(tokens)
tokens = torch.cat((torch.tensor([0]), tokens, torch.tensor([2])))
elif 'post_encode' in args.shuffle_mode:
sentence = sentence.split()
split_with_spaces = [i for i in sentence]
if args.shuffle_mode.startswith('whitespaced'):
tokens = [roberta.encode(" " + i)[1:-1] for i in split_with_spaces]
elif args.shuffle_mode.startswith('safe'):
tokens = [roberta.encode(split_with_spaces[0])[1:-1]] + [roberta.encode(i)[1:-1] for i in split_with_spaces]
else:
tokens = [roberta.encode(i)[1:-1] for i in split_with_spaces]
if args.shuffle_mode.startswith('baseline'):
random.shuffle(tokens)
else:
if label == 'p':
random.shuffle(tokens)
tokens = [item for sublist in tokens for item in sublist]
tokens = torch.stack(tokens)
tokens = torch.cat((torch.tensor([0]), tokens, torch.tensor([2])))
elif 'only_position' in args.shuffle_mode:
s_len = len(roberta.encode(" ".join(sentence)))
features = roberta.model.encoder.sentence_encoder.embed_positions.weight[:s_len].mean(dim=0)
else:
print(f"{args.shuffle_mode} does not exist")
return
if 'only_position' not in args.shuffle_mode:
features = roberta.extract_features(tokens)
features = features.squeeze(0).mean(dim=0)
all_sent_encodings.append(features.cpu().detach().numpy())
# make train / dev / test
clf = LogisticRegression(random_state=42)
X, y = np.vstack(all_sent_encodings), all_labels
scores = cross_val_score(clf, X, y, cv=5)
print(scores)
print(f"{np.mean(scores)} ± {np.std(scores)}")
# acc = clf.score(dev_features, dev_labels)
# dev_size = len(all_sent_encodings) // 6
# if not args.hold_out_words:
# train_features, train_labels = np.vstack(all_sent_encodings[:-dev_size]), all_labels[:-dev_size]
# dev_features, dev_labels = np.vstack(all_sent_encodings[-dev_size:]), all_labels[-dev_size:]
# print stats
# o_count_train = len([l for l in train_labels if l == 'o'])
# p_count_train = len([l for l in train_labels if l == 'p'])
# o_count_dev = len([l for l in dev_labels if l == 'o'])
# p_count_dev = len([l for l in dev_labels if l == 'p'])
# print("O-train: {} P-train: {} O-dev: {} P-dev: {} !".format(o_count_train, p_count_train, o_count_dev, p_count_dev))
#
# train and eval
# print(acc, ": acc")
def main():
parser = argparse.ArgumentParser(description="generate token embeddings from corpus")
parser.add_argument('-d', "--dataset_path", type=str)
parser.add_argument('-m', "--model_path", type=str)
parser.add_argument('-l', "--max_sentence_len", type=int, default=10)
parser.add_argument('-p', "--no_perms", type=int, default=1)
parser.add_argument("--shuffle_mode", action='store', default='tokens')
parser.add_argument('-hw', "--hold_out_words", action='store_true', default=False)
arguments = parser.parse_args()
dataset_file = open(arguments.dataset_path, 'r').read()
all_examples, all_labels = ud_load_classify(dataset_file, sentence_len_limit=arguments.max_sentence_len)
print(f'read {len(all_examples)} examples')
all_examples, all_labels = shuffle(np.array(all_examples), np.array(all_labels))
classify(arguments, all_examples, all_labels)
# compute correlation between ppl and levenstein distance
# corr = spearmanr(all_sent_ppl, leven_distances_to_orig)
# print(corr, " :correlation of perplexity to leven distance to orig order.")
# compute correlation between ppl and bleu-4
# corr = spearmanr(all_sent_ppl, bleu_to_orig)
# print(corr, " :correlation of perplexity to bleu to orig order.")
if __name__ == '__main__':
main()