mirror of https://github.com/YifanXu74/MQ-Det.git
322 lines
12 KiB
Python
322 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import torch
|
|
|
|
# transpose
|
|
FLIP_LEFT_RIGHT = 0
|
|
FLIP_TOP_BOTTOM = 1
|
|
|
|
|
|
class BoxList(object):
|
|
"""
|
|
This class represents a set of bounding boxes.
|
|
The bounding boxes are represented as a Nx4 Tensor.
|
|
In order to uniquely determine the bounding boxes with respect
|
|
to an image, we also store the corresponding image dimensions.
|
|
They can contain extra information that is specific to each bounding box, such as
|
|
labels.
|
|
"""
|
|
|
|
def __init__(self, bbox, image_size, mode="xyxy"):
|
|
device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
|
|
# only do as_tensor if isn't a "no-op", because it hurts JIT tracing
|
|
if (not isinstance(bbox, torch.Tensor)
|
|
or bbox.dtype != torch.float32 or bbox.device != device):
|
|
bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
|
|
if bbox.ndimension() != 2:
|
|
raise ValueError(
|
|
"bbox should have 2 dimensions, got {}".format(bbox.ndimension())
|
|
)
|
|
if bbox.size(-1) != 4:
|
|
raise ValueError(
|
|
"last dimenion of bbox should have a "
|
|
"size of 4, got {}".format(bbox.size(-1))
|
|
)
|
|
if mode not in ("xyxy", "xywh"):
|
|
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
|
|
|
self.bbox = bbox
|
|
self.size = image_size # (image_width, image_height)
|
|
self.mode = mode
|
|
self.extra_fields = {}
|
|
|
|
# note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between
|
|
def _jit_unwrap(self):
|
|
return (self.bbox,) + tuple(f for f in (self.get_field(field)
|
|
for field in sorted(self.fields()))
|
|
if isinstance(f, torch.Tensor))
|
|
|
|
def _jit_wrap(self, input_stream):
|
|
self.bbox = input_stream[0]
|
|
num_consumed = 1
|
|
for f in sorted(self.fields()):
|
|
if isinstance(self.extra_fields[f], torch.Tensor):
|
|
self.extra_fields[f] = input_stream[num_consumed]
|
|
num_consumed += 1
|
|
return self, input_stream[num_consumed:]
|
|
|
|
def add_field(self, field, field_data):
|
|
self.extra_fields[field] = field_data
|
|
|
|
def get_field(self, field):
|
|
return self.extra_fields[field]
|
|
|
|
def has_field(self, field):
|
|
return field in self.extra_fields
|
|
|
|
def fields(self):
|
|
return list(self.extra_fields.keys())
|
|
|
|
def _copy_extra_fields(self, bbox):
|
|
for k, v in bbox.extra_fields.items():
|
|
self.extra_fields[k] = v
|
|
|
|
def convert(self, mode):
|
|
if mode not in ("xyxy", "xywh"):
|
|
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
|
if mode == self.mode:
|
|
return self
|
|
# we only have two modes, so don't need to check
|
|
# self.mode
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if mode == "xyxy":
|
|
bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
else:
|
|
TO_REMOVE = 1
|
|
# NOTE: explicitly specify dim to avoid tracing error in GPU
|
|
bbox = torch.cat(
|
|
(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1
|
|
)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
bbox._copy_extra_fields(self)
|
|
return bbox
|
|
|
|
def _split_into_xyxy(self):
|
|
if self.mode == "xyxy":
|
|
xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
|
|
return xmin, ymin, xmax, ymax
|
|
elif self.mode == "xywh":
|
|
TO_REMOVE = 1
|
|
xmin, ymin, w, h = self.bbox.split(1, dim=-1)
|
|
return (
|
|
xmin,
|
|
ymin,
|
|
xmin + (w - TO_REMOVE).clamp(min=0),
|
|
ymin + (h - TO_REMOVE).clamp(min=0),
|
|
)
|
|
else:
|
|
raise RuntimeError("Should not be here")
|
|
|
|
def resize(self, size, *args, **kwargs):
|
|
"""
|
|
Returns a resized copy of this bounding box
|
|
|
|
:param size: The requested size in pixels, as a 2-tuple:
|
|
(width, height).
|
|
"""
|
|
|
|
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
|
|
if ratios[0] == ratios[1]:
|
|
ratio = ratios[0]
|
|
scaled_box = self.bbox * ratio
|
|
bbox = BoxList(scaled_box, size, mode=self.mode)
|
|
# bbox._copy_extra_fields(self)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, torch.Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
return bbox
|
|
|
|
ratio_width, ratio_height = ratios
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
scaled_xmin = xmin * ratio_width
|
|
scaled_xmax = xmax * ratio_width
|
|
scaled_ymin = ymin * ratio_height
|
|
scaled_ymax = ymax * ratio_height
|
|
scaled_box = torch.cat(
|
|
(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(scaled_box, size, mode="xyxy")
|
|
# bbox._copy_extra_fields(self)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, torch.Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
|
|
return bbox.convert(self.mode)
|
|
|
|
def transpose(self, method):
|
|
"""
|
|
Transpose bounding box (flip or rotate in 90 degree steps)
|
|
:param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
|
|
:py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
|
|
:py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
|
|
:py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
|
|
"""
|
|
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
|
|
raise NotImplementedError(
|
|
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
|
|
)
|
|
|
|
image_width, image_height = self.size
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if method == FLIP_LEFT_RIGHT:
|
|
TO_REMOVE = 1
|
|
transposed_xmin = image_width - xmax - TO_REMOVE
|
|
transposed_xmax = image_width - xmin - TO_REMOVE
|
|
transposed_ymin = ymin
|
|
transposed_ymax = ymax
|
|
elif method == FLIP_TOP_BOTTOM:
|
|
transposed_xmin = xmin
|
|
transposed_xmax = xmax
|
|
transposed_ymin = image_height - ymax
|
|
transposed_ymax = image_height - ymin
|
|
|
|
transposed_boxes = torch.cat(
|
|
(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
|
# bbox._copy_extra_fields(self)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, torch.Tensor):
|
|
v = v.transpose(method)
|
|
bbox.add_field(k, v)
|
|
return bbox.convert(self.mode)
|
|
|
|
def crop(self, box):
|
|
"""
|
|
Cropss a rectangular region from this bounding box. The box is a
|
|
4-tuple defining the left, upper, right, and lower pixel
|
|
coordinate.
|
|
"""
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
w, h = box[2] - box[0], box[3] - box[1]
|
|
cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
|
|
cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
|
|
cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
|
|
cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
|
|
|
|
# TODO should I filter empty boxes here?
|
|
cropped_box = torch.cat(
|
|
(cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(cropped_box, (w, h), mode="xyxy")
|
|
# bbox._copy_extra_fields(self)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, torch.Tensor):
|
|
v = v.crop(box)
|
|
bbox.add_field(k, v)
|
|
return bbox.convert(self.mode)
|
|
|
|
# Tensor-like methods
|
|
|
|
def to(self, device):
|
|
bbox = BoxList(self.bbox.to(device), self.size, self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
if hasattr(v, "to"):
|
|
v = v.to(device)
|
|
bbox.add_field(k, v)
|
|
return bbox
|
|
|
|
def __getitem__(self, item):
|
|
bbox = BoxList(self.bbox[item], self.size, self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
bbox.add_field(k, v[item])
|
|
return bbox
|
|
|
|
def __len__(self):
|
|
return self.bbox.shape[0]
|
|
|
|
def clip_to_image(self, remove_empty=True):
|
|
TO_REMOVE = 1
|
|
x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE)
|
|
y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE)
|
|
x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE)
|
|
y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE)
|
|
self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1)
|
|
if remove_empty:
|
|
box = self.bbox
|
|
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
|
|
return self[keep]
|
|
return self
|
|
|
|
def area(self):
|
|
if self.mode == 'xyxy':
|
|
TO_REMOVE = 1
|
|
box = self.bbox
|
|
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
|
elif self.mode == 'xywh':
|
|
box = self.bbox
|
|
area = box[:, 2] * box[:, 3]
|
|
else:
|
|
raise RuntimeError("Should not be here")
|
|
|
|
return area
|
|
|
|
def copy_with_fields(self, fields):
|
|
bbox = BoxList(self.bbox, self.size, self.mode)
|
|
if not isinstance(fields, (list, tuple)):
|
|
fields = [fields]
|
|
for field in fields:
|
|
bbox.add_field(field, self.get_field(field))
|
|
return bbox
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__ + "("
|
|
s += "num_boxes={}, ".format(len(self))
|
|
s += "image_width={}, ".format(self.size[0])
|
|
s += "image_height={}, ".format(self.size[1])
|
|
s += "mode={})".format(self.mode)
|
|
return s
|
|
|
|
@staticmethod
|
|
def concate_box_list(list_of_boxes):
|
|
boxes = torch.cat([i.bbox for i in list_of_boxes], dim = 0)
|
|
extra_fields_keys = list(list_of_boxes[0].extra_fields.keys())
|
|
extra_fields = {}
|
|
for key in extra_fields_keys:
|
|
extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim = 0)
|
|
|
|
final = list_of_boxes[0].copy_with_fields(extra_fields_keys)
|
|
|
|
final.bbox = boxes
|
|
final.extra_fields = extra_fields
|
|
return final
|
|
|
|
@torch.jit.unused
|
|
def _onnx_clip_boxes_to_image(boxes, size):
|
|
# type: (Tensor, Tuple[int, int])
|
|
"""
|
|
Clip boxes so that they lie inside an image of size `size`.
|
|
Clip's min max are traced as constants. Use torch.min/max to WAR this issue
|
|
Arguments:
|
|
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
|
|
size (Tuple[height, width]): size of the image
|
|
Returns:
|
|
clipped_boxes (Tensor[N, 4])
|
|
"""
|
|
TO_REMOVE = 1
|
|
device = boxes.device
|
|
dim = boxes.dim()
|
|
boxes_x = boxes[..., 0::2]
|
|
boxes_y = boxes[..., 1::2]
|
|
|
|
boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device))
|
|
boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device))
|
|
boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device))
|
|
boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device))
|
|
|
|
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
|
|
return clipped_boxes.reshape(boxes.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
|
|
s_bbox = bbox.resize((5, 5))
|
|
print(s_bbox)
|
|
print(s_bbox.bbox)
|
|
|
|
t_bbox = bbox.transpose(0)
|
|
print(t_bbox)
|
|
print(t_bbox.bbox)
|