-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_loss_eval_hook.py
executable file
·76 lines (65 loc) · 2.83 KB
/
my_loss_eval_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from detectron2.engine.hooks import HookBase
# from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
# from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm
import torch
import time
import datetime
import numpy as np
import logging
class LossEvalHook(HookBase):
def __init__(self, eval_period, model, data_loader):
self._model = model
self._period = eval_period
self._data_loader = data_loader
def _do_loss_eval(self):
# Copying inference_on_dataset from evaluator.py
total = len(self._data_loader)
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
mean_loss = np.mean(losses)
self.trainer.storage.put_scalar('validation_loss', mean_loss)
comm.synchronize()
return losses
def _get_loss(self, data):
# How loss is calculated on train_loop
metrics_dict = self._model(data)
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
return total_losses_reduced
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
self.trainer.storage.put_scalars(timetest=12)