-
Notifications
You must be signed in to change notification settings - Fork 7
/
criterion.py
70 lines (62 loc) · 2.72 KB
/
criterion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from torch import nn
import torch
import aloscene
from aloscene import Flow
class RAFTCriterion(nn.Module):
def __init__(self):
super().__init__()
# loss from RAFT implementation
@staticmethod
def sequence_loss(m_outputs, flow_gt, valid=None, gamma=0.8, max_flow=400, compute_per_iter=False):
"""Loss function defined over sequence of flow predictions"""
n_predictions = len(m_outputs)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt ** 2, dim=1, keepdim=True).sqrt()
if valid is None:
valid = torch.ones_like(mag, dtype=torch.bool)
else:
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
m_dict = m_outputs[i]
i_weight = gamma ** (n_predictions - i - 1)
i_loss = (m_dict["up_flow"] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
if compute_per_iter:
epe_per_iter = []
for i in range(n_predictions):
m_dict = m_outputs[i]
epe = torch.sum((m_dict["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
epe_per_iter.append(epe)
else:
epe_per_iter = None
epe = torch.sum((m_outputs[-1]["up_flow"] - flow_gt) ** 2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
metrics = {
"loss": flow_loss.item(),
"epe": epe.mean().item(),
"1px": (epe < 1).float().mean().item(),
"3px": (epe < 3).float().mean().item(),
"5px": (epe < 5).float().mean().item(),
}
return flow_loss, metrics, epe_per_iter
def forward(self, m_outputs, frame1, use_valid=True, compute_per_iter=False):
assert isinstance(frame1, aloscene.Frame)
flow_gt = [f.batch() for f in frame1.flow["flow_forward"]]
flow_gt = torch.cat(flow_gt, dim=0)
# occlusion mask -- not used in raft original repo
# in raft, valid removes only pixels with ground_truth flow > 1000 on one dimension
# valid = (flow_gt.occlusion / 255.)
# valid = valid[valid.get_slices({"C":0})]
assert flow_gt.names == ("B", "C", "H", "W")
flow_gt = flow_gt.as_tensor()
flow_x, flow_y = flow_gt[:, 0, ...], flow_gt[:, 1, ...]
if use_valid:
valid = (flow_x.abs() < 1000) & (flow_y.abs() < 1000)
else:
valid = None
flow_loss, metrics, epe_per_iter = RAFTCriterion.sequence_loss(
m_outputs, flow_gt, valid, compute_per_iter=compute_per_iter
)
return flow_loss, metrics, epe_per_iter