-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
147 lines (127 loc) · 5.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import pytorch_lightning as pl
import torch.optim as optim
from alonet.callbacks import MetricsCallback
from alonet.raft.criterion import RAFTCriterion
from alonet.raft.callbacks import RAFTFlowImagesCallback, RAFTEPECallback, FlowVideoCallback
from aloscene import Frame
import alonet
class LitRAFT(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.weights = args.weights
self.model = self.build_model(weights=args.weights)
self.criterion = self.build_criterion()
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("LitRAFT")
parser.add_argument("--weights", type=str, help="for example raft-things")
# override pl.Trainer gradient_clip default value
parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="Gradient clipping value")
return parent_parser
def forward(self, frames, only_last=True):
# prepare data
if isinstance(frames, list):
frames = Frame.batch_list(frames)
self.assert_input(frames, inference=True)
frame1 = frames[:, 0, ...]
frame2 = frames[:, 1, ...]
# run forward pass model
m_outputs = self.model(frame1, frame2, only_last=only_last)
return m_outputs
def inference(self, m_outputs, only_last=True):
return self.model.inference(m_outputs, only_last=only_last)
def training_step(self, frames, batch_idx):
# prepare data
if isinstance(frames, list):
frames = Frame.batch_list(frames)
self.assert_input(frames, inference=False)
frame1 = frames[:, 0, ...]
frame2 = frames[:, 1, ...]
# run forward pass model
m_outputs = self.model(frame1, frame2, only_last=False)
flow_loss, metrics, epe_per_iter = self.criterion(m_outputs, frame1)
outputs = {"loss": flow_loss, "metrics": metrics, "epe_per_iter": epe_per_iter}
return outputs
def validation_step(self, frames, batch_idx, dataloader_idx=None):
# prepare data
if isinstance(frames, list):
frames = Frame.batch_list(frames)
self.assert_input(frames, inference=True)
frame1 = frames[:, 0, ...]
frame2 = frames[:, 1, ...]
# run forward pass model
m_outputs = self.model(frame1, frame2, only_last=False)
flow_loss, metrics, epe_per_iter = self.criterion(m_outputs, frame1, compute_per_iter=True)
outputs = {"val_loss": flow_loss, "metrics": metrics, "epe_per_iter": epe_per_iter}
return outputs
def build_criterion(self):
return RAFTCriterion()
def build_model(self, weights=None, device="cpu", dropout=0):
return alonet.raft.RAFT(weights=weights, device=device, dropout=dropout)
def configure_optimizers(self, lr=4e-4, weight_decay=1e-4, epsilon=1e-8, numsteps=100000):
params = self.model.parameters()
optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay, eps=epsilon)
if self.args.max_steps is None:
return optimizer
else:
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
lr,
self.args.max_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy="linear",
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
def assert_input(self, frames, inference=False):
assert (
frames.normalization == "minmax_sym"
), f"frames.normalization should minmax_sym, not '{frames.normalization}'"
assert frames.names == (
"B",
"T",
"C",
"H",
"W",
), f"frames.names should be ('B','T','C','H','W'), not: '{frames.names}'"
if inference:
return
assert frames.flow is not None, "A flow label should be attached to the frame"
def callbacks(self, data_loader):
"""Given a data_loader, this method will return the default callbacks
of the training loop.
"""
metrics_callback = MetricsCallback(val_names=data_loader.val_names)
flow_images_callback = RAFTFlowImagesCallback(data_loader)
# flow_video_callback = FlowVideoCallback(data_loader, max_frames=30, fps=3)
flow_epe_callback = RAFTEPECallback(data_loader)
return [metrics_callback, flow_images_callback, flow_epe_callback]
def run_train(self, data_loader, args, project="raft", expe_name="raft", callbacks: list = None):
"""Train the model using pytorch lightning"""
# Set the default callbacks if not provide.
callbacks = callbacks if callbacks is not None else self.callbacks(data_loader)
alonet.common.pl_helpers.run_pl_training(
# Trainer, data & callbacks
lit_model=self,
data_loader=data_loader,
callbacks=callbacks,
# Project info
args=args,
project=project,
expe_name=expe_name,
)
def run_validation(self, data_loader, args, project="raft", expe_name="raft", callbacks: list = None):
"""Validate the model using pytorch lightning"""
# Set the default callbacks if not provide.
callbacks = callbacks if callbacks is not None else self.callbacks(data_loader)
alonet.common.pl_helpers.run_pl_validate(
# Trainer, data & callbacks
lit_model=self,
data_loader=data_loader,
callbacks=callbacks,
# Project info
args=args,
project=project,
expe_name=expe_name,
)