Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Commit

Permalink
Reduce HungarianMatcher's space complexity.
Browse files Browse the repository at this point in the history
The memory reduction factor of the cost matrix is sum(#target objects) / max(#target objects).

That is achieved by no longer computing and storing matching costs between predictions and targets at different positions inside the batch. More exactly the original matrix of shape [batch_size * queries, sum(#target objects)] is shrinked to a tensor of shape [batch_size, queries, max(#target objects)].

Besides allowing much larger batch sizes, tested on the table structure recognition task using the Table Transformer (TATR) (125 queries, 7 classes) with pubmed data, this change also results a) on CUDA at all batch sizes and on CPU with small batchs in a small but meaningful speedup, b) on CPU with larger batch sizes in much higher speedups.

The processing time decrease computed as (1 - new_time / old_time) is shown below in various configuration:

Batch |   Device
 size | cuda    cpu
------------------
1       8.2%   1.6%
2       1.6%   9.3%
3       1.6%   7.7%
4       0.9%  11.2%
5       0.8%  13.9%
6       0.9%  15.5%
7       0.9%  23.1%
8             47.1%
16            70.6%
32            88.3%
64            95.0%
  • Loading branch information
alcinos authored and dai20242024 committed Sep 18, 2023
1 parent 3af9fa8 commit 4833f70
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 43 deletions.
77 changes: 55 additions & 22 deletions models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou

Expand Down Expand Up @@ -52,34 +53,66 @@ def forward(self, outputs, targets):
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]

# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]

# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])

# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -out_prob[:, tgt_ids]
# In the comments below:
# - `bs` is the batch size, i.e. outputs["pred_logits"].shape[0];
# - `mo` is the maximum number of objects over all the targets,
# i.e. `max((len(v["labels"]) for v in targets))`;
# - `q` is the number of queries, i.e. outputs["pred_logits"].shape[1];
# - `cl` is the number of classes including no-object,
# i.e. outputs["pred_logits"].shape[2] or self.num_classes + 1.
if len(targets) == 1:
# This branch is just an optimization, not needed for correctness.
tgt_ids = targets[0]["labels"].unsqueeze(dim=0)
tgt_bbox = targets[0]["boxes"].unsqueeze(dim=0)
else:
tgt_ids = pad_sequence(
[target["labels"] for target in targets],
batch_first=True,
padding_value=0
) # (bs, mo)
tgt_bbox = pad_sequence(
[target["boxes"] for target in targets],
batch_first=True,
padding_value=0
) # (bs, mo, 4)

out_bbox = outputs["pred_boxes"] # (bs, q, 4)

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # (bs, q, mo)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
out_bbox_xyxy = box_cxcywh_to_xyxy(out_bbox)
tgt_bbox_xyxy = box_cxcywh_to_xyxy(tgt_bbox)
giou = generalized_box_iou(
out_bbox_xyxy, tgt_bbox_xyxy) # (bs, q, mo)

# Compute the classification cost. Contrary to the loss, we don't use
# the Negative Log Likelihood, but approximate it
# in `1 - proba[target class]`. The 1 is a constant that does not
# change the matching, it can be ommitted.
out_prob = outputs["pred_logits"].softmax(-1) # (bs, q, c)
prob_class = torch.gather(
out_prob,
dim=2,
index=tgt_ids.unsqueeze(dim=1).expand(-1, out_prob.shape[1], -1)
) # (bs, q, mo)

# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()

