forked from facebookresearch/detr
-
Notifications
You must be signed in to change notification settings - Fork 2
/
engine.py
226 lines (184 loc) · 9.37 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
import random
import os
import sys
from typing import Iterable
import wandb
import torch
import util.misc as utils
from util.io import create_wandb_img
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator
LOG_IDX = [] # Global list to make image logging consistent over epochs
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0):
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, log_step=0):
model.eval()
criterion.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
panoptic_evaluator = None
if 'panoptic' in postprocessors.keys():
panoptic_evaluator = PanopticEvaluator(
data_loader.dataset.ann_file,
data_loader.dataset.ann_folder,
output_dir=os.path.join(output_dir, "panoptic_eval"),
)
dataset = data_loader.dataset
classes = {cat["id"]: cat["name"] for cat in dataset.coco.dataset["categories"]}
wandb_imgs = {"images": [], "self_attention": [], "attention": []}
# Log every 50 steps and in step 0
log_this = output_dir and utils.is_main_process() and ((log_step + 1) % 50 == 0 or log_step == 0)
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
log_image = False
if log_this:
if len(LOG_IDX) == 15:
if targets[0]["image_id"] in LOG_IDX:
log_image = True
elif random.random() < 0.3 and len(targets[0]["labels"].tolist()) > 3:
LOG_IDX.append(targets[0]["image_id"])
log_image = True
if log_image:
# Taken from https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb
hooks = [
model.module.backbone[-2].register_forward_hook(
lambda self, input, output: conv_features.append(output)
),
model.module.transformer.encoder.layers[-1].self_attn.register_forward_hook(
lambda self, input, output: enc_attn_weights.append(output[1])
),
model.module.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
lambda self, input, output: dec_attn_weights.append(output[1])
),
]
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)
# Gather images to log to wandb
if log_image:
# get the HxW shape of the feature maps of the CNN
f_map = conv_features[-1]['0'].tensors.cpu()
shape = f_map.shape[-2:]
sattn = enc_attn_weights[-1][0].reshape(shape + shape).cpu()
dec_att = dec_attn_weights[-1].cpu()
target = targets[0]
logits = outputs["pred_logits"][0]
boxes = outputs["pred_boxes"][0]
pred = {"pred_logits": logits, "pred_boxes": boxes}
name = dataset.coco.imgs[target["image_id"].item()]["file_name"]
path = os.path.join(dataset.root, name)
img, self_attention, att_map = create_wandb_img(classes, path, target, pred, sattn, f_map, dec_att)
wandb_imgs["images"].append(img)
wandb_imgs["self_attention"].append(self_attention)
wandb_imgs["attention"].append(att_map)
# Free memory
del conv_features[-1]
del enc_attn_weights[-1]
for hook in hooks:
hook.remove()
if 'segm' in postprocessors.keys():
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
if panoptic_evaluator is not None:
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
for i, target in enumerate(targets):
image_id = target["image_id"].item()
file_name = f"{image_id:012d}.png"
res_pano[i]["image_id"] = image_id
res_pano[i]["file_name"] = file_name
panoptic_evaluator.update(res_pano)
# Log all images to wandb
if log_this:
wandb.log({"Images": wandb_imgs["images"]}, step=log_step)
wandb.log({"Self Attention": wandb_imgs["self_attention"]}, step=log_step)
wandb.log({"Attention": wandb_imgs["attention"]}, step=log_step)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
panoptic_res = None
if panoptic_evaluator is not None:
panoptic_res = panoptic_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in postprocessors.keys():
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
if panoptic_res is not None:
stats['PQ_all'] = panoptic_res["All"]
stats['PQ_th'] = panoptic_res["Things"]
stats['PQ_st'] = panoptic_res["Stuff"]
return stats, coco_evaluator