Skip to content

Commit

Permalink
giou
Browse files Browse the repository at this point in the history
  • Loading branch information
nathantsoi committed Mar 1, 2019
1 parent b3158fb commit d0ff754
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 0 deletions.
3 changes: 3 additions & 0 deletions research/object_detection/builders/losses_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def _build_localization_loss(loss_config):
return losses.WeightedSmoothL1LocalizationLoss(
loss_config.weighted_smooth_l1.delta)

if loss_type == 'weighted_giou':
return losses.WeightedGIoULocalizationLoss()

if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss()

Expand Down
90 changes: 90 additions & 0 deletions research/object_detection/core/box_list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ def area(boxlist, scope=None):
value=boxlist.get(), num_or_size_splits=4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])

def safe_boxlist(boxlist, scope=None):
"""Checks the min/max of the boxlist
Handles cases where ymin/ymax xmin/max are swapped.
Args:
boxlist: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing box areas.
"""
with tf.name_scope(scope, 'SafeBoxlist'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
min_y = tf.reshape(tf.minimum(y_min, y_max), [-1])
max_y = tf.reshape(tf.maximum(y_min, y_max), [-1])
min_x = tf.reshape(tf.minimum(x_min, x_max), [-1])
max_x = tf.reshape(tf.maximum(x_min, x_max), [-1])
safe_boxes = tf.stack([min_y, min_x, max_y, max_x], axis=1)
return box_list.BoxList(safe_boxes)

def height_width(boxlist, scope=None):
"""Computes height and width of boxes in boxlist.
Expand Down Expand Up @@ -252,6 +273,40 @@ def matched_intersection(boxlist1, boxlist2, scope=None):
return tf.reshape(intersect_heights * intersect_widths, [-1])


def matched_containing(boxlist1, boxlist2, scope=None):
"""Compute the smallest axis-aligned bounding box that fully contains
a pair of corresponding boxes from two boxlists.
If you find this useful in research, please consider citing:
Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression
H. Rezatofighi, N. Tsoi, J. Gwak, A. Sadeghian, I. Reid, and S. Savarese.
CVPR 2019
This area containing both bounding boxes is called "C" in GIoU
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing pairwise intersections
"""
with tf.name_scope(scope, 'MatchedContaining'):
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=boxlist1.get(), num_or_size_splits=4, axis=1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=boxlist2.get(), num_or_size_splits=4, axis=1)
max_ymax = tf.maximum(y_max1, y_max2)
min_ymin = tf.minimum(y_min1, y_min2)
containing_heights = tf.maximum(0.0, max_ymax - min_ymin)
max_xmax = tf.maximum(x_max1, x_max2)
min_xmin = tf.minimum(x_min1, x_min2)
containing_widths = tf.maximum(0.0, max_xmax - min_xmin)
return tf.reshape(containing_heights * containing_widths, [-1])


def iou(boxlist1, boxlist2, scope=None):
"""Computes pairwise intersection-over-union between box collections.
Expand Down Expand Up @@ -295,6 +350,41 @@ def matched_iou(boxlist1, boxlist2, scope=None):
tf.zeros_like(intersections), tf.truediv(intersections, unions))


def matched_giou(boxlist1, boxlist2, scope=None):
"""Compute generalized intersection-over-union between corresponding boxes in boxlists.
If you find this useful in research, please consider citing:
Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression
H. Rezatofighi, N. Tsoi, J. Gwak, A. Sadeghian, I. Reid, and S. Savarese.
CVPR 2019
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing pairwise iou scores.
"""
with tf.name_scope(scope, 'MatchedGIoU'):
epsilon = 0.00001
safe_boxlist1 = safe_boxlist(boxlist1)
safe_boxlist2 = safe_boxlist(boxlist2)
intersections = matched_intersection(safe_boxlist1, safe_boxlist2)
areas1 = area(safe_boxlist1)
areas2 = area(safe_boxlist2)
unions = areas1 + areas2 - intersections
iou = tf.where(
tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions))
containings = matched_containing(safe_boxlist1, safe_boxlist2)
# this is the `C_term` in `GIoU = IoU - C_term` as described in the GIoU paper
unoccupied_area = tf.div_no_nan((containings - unions), containings)
giou = tf.subtract(iou, unoccupied_area)
return giou


