-
Notifications
You must be signed in to change notification settings - Fork 2
/
penguins_novel_pruning.py
91 lines (70 loc) · 2.8 KB
/
penguins_novel_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torch.nn as nn
from torch.nn.utils import prune
device = "cuda"
class PenguinsDataset(Dataset):
def __init__(self, start, end):
f = open("penguins_processed.txt")
lines = f.readlines()
linesplit = [lines[n].split("|") for n in range(start, end)]
lab = [n[0] for n in linesplit]
dat = [n[1] for n in linesplit]
labsplit = [[float(m) for m in n.split(",")] for n in lab]
datsplit = [[float(m) for m in n.split(",")] for n in dat]
self.labels = labsplit
self.data = datsplit
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return torch.tensor(self.data[idx], device=device).float(), torch.tensor(self.labels[idx],
device=device).float()
class PenguinModel(nn.Module):
def __init__(self):
super(PenguinModel, self).__init__()
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.stack1 = nn.Linear(12, 12)
self.stack2 = nn.Linear(12, 3)
def forward(self, x):
return self.tanh(self.stack2(self.relu(self.stack1(x))))
train_dataset = PenguinsDataset(0, 200)
val_dataset = PenguinsDataset(200, 300)
model = PenguinModel()
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=50, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False)
train_performance = []
val_performance = []
epochs = 1000
for t in range(epochs):
if t in (333, 666):
print(t)
to_prune = ((model.stack1, "weight"), (model.stack2, "weight"))
prune.global_unstructured(to_prune, pruning_method=prune.L1Unstructured, amount=0.17)
for batch, (X, y) in enumerate(train_dataloader):
# print("\t"+str(batch))
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_X, val_y = next(iter(val_dataloader))
val_pred = model(val_X)
val_loss = loss_fn(val_pred, val_y)
train_performance.append(loss.item())
val_performance.append(val_loss.item())
plt.plot(train_performance, label="Training Loss")
plt.plot(val_performance, label="Validation Loss")
plt.legend()
plt.title("Model Training (Pruned During Training t=333, t=666)")
plt.show()
test_dataset = PenguinsDataset(300, 333)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=33, shuffle=False)
test_X, test_y = next(iter(test_dataloader))
test_pred = model(test_X)
test_loss = loss_fn(test_pred, test_y)
print(test_loss.item())