-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_flow_extractor.py
204 lines (157 loc) · 8.55 KB
/
train_flow_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# torch imports
import torch
import torch.nn.functional as F
# python miscellaneous
from collections import defaultdict
import pdb
# my imports
from aml_utils import select_inputs_, calculate_eucledian_metrics
from losses import l2_loss, deep_loss_, cycons_func_, chamfer_distance, local_flow_consistency, knn_loss,\
laplacian_regularization
from initializations import initialize_flow_extractor_loss
def train_step_flow_extractor(args, params, use_flow_signal=False, supervised=False, n_points=None):
"""implements functions to perform ONE train step of the flow extractor"""
# initialize some recurrent parameters
n_points = params["n_points"] if n_points is None else n_points
use_deep_loss = not args.use_shallow_loss
device = params["device"]
# parse input for the loss module
# use_half_clouds = False if args.loss_type in ["triplet_l", "triplet_hinge", "js"] else True
select_inputs = select_inputs_(use_flow_signal, n_points, half_cloud=True)
out_acc_dict = defaultdict(lambda: 0.0)
out_error_dict = defaultdict(lambda: 0.0)
# initialize loss functions
if not supervised:
loss_func_shallow = initialize_flow_extractor_loss(loss_type=args.loss_type, params=params)
loss_func_deep = initialize_flow_extractor_loss(loss_type=params["loss_type_deep"], params=params)
deep_loss = deep_loss_(loss_func_deep, use_deep_loss, device, scale_type=args.deep_loss_scale)
else:
if args.train_type == 'knn':
if args.loss_type == 'knn':
loss_func_supervised = knn_loss
else:
loss_func_supervised = l2_loss(dim=1, norm=params["norm_FE"])
if (not supervised) and (params["cycle_consistency"] is not None):
cycons_func = cycons_func_(params["cycle_consistency"])
elif params["cycle_consistency_sup"] is not None:
cycons_func = cycons_func_(params["cycle_consistency_sup"])
def train_step_FE_func_uns(flow_extractor, cloud_embedder, c1, c2, flow_t1=None):
# train the flow extractor
flow_extractor.train()
cloud_embedder.eval()
flow_pred = flow_extractor(c1, c2)
c2_pred = c1 + flow_pred
c1_, c_anchor, c_negative, c_positive = select_inputs(c1, c2, c2_pred, flow_t1)
f_0, hidden_feats_0 = cloud_embedder(c1_, c_anchor)
f_p, hidden_feats_p = cloud_embedder(c1_, c_positive)
f_n, hidden_feats_n = cloud_embedder(c1_, c_negative)
loss_hidden_feats = deep_loss(hidden_feats_0, hidden_feats_p, hidden_feats_n)
loss_feat = loss_func_shallow(f_0, f_p, f_n)
loss_FE = loss_feat + loss_hidden_feats
out_error_dict["loss_feat"], out_error_dict["loss_hidden_feats"] = loss_feat.item(), loss_hidden_feats.item()
if params["cycle_consistency"] is not None:
c2_pred = c2_pred.detach()
if args.cycon_aug:
c2_pred = torch.cat((c2_pred, c2), dim=-1)
flow_pred_backwards = flow_extractor(c2_pred, c1)[..., :flow_pred.shape[-1]]
loss_cycons = cycons_func(flow_pred, flow_pred_backwards)
loss_FE = loss_cycons * args.cycon_contribution + loss_FE
out_error_dict["loss_cycons"] = loss_cycons.item()
c2_pred = c2_pred[..., :flow_pred.shape[-1]]
if params["local_consistency"]:
# loss_loccons = local_flow_consistency(pc1=c1, flow_pred=flow_pred)
loss_loccons = local_flow_consistency(c1, flow_pred)
loss_FE = loss_loccons * params["local_consistency"] + loss_FE
out_error_dict["loss_loccons"] = loss_loccons.item()
if params["chamfer"] > 0:
chamfer_dist = chamfer_distance(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["chamfer"] * chamfer_dist
out_error_dict["chamfer"] = chamfer_dist.item()
if params["laplace"] > 0:
laplace_loss = laplacian_regularization(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["laplace"] * laplace_loss
out_error_dict["laplace"] = laplace_loss.item()
if params["static_penalty"] > 0:
static_penalty = (c2_pred[..., :c1.shape[-1]] - c1).norm(dim=1).clamp(max=1).mean()
loss_FE += params["static_penalty"] * static_penalty
out_error_dict["static"] = static_penalty.item()
if flow_t1 is not None:
flow_errors_dict, flow_accs_dict = calculate_eucledian_metrics(flow_pred, flow_t1)
out_error_dict.update(flow_errors_dict)
out_acc_dict.update(flow_accs_dict)
out_dict = {"error": out_error_dict, "acc": out_acc_dict}
return loss_FE, out_dict
def train_step_FE_func_sup(flow_extractor, c1, c2, flow_t1=None):
# train the flow extractor
flow_extractor.train()
flow_pred = flow_extractor(c1, c2)
c2_pred = c1 + flow_pred
loss_FE = 0
# pdb.set_trace()
loss_FE += args.sup_scale * loss_func_supervised(flow_pred, flow_t1)
if params["cycle_consistency_sup"] is not None:
c2_pred_ = c2_pred.detach()
flow_pred_backwards = flow_extractor(c2_pred_, c1)
loss_cycons = cycons_func(flow_pred, flow_pred_backwards)
loss_FE = loss_cycons * args.cycon_contribution + loss_FE
out_error_dict["loss_cycons"] = loss_cycons.item()
if params["local_consistency"]:
loss_loccons = local_flow_consistency(c1, flow_pred)
loss_FE = loss_loccons * params["local_consistency"] + loss_FE
out_error_dict["loss_loccons"] = loss_loccons.item()
if params["chamfer"] > 0:
chamfer_dist = chamfer_distance(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["chamfer"] * chamfer_dist
out_error_dict["chamfer"] = chamfer_dist.item()
if params["laplace"] > 0:
laplace_loss = laplacian_regularization(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["laplace"] * laplace_loss
out_error_dict["laplace"] = laplace_loss.item()
if params["static_penalty"] > 0:
static_penalty = (c2_pred[..., :c1.shape[-1]] - c1).norm(dim=1).clamp(max=1).mean()
loss_FE += params["static_penalty"] * static_penalty
out_error_dict["static"] = static_penalty.item()
flow_errors_dict, flow_accs_dict = calculate_eucledian_metrics(flow_pred, flow_t1)
out_error_dict.update(flow_errors_dict)
out_acc_dict.update(flow_accs_dict)
out_dict = {"error": out_error_dict, "acc": out_acc_dict}
return loss_FE, out_dict
def train_step_FE_func_knn_sup(flow_extractor, c1, c2, flow_t1=None):
# train the flow extractor using knn loss
flow_extractor.train()
flow_pred = flow_extractor(c1, c2)
c2_pred = c1 + flow_pred
loss_FE = loss_func_supervised(c1 + flow_pred, c2)
if params["cycle_consistency_sup"] is not None:
c2_pred_ = c2_pred.detach()
flow_pred_backwards = flow_extractor(c2_pred_, c1)
loss_cycons = cycons_func(flow_pred, flow_pred_backwards)
loss_FE = loss_cycons * args.cycon_contribution + loss_FE
out_error_dict["loss_cycons"] = loss_cycons.item()
if params["local_consistency"]:
loss_loccons = local_flow_consistency(c1, flow_pred)
loss_FE = loss_loccons * params["local_consistency"] + loss_FE
out_error_dict["loss_loccons"] = loss_loccons.item()
if params["chamfer"] > 0:
chamfer_dist = chamfer_distance(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["chamfer"] * chamfer_dist
out_error_dict["chamfer"] = chamfer_dist.item()
if params["laplace"] > 0:
laplace_loss = laplacian_regularization(pc_pred=c2_pred, pc_target=c2)
loss_FE += params["laplace"] * laplace_loss
out_error_dict["laplace"] = laplace_loss.item()
if params["static_penalty"] > 0:
static_penalty = (c2_pred[..., :c1.shape[-1]] - c1).norm(dim=1).clamp(max=1).mean()
loss_FE += params["static_penalty"] * static_penalty
out_error_dict["static"] = static_penalty.item()
flow_errors_dict, flow_accs_dict = calculate_eucledian_metrics(flow_pred, flow_t1)
out_error_dict.update(flow_errors_dict)
out_acc_dict.update(flow_accs_dict)
out_dict = {"error": out_error_dict, "acc": out_acc_dict}
return loss_FE, out_dict
if supervised:
if args.train_type == "knn":
return train_step_FE_func_knn_sup
return train_step_FE_func_sup
else:
return train_step_FE_func_uns