-
Notifications
You must be signed in to change notification settings - Fork 7
/
eval_on_coco.py
62 lines (48 loc) · 2.24 KB
/
eval_on_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from argparse import ArgumentParser
from tqdm import tqdm
import torch
from alonet.common import add_argparse_args
from alonet.detr import CocoPanoptic2Detr
from alonet.detr_panoptic import LitPanopticDetr
from alonet.metrics import PQMetrics, ApMetrics
from aloscene import Frame
def main(args):
""" Main
"""
device = torch.device("cuda")
# Init the DetrPanoptic model with the dataset
coco_loader = CocoPanoptic2Detr(args, batch_size=1)
lit_panoptic = LitPanopticDetr(args=args)
lit_panoptic.model = lit_panoptic.model.eval().to(device)
# Define the metric used in evaluation process
pq_metric = PQMetrics()
ap_metric = ApMetrics()
# Make all predictions to use metric defined
tbar = tqdm(total=len(coco_loader.val_dataloader()) if args.ap_limit is None else args.ap_limit)
for it, data in enumerate(coco_loader.val_dataloader()):
frame = Frame.batch_list(data).to(device)
pred_boxes, pred_masks = lit_panoptic.inference(lit_panoptic(frame, threshold=0.85))
pred_boxes, pred_masks = pred_boxes[0], pred_masks[0]
gt_boxes = frame.boxes2d[0] # Get gt boxes as BoundingBoxes2D.
gt_masks = frame.segmentation[0] # Get gt masks as Mask
# Add samples to evaluate metrics
pq_metric.add_sample(p_mask=pred_masks, t_mask=gt_masks)
gt_boxes.labels, gt_masks.labels = gt_boxes.labels["category"], gt_masks.labels["category"]
ap_metric.add_sample(p_bbox=pred_boxes, p_mask=pred_masks, t_bbox=gt_boxes, t_mask=gt_masks)
tbar.update()
if args.ap_limit is not None and it >= args.ap_limit:
break
# Show the results
print("Total eval batch:", it)
ap_metric.calc_map(print_result=True)
pq_metric.calc_map(print_result=True)
if __name__ == "__main__":
# Build parser
parser = ArgumentParser(conflict_handler="resolve")
parser = add_argparse_args(parser) # Common alonet parser
parser = CocoPanoptic2Detr.add_argparse_args(parser) # Coco panoptic parser
parser = LitPanopticDetr.add_argparse_args(parser) # LitPanopticDetr training parser
parser.add_argument(
"--ap_limit", type=int, default=None, help="Limit AP computation at the given number of sample"
)
main(parser.parse_args())