mirror of https://github.com/YifanXu74/MQ-Det.git
96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import math
|
|
|
|
import torch
|
|
|
|
|
|
class BoxCoder(object):
|
|
"""
|
|
This class encodes and decodes a set of bounding boxes into
|
|
the representation used for training the regressors.
|
|
"""
|
|
|
|
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
|
|
"""
|
|
Arguments:
|
|
weights (4-element tuple)
|
|
bbox_xform_clip (float)
|
|
"""
|
|
self.weights = weights
|
|
self.bbox_xform_clip = bbox_xform_clip
|
|
|
|
def encode(self, reference_boxes, proposals):
|
|
"""
|
|
Encode a set of proposals with respect to some
|
|
reference boxes
|
|
|
|
Arguments:
|
|
reference_boxes (Tensor): reference boxes
|
|
proposals (Tensor): boxes to be encoded
|
|
"""
|
|
|
|
TO_REMOVE = 1 # TODO remove
|
|
ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
|
|
ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
|
|
ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
|
|
ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
|
|
|
|
gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
|
|
gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
|
|
gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
|
|
gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
|
|
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
|
|
targets_dw = ww * torch.log(gt_widths / ex_widths)
|
|
targets_dh = wh * torch.log(gt_heights / ex_heights)
|
|
|
|
targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
|
|
return targets
|
|
|
|
def decode(self, rel_codes, boxes):
|
|
"""
|
|
From a set of original boxes and encoded relative box offsets,
|
|
get the decoded boxes.
|
|
|
|
Arguments:
|
|
rel_codes (Tensor): encoded boxes
|
|
boxes (Tensor): reference boxes.
|
|
"""
|
|
|
|
boxes = boxes.to(rel_codes.dtype)
|
|
|
|
TO_REMOVE = 1 # TODO remove
|
|
widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
|
|
heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
|
|
ctr_x = boxes[:, 0] + 0.5 * widths
|
|
ctr_y = boxes[:, 1] + 0.5 * heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
dx = rel_codes[:, 0::4] / wx
|
|
dy = rel_codes[:, 1::4] / wy
|
|
dw = rel_codes[:, 2::4] / ww
|
|
dh = rel_codes[:, 3::4] / wh
|
|
|
|
# Prevent sending too large values into torch.exp()
|
|
dw = torch.clamp(dw, max=self.bbox_xform_clip)
|
|
dh = torch.clamp(dh, max=self.bbox_xform_clip)
|
|
|
|
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
|
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
|
pred_w = torch.exp(dw) * widths[:, None]
|
|
pred_h = torch.exp(dh) * heights[:, None]
|
|
|
|
pred_boxes = torch.zeros_like(rel_codes)
|
|
# x1
|
|
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
|
|
# y1
|
|
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
|
|
# x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
|
|
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
|
|
# y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
|
|
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
|
|
|
|
return pred_boxes
|