Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Reduce the space complexity of the HungarianMatcher module. #606

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 55 additions & 22 deletions models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou

Expand Down Expand Up @@ -52,34 +53,66 @@ def forward(self, outputs, targets):
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]

# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]

# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])

# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# In the comments below:
# - `bs` is the batch size, i.e. outputs["pred_logits"].shape[0];
# - `mo` is the maximum number of objects over all the targets,
# i.e. `max((len(v["labels"]) for v in targets))`;
# - `q` is the number of queries, i.e. outputs["pred_logits"].shape[1];
# - `cl` is the number of classes including no-object,
# i.e. outputs["pred_logits"].shape[2] or self.num_classes + 1.
if len(targets) == 1:
# This branch is just an optimization, not needed for correctness.
tgt_ids = targets[0]["labels"].unsqueeze(dim=0)
tgt_bbox = targets[0]["boxes"].unsqueeze(dim=0)
else:
tgt_ids = pad_sequence(
[target["labels"] for target in targets],
batch_first=True,
padding_value=0
) # (bs, mo)
tgt_bbox = pad_sequence(
[target["boxes"] for target in targets],
batch_first=True,
padding_value=0
) # (bs, mo, 4)

out_bbox = outputs["pred_boxes"] # (bs, q, 4)

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # (bs, q, mo)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
out_bbox_xyxy = box_cxcywh_to_xyxy(out_bbox)
tgt_bbox_xyxy = box_cxcywh_to_xyxy(tgt_bbox)
giou = generalized_box_iou(
out_bbox_xyxy, tgt_bbox_xyxy) # (bs, q, mo)

# Compute the classification cost. Contrary to the loss, we don't use
# the Negative Log Likelihood, but approximate it
# in `1 - proba[target class]`. The 1 is a constant that does not
# change the matching, it can be ommitted.
out_prob = outputs["pred_logits"].softmax(-1) # (bs, q, c)
prob_class = torch.gather(
out_prob,
dim=2,
index=tgt_ids.unsqueeze(dim=1).expand(-1, out_prob.shape[1], -1)
) # (bs, q, mo)

# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()

sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
C = self.cost_bbox * cost_bbox - self.cost_giou * giou - self.cost_class * prob_class
c = C.cpu()

indices = [
linear_sum_assignment(c[i, :, :len(v["labels"])])
for i, v in enumerate(targets)
]
return [
(
torch.as_tensor(i, dtype=torch.int64),
torch.as_tensor(j, dtype=torch.int64),
)
for i, j in indices
]


def build_matcher(args):
Expand Down
74 changes: 68 additions & 6 deletions test_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io
import unittest
import functools
import operator

from itertools import combinations_with_replacement

import torch
from torch import nn, Tensor
from torchvision import ops
from typing import List

from models.matcher import HungarianMatcher
Expand Down Expand Up @@ -40,14 +45,21 @@ def test_hungarian(self):
matcher = HungarianMatcher()
targets = [{'labels': tgt_labels, 'boxes': tgt_boxes}]
indices_single = matcher({'pred_logits': logits, 'pred_boxes': boxes}, targets)
indices_batched = matcher({'pred_logits': logits.repeat(2, 1, 1),
'pred_boxes': boxes.repeat(2, 1, 1)}, targets * 2)
batch_size = 2
indices_batched = matcher(
{
'pred_logits': logits.repeat(batch_size, 1, 1),
'pred_boxes': boxes.repeat(batch_size, 1, 1),
},
targets * batch_size,
)
self.assertEqual(len(indices_single[0][0]), n_targets)
self.assertEqual(len(indices_single[0][1]), n_targets)
self.assertEqual(self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[0]]))
self.assertEqual(self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[1]]))
for i in range(batch_size):
self.assertEqual(
self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[i]]),
)

# test with empty targets
tgt_labels_empty = torch.randint(high=n_classes, size=(0,))
Expand Down Expand Up @@ -102,6 +114,56 @@ def test_model_detection_different_inputs(self):
out = model([x])
self.assertIn('pred_logits', out)

def test_box_iou_multiple_dimensions(self):
for extra_dims in range(3):
for extra_lengths in combinations_with_replacement(range(1, 4), extra_dims):
p = functools.reduce(operator.mul, extra_lengths, 1)
for n in range(3):
a = torch.rand(extra_lengths + (n, 4))
for m in range(3):
b = torch.rand(extra_lengths + (m, 4))
iou, union = box_ops.box_iou(a, b)
self.assertTupleEqual(iou.shape, union.shape)
self.assertTupleEqual(iou.shape, extra_lengths + (n, m))
iou_it = iter(iou.view(p, n, m))
for x, y in zip(a.view(p, n, 4), b.view(p, m, 4)):
self.assertTrue(
torch.equal(next(iou_it), ops.box_iou(x, y))
)

def test_generalized_box_iou_multiple_dimensions(self):
a = torch.tensor([1, 1, 2, 2])
b = torch.tensor([1, 2, 3, 5])
ab = -0.1250
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(a[None, :], b[None, :]),
torch.Tensor([[ab]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(a[None, None, :], b[None, None, :]),
torch.Tensor([[[ab]]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(
a[None, None, None, :], b[None, None, None, :]
),
torch.Tensor([[[[ab]]]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(
torch.stack([a, a, b, b]), torch.stack([a, b])
),
torch.Tensor(torch.Tensor([[1, ab], [1, ab], [ab, 1], [ab, 1]])),
)
)

def test_warpped_model_script_detection(self):
class WrappedDETR(nn.Module):
def __init__(self, model):
Expand Down
30 changes: 15 additions & 15 deletions util/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ def box_xyxy_to_cxcywh(x):
return torch.stack(b, dim=-1)


# modified from torchvision to also return the union
# Modified from torchvision to also return the union and to work only on the
# last two dimensions, assuming the other ones are identical.
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [..., N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [..., N,M,2]

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
inter = wh[..., 0] * wh[..., 1] # [..., N,M]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter
area1 = box_area(boxes1.view(-1, 4)).view(boxes1.shape[:-1])
area2 = box_area(boxes2.view(-1, 4)).view(boxes2.shape[:-1])
union = area1[..., None] + area2[..., None, :] - inter

iou = inter / union
return iou, union
Expand All @@ -48,15 +48,15 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert (boxes1[..., 2:] >= boxes1[..., :2]).all()
assert (boxes2[..., 2:] >= boxes2[..., :2]).all()
iou, union = box_iou(boxes1, boxes2)

lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
lt = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
rb = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])

wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
area = wh[..., 0] * wh[..., 1]

return iou - (area - union) / area

Expand Down