mirror of https://github.com/alibaba/EasyCV.git
remove RPNHeadNorm, adapted in mmlab_utils (#90)
* remove RPNHeadNorm, adapted in mmlab_utilspull/97/head
parent
0f69dbe902
commit
5110de7635
|
@ -28,7 +28,7 @@ model = dict(
|
|||
norm_cfg=norm_cfg,
|
||||
num_outs=5),
|
||||
rpn_head=dict(
|
||||
type='RPNHeadNorm',
|
||||
type='RPNHead',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
num_convs=2,
|
||||
|
@ -137,5 +137,6 @@ model = dict(
|
|||
|
||||
mmlab_modules = [
|
||||
dict(type='mmdet', name='MaskRCNN', module='model'),
|
||||
dict(type='mmdet', name='RPNHead', module='head'),
|
||||
dict(type='mmdet', name='StandardRoIHead', module='head'),
|
||||
]
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from mmcv.runner.hooks import HOOKS
|
||||
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
|
||||
annealing_cos)
|
||||
|
||||
from easycv.hooks import HOOKS
|
||||
|
||||
# initial_lr 0.01 = self.exp.basic_lr_per_img * self.args.batch_size
|
||||
# min_lr_ratio default 0.05, 0.2
|
||||
# total_iters = iters_per_epoch * total_epochs
|
||||
|
|
|
@ -9,11 +9,11 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
from mmcv.cnn import build_norm_layer, constant_init, kaiming_init
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmdet.utils import get_root_logger
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from ..registry import BACKBONES
|
||||
from ..utils import build_conv_layer
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import logging
|
||||
|
||||
from .vitdet import SFP, RPNHeadNorm
|
||||
from .vitdet import SFP
|
||||
|
||||
try:
|
||||
from .yolox.yolox import YOLOX
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
from .rpn_head_norm import RPNHeadNorm
|
||||
from .sfp import SFP
|
||||
|
|
|
@ -1,260 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.ops import batched_nms
|
||||
from mmdet.models import AnchorHead
|
||||
|
||||
from easycv.models.registry import HEADS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class RPNHeadNorm(AnchorHead):
|
||||
"""RPN head with norm.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
num_convs (int): Number of convolution layers in the head. Default 1.
|
||||
""" # noqa: W605
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
|
||||
num_convs=1,
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
self.num_convs = num_convs
|
||||
self.norm_cfg = norm_cfg
|
||||
super(RPNHeadNorm, self).__init__(
|
||||
1, in_channels, init_cfg=init_cfg, **kwargs)
|
||||
|
||||
def _init_layers(self):
|
||||
"""Initialize layers of the head."""
|
||||
if self.num_convs > 1:
|
||||
rpn_convs = []
|
||||
for i in range(self.num_convs):
|
||||
if i == 0:
|
||||
in_channels = self.in_channels
|
||||
else:
|
||||
in_channels = self.feat_channels
|
||||
# use ``inplace=False`` to avoid error: one of the variables
|
||||
# needed for gradient computation has been modified by an
|
||||
# inplace operation.
|
||||
rpn_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.feat_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
inplace=False))
|
||||
self.rpn_conv = nn.Sequential(*rpn_convs)
|
||||
else:
|
||||
self.rpn_conv = nn.Conv2d(
|
||||
self.in_channels, self.feat_channels, 3, padding=1)
|
||||
self.rpn_cls = nn.Conv2d(self.feat_channels,
|
||||
self.num_base_priors * self.cls_out_channels,
|
||||
1)
|
||||
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4,
|
||||
1)
|
||||
|
||||
def forward_single(self, x):
|
||||
"""Forward feature map of a single scale level."""
|
||||
x = self.rpn_conv(x)
|
||||
x = F.relu(x.clone(), inplace=True)
|
||||
rpn_cls_score = self.rpn_cls(x)
|
||||
rpn_bbox_pred = self.rpn_reg(x)
|
||||
return rpn_cls_score, rpn_bbox_pred
|
||||
|
||||
def loss(self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
gt_bboxes,
|
||||
img_metas,
|
||||
gt_bboxes_ignore=None):
|
||||
"""Compute losses of the head.
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Box scores for each scale level
|
||||
Has shape (N, num_anchors * num_classes, H, W)
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale
|
||||
level with shape (N, num_anchors * 4, H, W)
|
||||
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
|
||||
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
|
||||
boxes can be ignored when computing the loss.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
losses = super(RPNHeadNorm, self).loss(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
gt_bboxes,
|
||||
None,
|
||||
img_metas,
|
||||
gt_bboxes_ignore=gt_bboxes_ignore)
|
||||
return dict(
|
||||
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
|
||||
|
||||
def _get_bboxes_single(self,
|
||||
cls_score_list,
|
||||
bbox_pred_list,
|
||||
score_factor_list,
|
||||
mlvl_anchors,
|
||||
img_meta,
|
||||
cfg,
|
||||
rescale=False,
|
||||
with_nms=True,
|
||||
**kwargs):
|
||||
"""Transform outputs of a single image into bbox predictions.
|
||||
Args:
|
||||
cls_score_list (list[Tensor]): Box scores from all scale
|
||||
levels of a single image, each item has shape
|
||||
(num_anchors * num_classes, H, W).
|
||||
bbox_pred_list (list[Tensor]): Box energies / deltas from
|
||||
all scale levels of a single image, each item has
|
||||
shape (num_anchors * 4, H, W).
|
||||
score_factor_list (list[Tensor]): Score factor from all scale
|
||||
levels of a single image. RPN head does not need this value.
|
||||
mlvl_anchors (list[Tensor]): Anchors of all scale level
|
||||
each item has shape (num_anchors, 4).
|
||||
img_meta (dict): Image meta info.
|
||||
cfg (mmcv.Config): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default: False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
Returns:
|
||||
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
|
||||
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
|
||||
5-th column is a score between 0 and 1.
|
||||
"""
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
cfg = copy.deepcopy(cfg)
|
||||
img_shape = img_meta['img_shape']
|
||||
|
||||
# bboxes from different level should be independent during NMS,
|
||||
# level_ids are used as labels for batched NMS to separate them
|
||||
level_ids = []
|
||||
mlvl_scores = []
|
||||
mlvl_bbox_preds = []
|
||||
mlvl_valid_anchors = []
|
||||
nms_pre = cfg.get('nms_pre', -1)
|
||||
for level_idx in range(len(cls_score_list)):
|
||||
rpn_cls_score = cls_score_list[level_idx]
|
||||
rpn_bbox_pred = bbox_pred_list[level_idx]
|
||||
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
|
||||
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
|
||||
if self.use_sigmoid_cls:
|
||||
rpn_cls_score = rpn_cls_score.reshape(-1)
|
||||
scores = rpn_cls_score.sigmoid()
|
||||
else:
|
||||
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
|
||||
# We set FG labels to [0, num_class-1] and BG label to
|
||||
# num_class in RPN head since mmdet v2.5, which is unified to
|
||||
# be consistent with other head since mmdet v2.0. In mmdet v2.0
|
||||
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
|
||||
scores = rpn_cls_score.softmax(dim=1)[:, 0]
|
||||
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
|
||||
|
||||
anchors = mlvl_anchors[level_idx]
|
||||
if 0 < nms_pre < scores.shape[0]:
|
||||
# sort is faster than topk
|
||||
# _, topk_inds = scores.topk(cfg.nms_pre)
|
||||
ranked_scores, rank_inds = scores.sort(descending=True)
|
||||
topk_inds = rank_inds[:nms_pre]
|
||||
scores = ranked_scores[:nms_pre]
|
||||
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
|
||||
anchors = anchors[topk_inds, :]
|
||||
|
||||
mlvl_scores.append(scores)
|
||||
mlvl_bbox_preds.append(rpn_bbox_pred)
|
||||
mlvl_valid_anchors.append(anchors)
|
||||
level_ids.append(
|
||||
scores.new_full((scores.size(0), ),
|
||||
level_idx,
|
||||
dtype=torch.long))
|
||||
|
||||
return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds,
|
||||
mlvl_valid_anchors, level_ids, cfg,
|
||||
img_shape)
|
||||
|
||||
def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors,
|
||||
level_ids, cfg, img_shape, **kwargs):
|
||||
"""bbox post-processing method.
|
||||
The boxes would be rescaled to the original image scale and do
|
||||
the nms operation. Usually with_nms is False is used for aug test.
|
||||
Args:
|
||||
mlvl_scores (list[Tensor]): Box scores from all scale
|
||||
levels of a single image, each item has shape
|
||||
(num_bboxes, num_class).
|
||||
mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
|
||||
levels of a single image, each item has shape (num_bboxes, 4).
|
||||
mlvl_valid_anchors (list[Tensor]): Anchors of all scale level
|
||||
each item has shape (num_bboxes, 4).
|
||||
level_ids (list[Tensor]): Indexes from all scale levels of a
|
||||
single image, each item has shape (num_bboxes, ).
|
||||
cfg (mmcv.Config): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
img_shape (tuple(int)): Shape of current image.
|
||||
Returns:
|
||||
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
|
||||
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
|
||||
5-th column is a score between 0 and 1.
|
||||
"""
|
||||
scores = torch.cat(mlvl_scores)
|
||||
anchors = torch.cat(mlvl_valid_anchors)
|
||||
rpn_bbox_pred = torch.cat(mlvl_bboxes)
|
||||
proposals = self.bbox_coder.decode(
|
||||
anchors, rpn_bbox_pred, max_shape=img_shape)
|
||||
ids = torch.cat(level_ids)
|
||||
|
||||
if cfg.min_bbox_size >= 0:
|
||||
w = proposals[:, 2] - proposals[:, 0]
|
||||
h = proposals[:, 3] - proposals[:, 1]
|
||||
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
|
||||
if not valid_mask.all():
|
||||
proposals = proposals[valid_mask]
|
||||
scores = scores[valid_mask]
|
||||
ids = ids[valid_mask]
|
||||
|
||||
if proposals.numel() > 0:
|
||||
dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
|
||||
else:
|
||||
return proposals.new_zeros(0, 5)
|
||||
|
||||
return dets[:cfg.max_per_img]
|
||||
|
||||
def onnx_export(self, x, img_metas):
|
||||
"""Test without augmentation.
|
||||
Args:
|
||||
x (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
img_metas (list[dict]): Meta info of each image.
|
||||
Returns:
|
||||
Tensor: dets of shape [N, num_det, 5].
|
||||
"""
|
||||
cls_scores, bbox_preds = self(x)
|
||||
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
|
||||
batch_bboxes, batch_scores = super(RPNHeadNorm, self).onnx_export(
|
||||
cls_scores, bbox_preds, img_metas=img_metas, with_nms=False)
|
||||
# Use ONNX::NonMaxSuppression in deployment
|
||||
from mmdet.core.export import add_dummy_nms_for_onnx
|
||||
cfg = copy.deepcopy(self.test_cfg)
|
||||
score_threshold = cfg.nms.get('score_thr', 0.0)
|
||||
nms_pre = cfg.get('deploy_nms_pre', -1)
|
||||
# Different from the normal forward doing NMS level by level,
|
||||
# we do NMS across all levels when exporting ONNX.
|
||||
dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
|
||||
cfg.max_per_img,
|
||||
cfg.nms.iou_threshold,
|
||||
score_threshold, nms_pre,
|
||||
cfg.max_per_img)
|
||||
return dets
|
|
@ -5,6 +5,8 @@ import logging
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from easycv.models.registry import BACKBONES, HEADS, MODELS, NECKS
|
||||
from .test_util import run_in_subprocess
|
||||
|
@ -146,12 +148,18 @@ class MMAdapter:
|
|||
|
||||
class MMDetWrapper:
|
||||
|
||||
def __init__(self):
|
||||
self.refactor_modules()
|
||||
|
||||
def wrap_module(self, cls, module_type):
|
||||
if module_type == 'model':
|
||||
self._wrap_model_init(cls)
|
||||
self._wrap_model_forward(cls)
|
||||
self._wrap_model_forward_test(cls)
|
||||
|
||||
def refactor_modules(self):
|
||||
update_rpn_head()
|
||||
|
||||
def _wrap_model_init(self, cls):
|
||||
origin_init = cls.__init__
|
||||
|
||||
|
@ -252,3 +260,60 @@ def dynamic_adapt_for_mmlab(cfg):
|
|||
if len(mmlab_modules_cfg) > 1:
|
||||
adapter = MMAdapter(mmlab_modules_cfg)
|
||||
adapter.adapt_mmlab_modules()
|
||||
|
||||
|
||||
def update_rpn_head():
|
||||
logging.warning('refactor mmdet.models.RPNHead, add `norm_cfg`')
|
||||
from mmdet.models.builder import HEADS
|
||||
HEADS._module_dict.pop('RPNHead', None)
|
||||
from mmdet.models import RPNHead as _RPNHead
|
||||
|
||||
@HEADS.register_module()
|
||||
class RPNHead(_RPNHead):
|
||||
"""RPN head with norm.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
num_convs (int): Number of convolution layers in the head. Default 1.
|
||||
""" # noqa: W605
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
|
||||
num_convs=1,
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
self.num_convs = num_convs
|
||||
self.norm_cfg = norm_cfg
|
||||
super(RPNHead, self).__init__(
|
||||
in_channels, init_cfg=init_cfg, **kwargs)
|
||||
|
||||
def _init_layers(self):
|
||||
"""Initialize layers of the head."""
|
||||
if self.num_convs > 1:
|
||||
rpn_convs = []
|
||||
for i in range(self.num_convs):
|
||||
if i == 0:
|
||||
in_channels = self.in_channels
|
||||
else:
|
||||
in_channels = self.feat_channels
|
||||
# use ``inplace=False`` to avoid error: one of the variables
|
||||
# needed for gradient computation has been modified by an
|
||||
# inplace operation.
|
||||
rpn_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.feat_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
inplace=False))
|
||||
self.rpn_conv = nn.Sequential(*rpn_convs)
|
||||
else:
|
||||
self.rpn_conv = nn.Conv2d(
|
||||
self.in_channels, self.feat_channels, 3, padding=1)
|
||||
self.rpn_cls = nn.Conv2d(
|
||||
self.feat_channels,
|
||||
self.num_base_priors * self.cls_out_channels, 1)
|
||||
self.rpn_reg = nn.Conv2d(self.feat_channels,
|
||||
self.num_base_priors * 4, 1)
|
||||
|
|
Loading…
Reference in New Issue