forked from khurramjaved96/Recursive-CNNs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
146 lines (122 loc) · 6.66 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
''' Document Localization using Recursive CNN
Maintainer : Khurram Javed
Email : [email protected] '''
from __future__ import print_function
import argparse
import torch
import torch.utils.data as td
import dataprocessor
import experiment as ex
import model
import trainer
import utils
parser = argparse.ArgumentParser(description='Recursive-CNNs')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
help='learning rate (default: 0.005)')
parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
help='LR is multiplied by gamma[k] on schedule[k], number of gammas should be equal to schedule')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--pretrain', action='store_true', default=False,
help='Pretrain the model on CIFAR dataset?')
parser.add_argument('--load-ram', action='store_true', default=False,
help='Load data in ram: TODO : Remove this')
parser.add_argument('--debug', action='store_true', default=True,
help='Debug messages')
parser.add_argument('--seed', type=int, default=2323,
help='Seeds values to be used')
parser.add_argument('--log-interval', type=int, default=5, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--model-type', default="resnet",
help='model type to be used. Example : resnet32, resnet20, densenet, test')
parser.add_argument('--name', default="noname",
help='Name of the experiment')
parser.add_argument('--output-dir', default="../",
help='Directory to store the results; a new folder "DDMMYYYY" will be created '
'in the specified directory to save the results.')
parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
parser.add_argument('--epochs', type=int, default=40, help='Number of epochs for trianing')
parser.add_argument('--dataset', default="document", help='Dataset to be used; example document, corner')
parser.add_argument('--loader', default="hdd",
help='Loader to load data; hdd for reading from the hdd and ram for loading all data in the memory')
parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
help="input Directory of train data")
parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
help="input Directory of val data")
args = parser.parse_args()
# Define an experiment.
my_experiment = ex.experiment(args.name, args, args.output_dir)
# Add logging support
logger = utils.utils.setup_logger(my_experiment.path)
args.cuda = not args.no_cuda and torch.cuda.is_available()
dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
# Fix the seed.
seed = args.seed
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
transform=dataset.train_transform,
cuda=args.cuda)
# Loader used for training data
val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
transform=dataset.test_transform,
cuda=args.cuda)
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
# Iterator to iterate over training data.
train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
batch_size=args.batch_size, shuffle=True, **kwargs)
# Iterator to iterate over training data.
val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
batch_size=args.batch_size, shuffle=True, **kwargs)
# Get the required model
myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
if args.cuda:
myModel.cuda()
# Should I pretrain the model on CIFAR?
if args.pretrain:
trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
# Define the optimizer used in the experiment
cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
weight_decay=args.decay, nesterov=True)
# Trainer object used for training
cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
for epoch in range(0, 70):
logger.info("Epoch : %d", epoch)
cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
cifar_trainer.train(epoch)
# Freeze the model
counter = 0
for name, param in myModel.named_parameters():
# Getting the length of total layers so I can freeze x% of layers
gen_len = sum(1 for _ in myModel.parameters())
if counter < int(gen_len * 0.5):
param.requires_grad = False
logger.warning(name)
else:
logger.info(name)
counter += 1
# Define the optimizer used in the experiment
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
momentum=args.momentum,
weight_decay=args.decay, nesterov=True)
# Trainer object used for training
my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
# Evaluator
my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
# Running epochs_class epochs
for epoch in range(0, args.epochs):
logger.info("Epoch : %d", epoch)
my_trainer.update_lr(epoch, args.schedule, args.gammas)
my_trainer.train(epoch)
my_eval.evaluate(my_trainer.model, val_iterator)
torch.save(myModel.state_dict(), my_experiment.path + args.dataset + "_" + args.model_type+ ".pb")
my_experiment.store_json()