def ioa(boxlist1, boxlist2, scope=None):
"""Computes pairwise intersection-over-area between box collections.
Expand Down
24 changes: 24 additions & 0 deletions research/object_detection/core/box_list_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,30 @@ def test_matched_iou(self):
iou_output = sess.run(iou)
self.assertAllClose(iou_output, exp_output)

def test_matched_giou(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
exp_output = [-0.07500000298023224, -0.9333333373069763]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
iou = box_list_ops.matched_giou(boxes1, boxes2)
with self.test_session() as sess:
iou_output = sess.run(iou)
self.assertAllClose(iou_output, exp_output)

def test_matched_giou_when_swapped(self):
# ymin, xmin, ymax, xmax
# top, left, bottom, right
corners1 = tf.constant([[7.0, 3.0, 4.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
exp_output = [-0.07500000298023224, -0.9333333373069763]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
iou = box_list_ops.matched_giou(boxes1, boxes2)
with self.test_session() as sess:
iou_output = sess.run(iou)
self.assertAllClose(iou_output, exp_output)

def test_iouworks_on_empty_inputs(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
Expand Down
37 changes: 37 additions & 0 deletions research/object_detection/core/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,43 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights):
), axis=2)


class WeightedGIoULocalizationLoss(Loss):
"""GIoU localization loss function.
Sums the GIOU for corresponding pairs of predicted/groundtruth boxes
and for each pair assign a loss of 1 - GIoU. We then compute a weighted
sum over all pairs which is returned as the total loss.
If you find this useful in research, please consider citing:
Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression
H. Rezatofighi, N. Tsoi, J. Gwak, A. Sadeghian, I. Reid, and S. Savarese.
CVPR 2019
"""

def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded predicted boxes
target_tensor: A float tensor of shape [batch_size, num_anchors, 4]
representing the decoded target boxes
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
predicted_boxes = box_list.BoxList(tf.reshape(prediction_tensor, [-1, 4]))
target_boxes = box_list.BoxList(tf.reshape(target_tensor, [-1, 4]))
matched_giou = box_list_ops.matched_giou(predicted_boxes, target_boxes)
w_ln = tf.exp(matched_giou)
per_anchor_giou_loss = 1.0 - w_ln
return tf.reshape(weights, [-1]) * per_anchor_giou_loss



class WeightedIOULocalizationLoss(Loss):
"""IOU localization loss function.
Expand Down
99 changes: 99 additions & 0 deletions research/object_detection/core/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,105 @@ def testReturnsCorrectLossWithNoLabels(self):
self.assertAllClose(loss_output, exp_loss)


class WeightedGIoULocalizationLossTest(tf.test.TestCase):

def testReturnsCorrectLoss(self):
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])
weights = [[1.0, .5, 2.0]]
loss_op = losses.WeightedGIoULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights)
exp_loss = [0,0,3.097665]
with self.test_session() as sess:
loss_output = sess.run(loss)
self.assertAllClose(loss_output, exp_loss)

def testReturnsCorrectLossWhenSwapped(self):
# ymin, xmin, ymax, xmax
# top, left, bottom, right
prediction_tensor = tf.constant([[[2.4, 0, 1.5, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[2.4, 0, 1.5, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])
weights = [[1.0, .5, 2.0]]
loss_op = losses.WeightedGIoULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights)
exp_loss = [0,0,3.097665]
with self.test_session() as sess:
loss_output = sess.run(loss)
self.assertAllClose(loss_output, exp_loss)

def testReturnsCorrectLossSum(self):
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])
weights = [[1.0, .5, 2.0]]
loss_op = losses.WeightedGIoULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights)
loss = tf.reduce_sum(loss)
exp_loss = 3.097665
with self.test_session() as sess:
loss_output = sess.run(loss)
self.assertAllClose(loss_output, exp_loss)

def testReturnsCorrectLossGrad(self):
batch_size = 1
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])

x = np.random.rand(batch_size, 3, 4)
weights = np.random.rand(batch_size, 3, 4)
out_shape = (batch_size, 3, 4)

weights = np.array([[1.0, .5, 2.0]], dtype=np.float32)
loss_op = losses.WeightedGIoULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights)
loss_sum = tf.reduce_sum(loss)
exp_loss = 3.097665
with self.test_session() as sess:
x_ph = tf.placeholder(tf.float32, x.shape)
w_ph = tf.placeholder(tf.float32, weights.shape)
fd = {x_ph: x, w_ph: weights}
loss_output = sess.run(loss)
print("loss: {}".format(loss_output))
loss_sum_output = sess.run(loss_sum)
self.assertAllClose(loss_sum_output, exp_loss)
grad = tf.test.compute_gradient(w_ph, weights.shape, loss, out_shape, extra_feed_dict=fd)
print("grad: {}".format(grad))


def testReturnsCorrectLossWithNoLabels(self):
prediction_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[0, 0, .5, .25]]])
target_tensor = tf.constant([[[1.5, 0, 2.4, 1],
[0, 0, 1, 1],
[50, 50, 500.5, 100.25]]])
weights = [[1.0, .5, 2.0]]
losses_mask = tf.constant([False], tf.bool)
loss_op = losses.WeightedGIoULocalizationLoss()
loss = loss_op(prediction_tensor, target_tensor, weights=weights,
losses_mask=losses_mask)
loss = tf.reduce_sum(loss)
exp_loss = 0.0
with self.test_session() as sess:
loss_output = sess.run(loss)
self.assertAllClose(loss_output, exp_loss)


class WeightedSigmoidClassificationLossTest(tf.test.TestCase):

def testReturnsCorrectLoss(self):
Expand Down
5 changes: 5 additions & 0 deletions research/object_detection/protos/losses.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ message LocalizationLoss {
WeightedL2LocalizationLoss weighted_l2 = 1;
WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
WeightedIOULocalizationLoss weighted_iou = 3;
WeightedGIoULocalizationLoss weighted_giou = 4;
}
}

Expand Down Expand Up @@ -96,6 +97,10 @@ message WeightedSmoothL1LocalizationLoss {
message WeightedIOULocalizationLoss {
}

// Generalized Intersection over union location loss: 1 - IOU
message WeightedGIoULocalizationLoss {
}

// Configuration for class prediction loss function.
message ClassificationLoss {
oneof classification_loss {
Expand Down

0 comments on commit d0ff754

Please sign in to comment.