forked from luojunwei/DGDTA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_validation.py
111 lines (96 loc) · 4.65 KB
/
training_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import pandas as pd
import sys, os
from random import shuffle
import torch
import torch.nn as nn
from models.gatv2 import GATv2
from models.gatv2_gcn import GATv2_GCN
from utils import *
# training function at each epoch
def train(model, device, train_loader, optimizer, epoch):
print('Training on {} samples...'.format(len(train_loader.dataset)))
model.train()
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, data.y.view(-1, 1).float().to(device))
loss.backward()
optimizer.step()
if batch_idx % LOG_INTERVAL == 0:
print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
batch_idx * len(data.x),
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item()))
def predicting(model, device, loader):
model.eval()
total_preds = torch.Tensor()
total_labels = torch.Tensor()
print('Make prediction for {} samples...'.format(len(loader.dataset)))
with torch.no_grad():
for data in loader:
data = data.to(device)
output = model(data)
total_preds = torch.cat((total_preds, output.cpu()), 0)
total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
return total_labels.numpy().flatten(),total_preds.numpy().flatten()
datasets = [['davis','kiba'][int(1)]]
modeling = [GATv2, GATv2_GCN][int(1)]
model_st = modeling.__name__
cuda_name = "cuda:0"
TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
LR = 0.0005
LOG_INTERVAL = 20
NUM_EPOCHS = 1000
print('Learning rate: ', LR)
print('Epochs: ', NUM_EPOCHS)
# Main program: iterate over different datasets
for dataset in datasets:
print('\nrunning on ', model_st + '_' + dataset )
processed_data_file_train = 'data/processed/' + dataset + '_train.pt'
processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_test))):
print('please run create_data.py to prepare data in pytorch format!')
else:
train_data = TestbedDataset(root='data', dataset=dataset+'_train')
test_data = TestbedDataset(root='data', dataset=dataset+'_test')
train_size = int(0.8 * len(train_data))
valid_size = len(train_data) - train_size
train_data, valid_data = torch.utils.data.random_split(train_data, [train_size, valid_size])
# make data PyTorch mini-batch processing ready
train_loader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False)
# training the model
device = torch.device(cuda_name if torch.cuda.is_available() else "cpu")
model = modeling().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_mse = 1000
best_test_mse = 1000
best_test_ci = 0
best_epoch = -1
model_file_name = 'model_' + model_st + '_' + dataset + '.model'
result_file_name = 'result_' + model_st + '_' + dataset + '.csv'
for epoch in range(NUM_EPOCHS):
train(model, device, train_loader, optimizer, epoch+1)
print('predicting for valid data')
G,P = predicting(model, device, valid_loader)
val = mse(G,P)
if val<best_mse:
best_mse = val
best_epoch = epoch+1
torch.save(model.state_dict(), model_file_name)
print('predicting for test data')
G,P = predicting(model, device, test_loader)
ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)]
with open(result_file_name,'w') as f:
f.write(','.join(map(str,ret)))
best_test_mse = ret[1]
best_test_ci = ret[-1]
print('rmse improved at epoch ', best_epoch, '; best_test_mse,best_test_ci:', best_test_mse,best_test_ci,model_st,dataset)
else:
print(ret[1],'No improvement since epoch ', best_epoch, '; best_test_mse,best_test_ci:', best_test_mse,best_test_ci,model_st,dataset)