diff --git a/models/matcher.py b/models/matcher.py index 0c2914739..4a7255a4e 100644 --- a/models/matcher.py +++ b/models/matcher.py @@ -5,6 +5,7 @@ import torch from scipy.optimize import linear_sum_assignment from torch import nn +from torch.nn.utils.rnn import pad_sequence from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou @@ -52,34 +53,66 @@ def forward(self, outputs, targets): For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ - bs, num_queries = outputs["pred_logits"].shape[:2] - - # We flatten to compute the cost matrices in a batch - out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] - - # Also concat the target labels and boxes - tgt_ids = torch.cat([v["labels"] for v in targets]) - tgt_bbox = torch.cat([v["boxes"] for v in targets]) - - # Compute the classification cost. Contrary to the loss, we don't use the NLL, - # but approximate it in 1 - proba[target class]. - # The 1 is a constant that doesn't change the matching, it can be ommitted. - cost_class = -out_prob[:, tgt_ids] + # In the comments below: + # - `bs` is the batch size, i.e. outputs["pred_logits"].shape[0]; + # - `mo` is the maximum number of objects over all the targets, + # i.e. `max((len(v["labels"]) for v in targets))`; + # - `q` is the number of queries, i.e. outputs["pred_logits"].shape[1]; + # - `cl` is the number of classes including no-object, + # i.e. outputs["pred_logits"].shape[2] or self.num_classes + 1. + if len(targets) == 1: + # This branch is just an optimization, not needed for correctness. + tgt_ids = targets[0]["labels"].unsqueeze(dim=0) + tgt_bbox = targets[0]["boxes"].unsqueeze(dim=0) + else: + tgt_ids = pad_sequence( + [target["labels"] for target in targets], + batch_first=True, + padding_value=0 + ) # (bs, mo) + tgt_bbox = pad_sequence( + [target["boxes"] for target in targets], + batch_first=True, + padding_value=0 + ) # (bs, mo, 4) + + out_bbox = outputs["pred_boxes"] # (bs, q, 4) # Compute the L1 cost between boxes - cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # (bs, q, mo) # Compute the giou cost betwen boxes - cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + out_bbox_xyxy = box_cxcywh_to_xyxy(out_bbox) + tgt_bbox_xyxy = box_cxcywh_to_xyxy(tgt_bbox) + giou = generalized_box_iou( + out_bbox_xyxy, tgt_bbox_xyxy) # (bs, q, mo) + + # Compute the classification cost. Contrary to the loss, we don't use + # the Negative Log Likelihood, but approximate it + # in `1 - proba[target class]`. The 1 is a constant that does not + # change the matching, it can be ommitted. + out_prob = outputs["pred_logits"].softmax(-1) # (bs, q, c) + prob_class = torch.gather( + out_prob, + dim=2, + index=tgt_ids.unsqueeze(dim=1).expand(-1, out_prob.shape[1], -1) + ) # (bs, q, mo) # Final cost matrix - C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou - C = C.view(bs, num_queries, -1).cpu() - - sizes = [len(v["boxes"]) for v in targets] - indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] - return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + C = self.cost_bbox * cost_bbox - self.cost_giou * giou - self.cost_class * prob_class + c = C.cpu() + + indices = [ + linear_sum_assignment(c[i, :, :len(v["labels"])]) + for i, v in enumerate(targets) + ] + return [ + ( + torch.as_tensor(i, dtype=torch.int64), + torch.as_tensor(j, dtype=torch.int64), + ) + for i, j in indices + ] def build_matcher(args): diff --git a/test_all.py b/test_all.py index 7153892ff..c4d2c68f8 100644 --- a/test_all.py +++ b/test_all.py @@ -1,9 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import io import unittest +import functools +import operator + +from itertools import combinations_with_replacement import torch from torch import nn, Tensor +from torchvision import ops from typing import List from models.matcher import HungarianMatcher @@ -40,14 +45,21 @@ def test_hungarian(self): matcher = HungarianMatcher() targets = [{'labels': tgt_labels, 'boxes': tgt_boxes}] indices_single = matcher({'pred_logits': logits, 'pred_boxes': boxes}, targets) - indices_batched = matcher({'pred_logits': logits.repeat(2, 1, 1), - 'pred_boxes': boxes.repeat(2, 1, 1)}, targets * 2) + batch_size = 2 + indices_batched = matcher( + { + 'pred_logits': logits.repeat(batch_size, 1, 1), + 'pred_boxes': boxes.repeat(batch_size, 1, 1), + }, + targets * batch_size, + ) self.assertEqual(len(indices_single[0][0]), n_targets) self.assertEqual(len(indices_single[0][1]), n_targets) - self.assertEqual(self.indices_torch2python(indices_single), - self.indices_torch2python([indices_batched[0]])) - self.assertEqual(self.indices_torch2python(indices_single), - self.indices_torch2python([indices_batched[1]])) + for i in range(batch_size): + self.assertEqual( + self.indices_torch2python(indices_single), + self.indices_torch2python([indices_batched[i]]), + ) # test with empty targets tgt_labels_empty = torch.randint(high=n_classes, size=(0,)) @@ -102,6 +114,56 @@ def test_model_detection_different_inputs(self): out = model([x]) self.assertIn('pred_logits', out) + def test_box_iou_multiple_dimensions(self): + for extra_dims in range(3): + for extra_lengths in combinations_with_replacement(range(1, 4), extra_dims): + p = functools.reduce(operator.mul, extra_lengths, 1) + for n in range(3): + a = torch.rand(extra_lengths + (n, 4)) + for m in range(3): + b = torch.rand(extra_lengths + (m, 4)) + iou, union = box_ops.box_iou(a, b) + self.assertTupleEqual(iou.shape, union.shape) + self.assertTupleEqual(iou.shape, extra_lengths + (n, m)) + iou_it = iter(iou.view(p, n, m)) + for x, y in zip(a.view(p, n, 4), b.view(p, m, 4)): + self.assertTrue( + torch.equal(next(iou_it), ops.box_iou(x, y)) + ) + + def test_generalized_box_iou_multiple_dimensions(self): + a = torch.tensor([1, 1, 2, 2]) + b = torch.tensor([1, 2, 3, 5]) + ab = -0.1250 + self.assertTrue( + torch.equal( + box_ops.generalized_box_iou(a[None, :], b[None, :]), + torch.Tensor([[ab]]), + ) + ) + self.assertTrue( + torch.equal( + box_ops.generalized_box_iou(a[None, None, :], b[None, None, :]), + torch.Tensor([[[ab]]]), + ) + ) + self.assertTrue( + torch.equal( + box_ops.generalized_box_iou( + a[None, None, None, :], b[None, None, None, :] + ), + torch.Tensor([[[[ab]]]]), + ) + ) + self.assertTrue( + torch.equal( + box_ops.generalized_box_iou( + torch.stack([a, a, b, b]), torch.stack([a, b]) + ), + torch.Tensor(torch.Tensor([[1, ab], [1, ab], [ab, 1], [ab, 1]])), + ) + ) + def test_warpped_model_script_detection(self): class WrappedDETR(nn.Module): def __init__(self, model): diff --git a/util/box_ops.py b/util/box_ops.py index 9c088e5ba..0d6c89b58 100644 --- a/util/box_ops.py +++ b/util/box_ops.py @@ -20,18 +20,18 @@ def box_xyxy_to_cxcywh(x): return torch.stack(b, dim=-1) -# modified from torchvision to also return the union +# Modified from torchvision to also return the union and to work only on the +# last two dimensions, assuming the other ones are identical. def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) + lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [..., N,M,2] + rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [..., N,M,2] - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + wh = (rb - lt).clamp(min=0) # [..., N,M,2] + inter = wh[..., 0] * wh[..., 1] # [..., N,M] - wh = (rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter + area1 = box_area(boxes1.view(-1, 4)).view(boxes1.shape[:-1]) + area2 = box_area(boxes2.view(-1, 4)).view(boxes2.shape[:-1]) + union = area1[..., None] + area2[..., None, :] - inter iou = inter / union return iou, union @@ -48,15 +48,15 @@ def generalized_box_iou(boxes1, boxes2): """ # degenerate boxes gives inf / nan results # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + assert (boxes1[..., 2:] >= boxes1[..., :2]).all() + assert (boxes2[..., 2:] >= boxes2[..., :2]).all() iou, union = box_iou(boxes1, boxes2) - lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + lt = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2]) + rb = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) - wh = (rb - lt).clamp(min=0) # [N,M,2] - area = wh[:, :, 0] * wh[:, :, 1] + wh = (rb - lt).clamp(min=0) # [..., N,M,2] + area = wh[..., 0] * wh[..., 1] return iou - (area - union) / area