-
Notifications
You must be signed in to change notification settings - Fork 2
/
probe_test.py
113 lines (99 loc) · 4.18 KB
/
probe_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from model import SimpleTokenizer, collate_fn
from utils import get_state_sequence, load_model, load_train_data, load_heldout_data
from torch.utils.data import DataLoader
import torch
import numpy as np
import argparse
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
import pdb
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='shortest-paths')
parser.add_argument('--use-untrained-model', action='store_true')
args = parser.parse_args()
data = args.data
use_untrained_model = args.use_untrained_model
class MultinomialLogisticRegression(nn.Module):
def __init__(self, input_dim, num_classes):
super(MultinomialLogisticRegression, self).__init__()
self.linear = nn.Linear(input_dim, num_classes)
def forward(self, x):
out = self.linear(x)
return out
model = load_model(data, use_untrained_model)
tokenizer = model.tokenizer
valid_turns = tokenizer.valid_turns
node_and_direction_to_neighbor = tokenizer.node_and_direction_to_neighbor
# num_samples = 100000
num_samples = 25000
dataset = load_train_data(data, tokenizer, num_samples=num_samples)
# dataset = load_heldout_data(data, tokenizer)
# Iterate through dataset and get representations.
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
representations = []
labels = []
print("Getting representations...")
bar = tqdm(dataloader)
for batch in bar:
bsz, seq_len = batch['input_ids'].shape
with torch.no_grad():
input_ids = batch['input_ids'].to(model.device)
mask = batch['attention_mask'].to(model.device)
outputs = model.model(input_ids, attention_mask=mask, labels=input_ids, output_hidden_states=True)
hidden_states = outputs.hidden_states
hidden_states = torch.stack(hidden_states, dim=1)
for i in range(bsz):
sequence_str = tokenizer.decode(batch['input_ids'][i])
sequence_states = get_state_sequence(sequence_str, node_and_direction_to_neighbor)
# First state is first token
labels.append(sequence_states[0])
representations.append(hidden_states[i, -1, 0, :].cpu().numpy())
for j in range(1, len(sequence_states)):
labels.append(sequence_states[j])
representations.append(hidden_states[i, -1, j+1, :].cpu().numpy())
label_array = np.array(labels)
representations_array = np.array(representations)
# Convert labels to {0, 1, ..., num_classes-1}
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(label_array)
# Split train and test.
train_inds = np.random.choice(len(representations_array), size=int(0.8 * len(representations_array)), replace=False)
test_inds = np.setdiff1d(np.arange(len(representations_array)), train_inds)
X_train, X_test = representations_array[train_inds], representations_array[test_inds]
y_train, y_test = encoded_labels[train_inds], encoded_labels[test_inds]
X_train, X_test, y_train, y_test = torch.tensor(X_train).float(), torch.tensor(X_test).float(), torch.tensor(y_train), torch.tensor(y_test)
# Initialize probe.
input_dim = X_train.shape[1]
num_classes = len(label_encoder.classes_)
device = model.device
net = MultinomialLogisticRegression(input_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)
# Train probe.
num_epochs = 100
batch_size = 2048
bar = tqdm(range(num_epochs))
for epoch in bar:
for i in range(0, len(X_train), batch_size):
batch_X = X_train[i:i+batch_size].to(device)
batch_y = y_train[i:i+batch_size].to(device)
outputs = net(batch_X)
loss = criterion(outputs, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 1 == 0:
with torch.no_grad():
X_test = X_test.to(device)
y_test = y_test.to(device)
outputs = net(X_test)
_, predicted = torch.max(outputs.data, 1)
accuracy = accuracy_score(y_test.cpu().numpy(), predicted.cpu().numpy())
num_accurate = sum(y_test.cpu().numpy() == predicted.cpu().numpy())
p = num_accurate / len(y_test)
std = np.sqrt(p * (1 - p) / len(y_test))
bar.set_description(f"Probe accuracy: {accuracy:.3f} ({std:.3f})")