-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
116 lines (96 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import os
import copy
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from models import EDSR
from datasets import dataset
from utils import AverageMeter,calc_rmse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train-file',type=str,required=True)
parser.add_argument('--eval-file',type=str,required=True)
parser.add_argument('--output-dir',type=str,required=True)
parser.add_argument('--scale',type=int,default=2)
parser.add_argument('--lr',type=float,default=1e-4)
parser.add_argument('--batch-size',type=int,default=16)
parser.add_argument('--num-epoch',type=int,default=40)
parser.add_argument('--num-workers',type=int,default=8)
parser.add_argument('--seed',type=int,default=911)
args = parser.parse_args()
args.output_dir = os.path.join(args.output_dir,'x{}'.format(args.scale))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = EDSR().to(device)
params = model.parameters()
optimizer = optim.Adam(params, lr = args.lr)
train_dataset = dataset(args.train_file)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=16,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
eval_dataset = dataset(args.eval_file)
eval_dataloader = DataLoader(
dataset = eval_dataset,
batch_size=1
)
best_weight = copy.deepcopy(model.state_dict())
best_epoch = 0
best_rmse = 10000
criterion = nn.MSELoss()
for epoch in range(args.num_epoch):
model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
t.set_description('epoch:{}/{}'.format(epoch,args.num_epoch - 1))
for data in train_dataloader:
inputs,label = data
inputs = inputs.to(device)
label = label.to(device)
inputs = torch.where(torch.isnan(inputs) | torch.isinf(inputs), torch.zeros_like(inputs), inputs)
label = torch.where(torch.isnan(label) | torch.isinf(label), torch.zeros_like(label), label)
inputs = inputs.unsqueeze(1)
pred = model(inputs)
pred = pred.squeeze(1)
loss = criterion(label,pred)
epoch_losses.update(loss.item(),len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
t.update(len(inputs))
torch.save(model.state_dict(),os.path.join(args.output_dir,"epoch_{}.pth".format(epoch)))
model.eval()
epoch_rmse = AverageMeter()
for data in eval_dataloader:
with torch.no_grad():
inputs,labels = data
inputs = inputs.to(device)
labels = labels.to(device)
inputs = torch.where(torch.isnan(inputs) | torch.isinf(inputs), torch.zeros_like(inputs), inputs)
labels = torch.where(torch.isnan(labels) | torch.isinf(labels), torch.zeros_like(labels), labels)
inputs = inputs.unsqueeze(1)
preds = model(inputs)
preds = preds.squeeze(1)
print(labels)
print(preds)
epoch_rmse.update(calc_rmse(preds,labels),1)
if epoch_rmse.avg < best_rmse:
best_epoch = epoch
best_rmse = epoch_rmse.avg
best_weight = copy.deepcopy(model.state_dict())
print('eval_rmse {:.2f}'.format(epoch_rmse.avg))
print("best_epoch:{} best_rmse:{:.2f}".format(best_epoch,best_rmse))
torch.save(best_epoch,os.path.join(args.output_dir,'best.pth'))