mirror of https://github.com/hero-y/BHRL
341 lines
13 KiB
Python
341 lines
13 KiB
Python
from abc import abstractmethod
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner import force_fp32
|
|
|
|
from mmdet.core import multi_apply
|
|
from ..builder import HEADS, build_loss
|
|
from .base_dense_head import BaseDenseHead
|
|
from .dense_test_mixins import BBoxTestMixin
|
|
|
|
|
|
@HEADS.register_module()
|
|
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
|
|
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).
|
|
|
|
Args:
|
|
num_classes (int): Number of categories excluding the background
|
|
category.
|
|
in_channels (int): Number of channels in the input feature map.
|
|
feat_channels (int): Number of hidden channels. Used in child classes.
|
|
stacked_convs (int): Number of stacking convs of the head.
|
|
strides (tuple): Downsample factor of each feature map.
|
|
dcn_on_last_conv (bool): If true, use dcn in the last layer of
|
|
towers. Default: False.
|
|
conv_bias (bool | str): If specified as `auto`, it will be decided by
|
|
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
|
|
None, otherwise False. Default: "auto".
|
|
loss_cls (dict): Config of classification loss.
|
|
loss_bbox (dict): Config of localization loss.
|
|
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
train_cfg (dict): Training config of anchor head.
|
|
test_cfg (dict): Testing config of anchor head.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
""" # noqa: W605
|
|
|
|
_version = 1
|
|
|
|
def __init__(self,
|
|
num_classes,
|
|
in_channels,
|
|
feat_channels=256,
|
|
stacked_convs=4,
|
|
strides=(4, 8, 16, 32, 64),
|
|
dcn_on_last_conv=False,
|
|
conv_bias='auto',
|
|
loss_cls=dict(
|
|
type='FocalLoss',
|
|
use_sigmoid=True,
|
|
gamma=2.0,
|
|
alpha=0.25,
|
|
loss_weight=1.0),
|
|
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
init_cfg=dict(
|
|
type='Normal',
|
|
layer='Conv2d',
|
|
std=0.01,
|
|
override=dict(
|
|
type='Normal',
|
|
name='conv_cls',
|
|
std=0.01,
|
|
bias_prob=0.01))):
|
|
super(AnchorFreeHead, self).__init__(init_cfg)
|
|
self.num_classes = num_classes
|
|
self.cls_out_channels = num_classes
|
|
self.in_channels = in_channels
|
|
self.feat_channels = feat_channels
|
|
self.stacked_convs = stacked_convs
|
|
self.strides = strides
|
|
self.dcn_on_last_conv = dcn_on_last_conv
|
|
assert conv_bias == 'auto' or isinstance(conv_bias, bool)
|
|
self.conv_bias = conv_bias
|
|
self.loss_cls = build_loss(loss_cls)
|
|
self.loss_bbox = build_loss(loss_bbox)
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.fp16_enabled = False
|
|
|
|
self._init_layers()
|
|
|
|
def _init_layers(self):
|
|
"""Initialize layers of the head."""
|
|
self._init_cls_convs()
|
|
self._init_reg_convs()
|
|
self._init_predictor()
|
|
|
|
def _init_cls_convs(self):
|
|
"""Initialize classification conv layers of the head."""
|
|
self.cls_convs = nn.ModuleList()
|
|
for i in range(self.stacked_convs):
|
|
chn = self.in_channels if i == 0 else self.feat_channels
|
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
|
|
conv_cfg = dict(type='DCNv2')
|
|
else:
|
|
conv_cfg = self.conv_cfg
|
|
self.cls_convs.append(
|
|
ConvModule(
|
|
chn,
|
|
self.feat_channels,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
bias=self.conv_bias))
|
|
|
|
def _init_reg_convs(self):
|
|
"""Initialize bbox regression conv layers of the head."""
|
|
self.reg_convs = nn.ModuleList()
|
|
for i in range(self.stacked_convs):
|
|
chn = self.in_channels if i == 0 else self.feat_channels
|
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
|
|
conv_cfg = dict(type='DCNv2')
|
|
else:
|
|
conv_cfg = self.conv_cfg
|
|
self.reg_convs.append(
|
|
ConvModule(
|
|
chn,
|
|
self.feat_channels,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
bias=self.conv_bias))
|
|
|
|
def _init_predictor(self):
|
|
"""Initialize predictor layers of the head."""
|
|
self.conv_cls = nn.Conv2d(
|
|
self.feat_channels, self.cls_out_channels, 3, padding=1)
|
|
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
"""Hack some keys of the model state dict so that can load checkpoints
|
|
of previous version."""
|
|
version = local_metadata.get('version', None)
|
|
if version is None:
|
|
# the key is different in early versions
|
|
# for example, 'fcos_cls' become 'conv_cls' now
|
|
bbox_head_keys = [
|
|
k for k in state_dict.keys() if k.startswith(prefix)
|
|
]
|
|
ori_predictor_keys = []
|
|
new_predictor_keys = []
|
|
# e.g. 'fcos_cls' or 'fcos_reg'
|
|
for key in bbox_head_keys:
|
|
ori_predictor_keys.append(key)
|
|
key = key.split('.')
|
|
conv_name = None
|
|
if key[1].endswith('cls'):
|
|
conv_name = 'conv_cls'
|
|
elif key[1].endswith('reg'):
|
|
conv_name = 'conv_reg'
|
|
elif key[1].endswith('centerness'):
|
|
conv_name = 'conv_centerness'
|
|
else:
|
|
assert NotImplementedError
|
|
if conv_name is not None:
|
|
key[1] = conv_name
|
|
new_predictor_keys.append('.'.join(key))
|
|
else:
|
|
ori_predictor_keys.pop(-1)
|
|
for i in range(len(new_predictor_keys)):
|
|
state_dict[new_predictor_keys[i]] = state_dict.pop(
|
|
ori_predictor_keys[i])
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys,
|
|
error_msgs)
|
|
|
|
def forward(self, feats):
|
|
"""Forward features from the upstream network.
|
|
|
|
Args:
|
|
feats (tuple[Tensor]): Features from the upstream network, each is
|
|
a 4D-tensor.
|
|
|
|
Returns:
|
|
tuple: Usually contain classification scores and bbox predictions.
|
|
cls_scores (list[Tensor]): Box scores for each scale level,
|
|
each is a 4D-tensor, the channel number is
|
|
num_points * num_classes.
|
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
|
level, each is a 4D-tensor, the channel number is
|
|
num_points * 4.
|
|
"""
|
|
return multi_apply(self.forward_single, feats)[:2]
|
|
|
|
def forward_single(self, x):
|
|
"""Forward features of a single scale level.
|
|
|
|
Args:
|
|
x (Tensor): FPN feature maps of the specified stride.
|
|
|
|
Returns:
|
|
tuple: Scores for each class, bbox predictions, features
|
|
after classification and regression conv layers, some
|
|
models needs these features like FCOS.
|
|
"""
|
|
cls_feat = x
|
|
reg_feat = x
|
|
|
|
for cls_layer in self.cls_convs:
|
|
cls_feat = cls_layer(cls_feat)
|
|
cls_score = self.conv_cls(cls_feat)
|
|
|
|
for reg_layer in self.reg_convs:
|
|
reg_feat = reg_layer(reg_feat)
|
|
bbox_pred = self.conv_reg(reg_feat)
|
|
return cls_score, bbox_pred, cls_feat, reg_feat
|
|
|
|
@abstractmethod
|
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
|
|
def loss(self,
|
|
cls_scores,
|
|
bbox_preds,
|
|
gt_bboxes,
|
|
gt_labels,
|
|
img_metas,
|
|
gt_bboxes_ignore=None):
|
|
"""Compute loss of the head.
|
|
|
|
Args:
|
|
cls_scores (list[Tensor]): Box scores for each scale level,
|
|
each is a 4D-tensor, the channel number is
|
|
num_points * num_classes.
|
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
|
level, each is a 4D-tensor, the channel number is
|
|
num_points * 4.
|
|
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
|
|
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
|
gt_labels (list[Tensor]): class indices corresponding to each box
|
|
img_metas (list[dict]): Meta information of each image, e.g.,
|
|
image size, scaling factor, etc.
|
|
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
|
|
boxes can be ignored when computing the loss.
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
|
|
def get_bboxes(self,
|
|
cls_scores,
|
|
bbox_preds,
|
|
img_metas,
|
|
cfg=None,
|
|
rescale=None):
|
|
"""Transform network output for a batch into bbox predictions.
|
|
|
|
Args:
|
|
cls_scores (list[Tensor]): Box scores for each scale level
|
|
Has shape (N, num_points * num_classes, H, W)
|
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
|
level with shape (N, num_points * 4, H, W)
|
|
img_metas (list[dict]): Meta information of each image, e.g.,
|
|
image size, scaling factor, etc.
|
|
cfg (mmcv.Config): Test / postprocessing configuration,
|
|
if None, test_cfg would be used
|
|
rescale (bool): If True, return boxes in original image space
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_targets(self, points, gt_bboxes_list, gt_labels_list):
|
|
"""Compute regression, classification and centerness targets for points
|
|
in multiple images.
|
|
|
|
Args:
|
|
points (list[Tensor]): Points of each fpn level, each has shape
|
|
(num_points, 2).
|
|
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
|
|
each has shape (num_gt, 4).
|
|
gt_labels_list (list[Tensor]): Ground truth labels of each box,
|
|
each has shape (num_gt,).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _get_points_single(self,
|
|
featmap_size,
|
|
stride,
|
|
dtype,
|
|
device,
|
|
flatten=False):
|
|
"""Get points of a single scale level."""
|
|
h, w = featmap_size
|
|
# First create Range with the default dtype, than convert to
|
|
# target `dtype` for onnx exporting.
|
|
x_range = torch.arange(w, device=device).to(dtype)
|
|
y_range = torch.arange(h, device=device).to(dtype)
|
|
y, x = torch.meshgrid(y_range, x_range)
|
|
if flatten:
|
|
y = y.flatten()
|
|
x = x.flatten()
|
|
return y, x
|
|
|
|
def get_points(self, featmap_sizes, dtype, device, flatten=False):
|
|
"""Get points according to feature map sizes.
|
|
|
|
Args:
|
|
featmap_sizes (list[tuple]): Multi-level feature map sizes.
|
|
dtype (torch.dtype): Type of points.
|
|
device (torch.device): Device of points.
|
|
|
|
Returns:
|
|
tuple: points of each image.
|
|
"""
|
|
mlvl_points = []
|
|
for i in range(len(featmap_sizes)):
|
|
mlvl_points.append(
|
|
self._get_points_single(featmap_sizes[i], self.strides[i],
|
|
dtype, device, flatten))
|
|
return mlvl_points
|
|
|
|
def aug_test(self, feats, img_metas, rescale=False):
|
|
"""Test function with test time augmentation.
|
|
|
|
Args:
|
|
feats (list[Tensor]): the outer list indicates test-time
|
|
augmentations and inner Tensor should have a shape NxCxHxW,
|
|
which contains features for all images in the batch.
|
|
img_metas (list[list[dict]]): the outer list indicates test-time
|
|
augs (multiscale, flip, etc.) and the inner list indicates
|
|
images in a batch. each dict has image information.
|
|
rescale (bool, optional): Whether to rescale the results.
|
|
Defaults to False.
|
|
|
|
Returns:
|
|
list[ndarray]: bbox results of each class
|
|
"""
|
|
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
|