mirror of https://github.com/YifanXu74/MQ-Det.git
169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from maskrcnn_benchmark.layers import ROIAlign, ROIAlignV2
|
|
|
|
from .utils import cat
|
|
|
|
|
|
class LevelMapper(object):
|
|
"""Determine which FPN level each RoI in a set of RoIs should map to based
|
|
on the heuristic in the FPN paper.
|
|
"""
|
|
|
|
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
|
|
"""
|
|
Arguments:
|
|
k_min (int)
|
|
k_max (int)
|
|
canonical_scale (int)
|
|
canonical_level (int)
|
|
eps (float)
|
|
"""
|
|
self.k_min = k_min
|
|
self.k_max = k_max
|
|
self.s0 = canonical_scale
|
|
self.lvl0 = canonical_level
|
|
self.eps = eps
|
|
|
|
def __call__(self, boxlists):
|
|
"""
|
|
Arguments:
|
|
boxlists (list[BoxList])
|
|
"""
|
|
# Compute level ids
|
|
s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))
|
|
|
|
# Eqn.(1) in FPN paper
|
|
target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
|
|
target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
|
|
return target_lvls.to(torch.int64) - self.k_min
|
|
|
|
|
|
class Pooler(nn.Module):
|
|
"""
|
|
Pooler for Detection with or without FPN.
|
|
It currently hard-code ROIAlign in the implementation,
|
|
but that can be made more generic later on.
|
|
Also, the requirement of passing the scales is not strictly necessary, as they
|
|
can be inferred from the size of the feature map / size of original image,
|
|
which is available thanks to the BoxList.
|
|
"""
|
|
|
|
def __init__(self, output_size, scales, sampling_ratio, use_v2=False):
|
|
"""
|
|
Arguments:
|
|
output_size (list[tuple[int]] or list[int]): output size for the pooled region
|
|
scales (list[float]): scales for each Pooler
|
|
sampling_ratio (int): sampling ratio for ROIAlign
|
|
"""
|
|
super(Pooler, self).__init__()
|
|
poolers = []
|
|
for scale in scales:
|
|
poolers.append(
|
|
ROIAlignV2(
|
|
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
|
|
)
|
|
if use_v2 else
|
|
ROIAlign(
|
|
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
|
|
)
|
|
)
|
|
self.poolers = nn.ModuleList(poolers)
|
|
self.output_size = output_size
|
|
# get the levels in the feature map by leveraging the fact that the network always
|
|
# downsamples by a factor of 2 at each level.
|
|
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
|
|
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
|
|
self.map_levels = LevelMapper(lvl_min, lvl_max)
|
|
|
|
def convert_to_roi_format(self, boxes):
|
|
'''
|
|
Concate a batch id on top of each box
|
|
boxes: [(K1, 4), (K2, 4)]
|
|
rois: (K1+K2, 5)
|
|
'''
|
|
concat_boxes = cat([b.bbox for b in boxes], dim=0)
|
|
device, dtype = concat_boxes.device, concat_boxes.dtype
|
|
ids = cat(
|
|
[
|
|
torch.full((len(b), 1), i, dtype=dtype, device=device)
|
|
for i, b in enumerate(boxes)
|
|
],
|
|
dim=0,
|
|
)
|
|
rois = torch.cat([ids, concat_boxes], dim=1)
|
|
return rois
|
|
|
|
def forward(self, x, boxes):
|
|
"""
|
|
Arguments:
|
|
x (list[Tensor]): feature maps for each level
|
|
boxes (list[BoxList]): boxes to be used to perform the pooling operation.
|
|
Returns:
|
|
result (Tensor)
|
|
"""
|
|
num_levels = len(self.poolers)
|
|
rois = self.convert_to_roi_format(boxes)
|
|
if num_levels == 1:
|
|
return self.poolers[0](x[0], rois)
|
|
|
|
levels = self.map_levels(boxes)
|
|
|
|
num_rois = len(rois)
|
|
num_channels = x[0].shape[1]
|
|
output_size = self.output_size[0]
|
|
|
|
dtype, device = x[0].dtype, x[0].device
|
|
result = torch.zeros(
|
|
(num_rois, num_channels, output_size, output_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
|
|
idx_in_level = torch.nonzero(levels == level).squeeze(1)
|
|
rois_per_level = rois[idx_in_level]
|
|
result[idx_in_level] = pooler(per_level_feature, rois_per_level)
|
|
|
|
return result
|
|
|
|
class CustomPooler(Pooler):
|
|
"""
|
|
Return features in every fpn level.
|
|
"""
|
|
def __init__(self, **kwargs):
|
|
super(CustomPooler, self).__init__(**kwargs)
|
|
|
|
def forward(self, x, boxes):
|
|
"""
|
|
Arguments:
|
|
x (list[Tensor]): feature maps for each level
|
|
boxes (list[BoxList]): boxes to be used to perform the pooling operation.
|
|
Returns:
|
|
result (list[Tensor]): pooled features for each level
|
|
"""
|
|
num_levels = len(self.poolers)
|
|
rois = self.convert_to_roi_format(boxes)
|
|
if num_levels == 1:
|
|
return self.poolers[0](x[0], rois)
|
|
|
|
# levels = self.map_levels(boxes)
|
|
|
|
num_rois = len(rois)
|
|
num_channels = x[0].shape[1]
|
|
output_size = self.output_size[0]
|
|
|
|
dtype, device = x[0].dtype, x[0].device
|
|
result = torch.zeros(
|
|
(num_levels, num_rois, num_channels, output_size, output_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
|
|
# idx_in_level = torch.nonzero(levels == level).squeeze(1)
|
|
# rois_per_level = rois[idx_in_level]
|
|
result[level] = pooler(per_level_feature, rois)
|
|
|
|
return result |