sizes = [len(v["boxes"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
C = self.cost_bbox * cost_bbox - self.cost_giou * giou - self.cost_class * prob_class
c = C.cpu()

indices = [
linear_sum_assignment(c[i, :, :len(v["labels"])])
for i, v in enumerate(targets)
]
return [
(
torch.as_tensor(i, dtype=torch.int64),
torch.as_tensor(j, dtype=torch.int64),
)
for i, j in indices
]


def build_matcher(args):
Expand Down
74 changes: 68 additions & 6 deletions test_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import io
import unittest
import functools
import operator

from itertools import combinations_with_replacement

import torch
from torch import nn, Tensor
from torchvision import ops
from typing import List

from models.matcher import HungarianMatcher
Expand Down Expand Up @@ -40,14 +45,21 @@ def test_hungarian(self):
matcher = HungarianMatcher()
targets = [{'labels': tgt_labels, 'boxes': tgt_boxes}]
indices_single = matcher({'pred_logits': logits, 'pred_boxes': boxes}, targets)
indices_batched = matcher({'pred_logits': logits.repeat(2, 1, 1),
'pred_boxes': boxes.repeat(2, 1, 1)}, targets * 2)
batch_size = 2
indices_batched = matcher(
{
'pred_logits': logits.repeat(batch_size, 1, 1),
'pred_boxes': boxes.repeat(batch_size, 1, 1),
},
targets * batch_size,
)
self.assertEqual(len(indices_single[0][0]), n_targets)
self.assertEqual(len(indices_single[0][1]), n_targets)
self.assertEqual(self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[0]]))
self.assertEqual(self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[1]]))
for i in range(batch_size):
self.assertEqual(
self.indices_torch2python(indices_single),
self.indices_torch2python([indices_batched[i]]),
)

# test with empty targets
tgt_labels_empty = torch.randint(high=n_classes, size=(0,))
Expand Down Expand Up @@ -102,6 +114,56 @@ def test_model_detection_different_inputs(self):
out = model([x])
self.assertIn('pred_logits', out)

def test_box_iou_multiple_dimensions(self):
for extra_dims in range(3):
for extra_lengths in combinations_with_replacement(range(1, 4), extra_dims):
p = functools.reduce(operator.mul, extra_lengths, 1)
for n in range(3):
a = torch.rand(extra_lengths + (n, 4))
for m in range(3):
b = torch.rand(extra_lengths + (m, 4))
iou, union = box_ops.box_iou(a, b)
self.assertTupleEqual(iou.shape, union.shape)
self.assertTupleEqual(iou.shape, extra_lengths + (n, m))
iou_it = iter(iou.view(p, n, m))
for x, y in zip(a.view(p, n, 4), b.view(p, m, 4), strict=True):
self.assertTrue(
torch.equal(next(iou_it), ops.box_iou(x, y))
)

def test_generalized_box_iou_multiple_dimensions(self):
a = torch.tensor([1, 1, 2, 2])
b = torch.tensor([1, 2, 3, 5])
ab = -0.1250
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(a[None, :], b[None, :]),
torch.Tensor([[ab]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(a[None, None, :], b[None, None, :]),
torch.Tensor([[[ab]]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(
a[None, None, None, :], b[None, None, None, :]
),
torch.Tensor([[[[ab]]]]),
)
)
self.assertTrue(
torch.equal(
box_ops.generalized_box_iou(
torch.stack([a, a, b, b]), torch.stack([a, b])
),
torch.Tensor(torch.Tensor([[1, ab], [1, ab], [ab, 1], [ab, 1]])),
)
)

def test_warpped_model_script_detection(self):
class WrappedDETR(nn.Module):
def __init__(self, model):
Expand Down
30 changes: 15 additions & 15 deletions util/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ def box_xyxy_to_cxcywh(x):
return torch.stack(b, dim=-1)


# modified from torchvision to also return the union
# Modified from torchvision to also return the union and to work only on the
# last two dimensions, assuming the other ones are identical.
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [..., N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [..., N,M,2]

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
inter = wh[..., 0] * wh[..., 1] # [..., N,M]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter
area1 = box_area(boxes1.view(-1, 4)).view(boxes1.shape[:-1])
area2 = box_area(boxes2.view(-1, 4)).view(boxes2.shape[:-1])
union = area1[..., None] + area2[..., None, :] - inter

iou = inter / union
return iou, union
Expand All @@ -48,15 +48,15 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
assert (boxes1[..., 2:] >= boxes1[..., :2]).all()
assert (boxes2[..., 2:] >= boxes2[..., :2]).all()
iou, union = box_iou(boxes1, boxes2)

lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
lt = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
rb = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])

wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
area = wh[..., 0] * wh[..., 1]

return iou - (area - union) / area

Expand Down

0 comments on commit 4833f70

Please sign in to comment.