mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support RTMDet Ins Segmentation Inference (#583)
* update config * reproduce map in mmyolo * reproduce map in mmyolo * collate fn、process加mask * reproduce result * beauty code * beauty code * beauty code * del yolov5_seg_head.py * beauty config * add doc and typehint * del objectness * fix ut; add empty respull/612/head
parent
414deaea9a
commit
146cd930c5
|
@ -0,0 +1,31 @@
|
|||
_base_ = './rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
widen_factor = 0.5
|
||||
|
||||
model = dict(
|
||||
bbox_head=dict(
|
||||
type='RTMDetInsSepBNHead',
|
||||
head_module=dict(
|
||||
type='RTMDetInsSepBNHeadModule',
|
||||
use_sigmoid_cls=True,
|
||||
widen_factor=widen_factor),
|
||||
loss_mask=dict(
|
||||
type='mmdet.DiceLoss', loss_weight=2.0, eps=5e-6,
|
||||
reduction='mean')),
|
||||
test_cfg=dict(
|
||||
multi_label=True,
|
||||
nms_pre=1000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.6),
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5))
|
||||
|
||||
_base_.test_pipeline[-2] = dict(
|
||||
type='LoadAnnotations', with_bbox=True, with_mask=True, _scope_='mmdet')
|
||||
|
||||
val_dataloader = dict(dataset=dict(pipeline=_base_.test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(metric=['bbox', 'segm'])
|
||||
test_evaluator = val_evaluator
|
|
@ -19,6 +19,7 @@ def yolov5_collate(data_batch: Sequence,
|
|||
"""
|
||||
batch_imgs = []
|
||||
batch_bboxes_labels = []
|
||||
batch_masks = []
|
||||
for i in range(len(data_batch)):
|
||||
datasamples = data_batch[i]['data_samples']
|
||||
inputs = data_batch[i]['inputs']
|
||||
|
@ -26,6 +27,10 @@ def yolov5_collate(data_batch: Sequence,
|
|||
|
||||
gt_bboxes = datasamples.gt_instances.bboxes.tensor
|
||||
gt_labels = datasamples.gt_instances.labels
|
||||
if 'masks' in datasamples.gt_instances:
|
||||
masks = datasamples.gt_instances.masks.to_tensor(
|
||||
dtype=torch.bool, device=gt_bboxes.device)
|
||||
batch_masks.append(masks)
|
||||
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
|
||||
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
|
||||
dim=1)
|
||||
|
@ -36,6 +41,8 @@ def yolov5_collate(data_batch: Sequence,
|
|||
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
|
||||
}
|
||||
}
|
||||
if len(batch_masks) > 0:
|
||||
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)
|
||||
|
||||
if use_ms_training:
|
||||
collated_results['inputs'] = batch_imgs
|
||||
|
|
|
@ -96,12 +96,14 @@ class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
|
|||
inputs, data_samples = batch_aug(inputs, data_samples)
|
||||
|
||||
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
|
||||
data_samples = {
|
||||
data_samples_output = {
|
||||
'bboxes_labels': data_samples['bboxes_labels'],
|
||||
'img_metas': img_metas
|
||||
}
|
||||
if 'masks' in data_samples:
|
||||
data_samples_output['masks'] = data_samples['masks']
|
||||
|
||||
return {'inputs': inputs, 'data_samples': data_samples}
|
||||
return {'inputs': inputs, 'data_samples': data_samples_output}
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule
|
||||
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
||||
from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule
|
||||
from .rtmdet_rotated_head import (RTMDetRotatedHead,
|
||||
RTMDetRotatedSepBNHeadModule)
|
||||
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
|
||||
|
@ -14,5 +15,6 @@ __all__ = [
|
|||
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
|
||||
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
|
||||
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
|
||||
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule'
|
||||
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
|
||||
'RTMDetInsSepBNHeadModule'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,725 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, is_norm
|
||||
from mmcv.ops import batched_nms
|
||||
from mmdet.models.utils import filter_scores_and_topk
|
||||
from mmdet.structures.bbox import get_box_tensor, get_box_wh, scale_boxes
|
||||
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
|
||||
OptInstanceList, OptMultiConfig)
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
|
||||
normal_init)
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
||||
|
||||
|
||||
class MaskFeatModule(BaseModule):
|
||||
"""Mask feature head used in RTMDet-Ins. Copy from mmdet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
feat_channels (int): Number of hidden channels of the mask feature
|
||||
map branch.
|
||||
stacked_convs (int): Number of convs in mask feature branch.
|
||||
num_levels (int): The starting feature map level from RPN that
|
||||
will be used to predict the mask feature map.
|
||||
num_prototypes (int): Number of output channel of the mask feature
|
||||
map branch. This is the channel count of the mask
|
||||
feature map that to be dynamically convolved with the predicted
|
||||
kernel.
|
||||
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True)
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
feat_channels: int = 256,
|
||||
stacked_convs: int = 4,
|
||||
num_levels: int = 3,
|
||||
num_prototypes: int = 8,
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
norm_cfg: ConfigType = dict(type='BN')
|
||||
) -> None:
|
||||
super().__init__(init_cfg=None)
|
||||
self.num_levels = num_levels
|
||||
self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
|
||||
convs = []
|
||||
for i in range(stacked_convs):
|
||||
in_c = in_channels if i == 0 else feat_channels
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_c,
|
||||
feat_channels,
|
||||
3,
|
||||
padding=1,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
self.stacked_convs = nn.Sequential(*convs)
|
||||
self.projection = nn.Conv2d(
|
||||
feat_channels, num_prototypes, kernel_size=1)
|
||||
|
||||
def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
|
||||
# multi-level feature fusion
|
||||
fusion_feats = [features[0]]
|
||||
size = features[0].shape[-2:]
|
||||
for i in range(1, self.num_levels):
|
||||
f = F.interpolate(features[i], size=size, mode='bilinear')
|
||||
fusion_feats.append(f)
|
||||
fusion_feats = torch.cat(fusion_feats, dim=1)
|
||||
fusion_feats = self.fusion_conv(fusion_feats)
|
||||
# pred mask feats
|
||||
mask_features = self.stacked_convs(fusion_feats)
|
||||
mask_features = self.projection(mask_features)
|
||||
return mask_features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RTMDetInsSepBNHeadModule(RTMDetSepBNHeadModule):
|
||||
"""Detection and Instance Segmentation Head of RTMDet.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
num_prototypes (int): Number of mask prototype features extracted
|
||||
from the mask head. Defaults to 8.
|
||||
dyconv_channels (int): Channel of the dynamic conv layers.
|
||||
Defaults to 8.
|
||||
num_dyconvs (int): Number of the dynamic convolution layers.
|
||||
Defaults to 3.
|
||||
use_sigmoid_cls (bool): Use sigmoid for class prediction.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
*args,
|
||||
num_prototypes: int = 8,
|
||||
dyconv_channels: int = 8,
|
||||
num_dyconvs: int = 3,
|
||||
use_sigmoid_cls: bool = True,
|
||||
**kwargs):
|
||||
self.num_prototypes = num_prototypes
|
||||
self.num_dyconvs = num_dyconvs
|
||||
self.dyconv_channels = dyconv_channels
|
||||
self.use_sigmoid_cls = use_sigmoid_cls
|
||||
if self.use_sigmoid_cls:
|
||||
self.cls_out_channels = num_classes
|
||||
else:
|
||||
self.cls_out_channels = num_classes + 1
|
||||
super().__init__(num_classes=num_classes, *args, **kwargs)
|
||||
|
||||
def _init_layers(self):
|
||||
"""Initialize layers of the head."""
|
||||
self.cls_convs = nn.ModuleList()
|
||||
self.reg_convs = nn.ModuleList()
|
||||
self.kernel_convs = nn.ModuleList()
|
||||
|
||||
self.rtm_cls = nn.ModuleList()
|
||||
self.rtm_reg = nn.ModuleList()
|
||||
self.rtm_kernel = nn.ModuleList()
|
||||
self.rtm_obj = nn.ModuleList()
|
||||
|
||||
# calculate num dynamic parameters
|
||||
weight_nums, bias_nums = [], []
|
||||
for i in range(self.num_dyconvs):
|
||||
if i == 0:
|
||||
weight_nums.append(
|
||||
(self.num_prototypes + 2) * self.dyconv_channels)
|
||||
bias_nums.append(self.dyconv_channels)
|
||||
elif i == self.num_dyconvs - 1:
|
||||
weight_nums.append(self.dyconv_channels)
|
||||
bias_nums.append(1)
|
||||
else:
|
||||
weight_nums.append(self.dyconv_channels * self.dyconv_channels)
|
||||
bias_nums.append(self.dyconv_channels)
|
||||
self.weight_nums = weight_nums
|
||||
self.bias_nums = bias_nums
|
||||
self.num_gen_params = sum(weight_nums) + sum(bias_nums)
|
||||
pred_pad_size = self.pred_kernel_size // 2
|
||||
|
||||
for n in range(len(self.featmap_strides)):
|
||||
cls_convs = nn.ModuleList()
|
||||
reg_convs = nn.ModuleList()
|
||||
kernel_convs = nn.ModuleList()
|
||||
for i in range(self.stacked_convs):
|
||||
chn = self.in_channels if i == 0 else self.feat_channels
|
||||
cls_convs.append(
|
||||
ConvModule(
|
||||
chn,
|
||||
self.feat_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
reg_convs.append(
|
||||
ConvModule(
|
||||
chn,
|
||||
self.feat_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
kernel_convs.append(
|
||||
ConvModule(
|
||||
chn,
|
||||
self.feat_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.cls_convs.append(cls_convs)
|
||||
self.reg_convs.append(cls_convs)
|
||||
self.kernel_convs.append(kernel_convs)
|
||||
|
||||
self.rtm_cls.append(
|
||||
nn.Conv2d(
|
||||
self.feat_channels,
|
||||
self.num_base_priors * self.cls_out_channels,
|
||||
self.pred_kernel_size,
|
||||
padding=pred_pad_size))
|
||||
self.rtm_reg.append(
|
||||
nn.Conv2d(
|
||||
self.feat_channels,
|
||||
self.num_base_priors * 4,
|
||||
self.pred_kernel_size,
|
||||
padding=pred_pad_size))
|
||||
self.rtm_kernel.append(
|
||||
nn.Conv2d(
|
||||
self.feat_channels,
|
||||
self.num_gen_params,
|
||||
self.pred_kernel_size,
|
||||
padding=pred_pad_size))
|
||||
|
||||
if self.share_conv:
|
||||
for n in range(len(self.featmap_strides)):
|
||||
for i in range(self.stacked_convs):
|
||||
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
|
||||
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
|
||||
|
||||
self.mask_head = MaskFeatModule(
|
||||
in_channels=self.in_channels,
|
||||
feat_channels=self.feat_channels,
|
||||
stacked_convs=4,
|
||||
num_levels=len(self.featmap_strides),
|
||||
num_prototypes=self.num_prototypes,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg)
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize weights of the head."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
normal_init(m, mean=0, std=0.01)
|
||||
if is_norm(m):
|
||||
constant_init(m, 1)
|
||||
bias_cls = bias_init_with_prob(0.01)
|
||||
for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
|
||||
self.rtm_kernel):
|
||||
normal_init(rtm_cls, std=0.01, bias=bias_cls)
|
||||
normal_init(rtm_reg, std=0.01, bias=1)
|
||||
|
||||
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
|
||||
"""Forward features from the upstream network.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
|
||||
Returns:
|
||||
tuple: Usually a tuple of classification scores and bbox prediction
|
||||
- cls_scores (list[Tensor]): Classification scores for all scale
|
||||
levels, each is a 4D-tensor, the channels number is
|
||||
num_base_priors * num_classes.
|
||||
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
|
||||
levels, each is a 4D-tensor, the channels number is
|
||||
num_base_priors * 4.
|
||||
- kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
|
||||
levels, each is a 4D-tensor, the channels number is
|
||||
num_gen_params.
|
||||
- mask_feat (Tensor): Mask prototype features.
|
||||
Has shape (batch_size, num_prototypes, H, W).
|
||||
"""
|
||||
mask_feat = self.mask_head(feats)
|
||||
|
||||
cls_scores = []
|
||||
bbox_preds = []
|
||||
kernel_preds = []
|
||||
for idx, (x, stride) in enumerate(zip(feats, self.featmap_strides)):
|
||||
cls_feat = x
|
||||
reg_feat = x
|
||||
kernel_feat = x
|
||||
|
||||
for cls_layer in self.cls_convs[idx]:
|
||||
cls_feat = cls_layer(cls_feat)
|
||||
cls_score = self.rtm_cls[idx](cls_feat)
|
||||
|
||||
for kernel_layer in self.kernel_convs[idx]:
|
||||
kernel_feat = kernel_layer(kernel_feat)
|
||||
kernel_pred = self.rtm_kernel[idx](kernel_feat)
|
||||
|
||||
for reg_layer in self.reg_convs[idx]:
|
||||
reg_feat = reg_layer(reg_feat)
|
||||
reg_dist = self.rtm_reg[idx](reg_feat)
|
||||
|
||||
cls_scores.append(cls_score)
|
||||
bbox_preds.append(reg_dist)
|
||||
kernel_preds.append(kernel_pred)
|
||||
return tuple(cls_scores), tuple(bbox_preds), tuple(
|
||||
kernel_preds), mask_feat
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RTMDetInsSepBNHead(RTMDetHead):
|
||||
"""RTMDet Instance Segmentation head.
|
||||
|
||||
Args:
|
||||
head_module(ConfigType): Base module used for RTMDetInsSepBNHead
|
||||
prior_generator: Points generator feature maps in
|
||||
2D points-based detectors.
|
||||
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
||||
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
||||
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
||||
loss_mask (:obj:`ConfigDict` or dict): Config of mask loss.
|
||||
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
||||
anchor head. Defaults to None.
|
||||
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
||||
anchor head. Defaults to None.
|
||||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
||||
list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
head_module: ConfigType,
|
||||
prior_generator: ConfigType = dict(
|
||||
type='mmdet.MlvlPointGenerator',
|
||||
offset=0,
|
||||
strides=[8, 16, 32]),
|
||||
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
||||
loss_cls: ConfigType = dict(
|
||||
type='mmdet.QualityFocalLoss',
|
||||
use_sigmoid=True,
|
||||
beta=2.0,
|
||||
loss_weight=1.0),
|
||||
loss_bbox: ConfigType = dict(
|
||||
type='mmdet.GIoULoss', loss_weight=2.0),
|
||||
loss_mask=dict(
|
||||
type='mmdet.DiceLoss',
|
||||
loss_weight=2.0,
|
||||
eps=5e-6,
|
||||
reduction='mean'),
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
|
||||
super().__init__(
|
||||
head_module=head_module,
|
||||
prior_generator=prior_generator,
|
||||
bbox_coder=bbox_coder,
|
||||
loss_cls=loss_cls,
|
||||
loss_bbox=loss_bbox,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
|
||||
if isinstance(self.head_module, RTMDetInsSepBNHeadModule):
|
||||
assert self.use_sigmoid_cls == self.head_module.use_sigmoid_cls
|
||||
self.loss_mask = MODELS.build(loss_mask)
|
||||
|
||||
def predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
kernel_preds: List[Tensor],
|
||||
mask_feats: Tensor,
|
||||
score_factors: Optional[List[Tensor]] = None,
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = True,
|
||||
with_nms: bool = True) -> List[InstanceData]:
|
||||
"""Transform a batch of output features extracted from the head into
|
||||
bbox results.
|
||||
|
||||
Note: When score_factors is not None, the cls_scores are
|
||||
usually multiplied by it then obtain the real score used in NMS.
|
||||
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
kernel_preds (list[Tensor]): Kernel predictions of dynamic
|
||||
convs for all scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_params, H, W).
|
||||
mask_feats (Tensor): Mask prototype features extracted from the
|
||||
mask head, has shape (batch_size, num_prototypes, H, W).
|
||||
score_factors (list[Tensor], optional): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 1, H, W). Defaults to None.
|
||||
batch_img_metas (list[dict], Optional): Batch image meta info.
|
||||
Defaults to None.
|
||||
cfg (ConfigDict, optional): Test / postprocessing
|
||||
configuration, if None, test_cfg would be used.
|
||||
Defaults to None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
list[:obj:`InstanceData`]: Object detection and instance
|
||||
segmentation results of each image after the post process.
|
||||
Each item usually contains following keys.
|
||||
|
||||
- scores (Tensor): Classification scores, has a shape
|
||||
(num_instance, )
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||
- masks (Tensor): Has a shape (num_instances, h, w).
|
||||
"""
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
cfg = copy.deepcopy(cfg)
|
||||
|
||||
multi_label = cfg.multi_label
|
||||
multi_label &= self.num_classes > 1
|
||||
cfg.multi_label = multi_label
|
||||
|
||||
num_imgs = len(batch_img_metas)
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
|
||||
# If the shape does not change, use the previous mlvl_priors
|
||||
if featmap_sizes != self.featmap_sizes:
|
||||
self.mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes,
|
||||
dtype=cls_scores[0].dtype,
|
||||
device=cls_scores[0].device,
|
||||
with_stride=True)
|
||||
self.featmap_sizes = featmap_sizes
|
||||
flatten_priors = torch.cat(self.mlvl_priors)
|
||||
|
||||
mlvl_strides = [
|
||||
flatten_priors.new_full(
|
||||
(featmap_size.numel() * self.num_base_priors, ), stride) for
|
||||
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
||||
]
|
||||
flatten_stride = torch.cat(mlvl_strides)
|
||||
|
||||
# flatten cls_scores, bbox_preds
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
self.num_classes)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
flatten_kernel_preds = [
|
||||
kernel_pred.permute(0, 2, 3,
|
||||
1).reshape(num_imgs, -1,
|
||||
self.head_module.num_gen_params)
|
||||
for kernel_pred in kernel_preds
|
||||
]
|
||||
|
||||
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
flatten_decoded_bboxes = self.bbox_coder.decode(
|
||||
flatten_priors[..., :2].unsqueeze(0), flatten_bbox_preds,
|
||||
flatten_stride)
|
||||
|
||||
flatten_kernel_preds = torch.cat(flatten_kernel_preds, dim=1)
|
||||
|
||||
results_list = []
|
||||
for (bboxes, scores, kernel_pred, mask_feat,
|
||||
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
|
||||
flatten_kernel_preds, mask_feats,
|
||||
batch_img_metas):
|
||||
ori_shape = img_meta['ori_shape']
|
||||
scale_factor = img_meta['scale_factor']
|
||||
if 'pad_param' in img_meta:
|
||||
pad_param = img_meta['pad_param']
|
||||
else:
|
||||
pad_param = None
|
||||
|
||||
score_thr = cfg.get('score_thr', -1)
|
||||
if scores.shape[0] == 0:
|
||||
empty_results = InstanceData()
|
||||
empty_results.bboxes = bboxes
|
||||
empty_results.scores = scores[:, 0]
|
||||
empty_results.labels = scores[:, 0].int()
|
||||
h, w = ori_shape[:2] if rescale else img_meta['img_shape'][:2]
|
||||
empty_results.masks = torch.zeros(
|
||||
size=(0, h, w), dtype=torch.bool, device=bboxes.device)
|
||||
results_list.append(empty_results)
|
||||
continue
|
||||
|
||||
nms_pre = cfg.get('nms_pre', 100000)
|
||||
if cfg.multi_label is False:
|
||||
scores, labels = scores.max(1, keepdim=True)
|
||||
scores, _, keep_idxs, results = filter_scores_and_topk(
|
||||
scores,
|
||||
score_thr,
|
||||
nms_pre,
|
||||
results=dict(
|
||||
labels=labels[:, 0],
|
||||
kernel_pred=kernel_pred,
|
||||
priors=flatten_priors))
|
||||
labels = results['labels']
|
||||
kernel_pred = results['kernel_pred']
|
||||
priors = results['priors']
|
||||
else:
|
||||
out = filter_scores_and_topk(
|
||||
scores,
|
||||
score_thr,
|
||||
nms_pre,
|
||||
results=dict(
|
||||
kernel_pred=kernel_pred, priors=flatten_priors))
|
||||
scores, labels, keep_idxs, filtered_results = out
|
||||
kernel_pred = filtered_results['kernel_pred']
|
||||
priors = filtered_results['priors']
|
||||
|
||||
results = InstanceData(
|
||||
scores=scores,
|
||||
labels=labels,
|
||||
bboxes=bboxes[keep_idxs],
|
||||
kernels=kernel_pred,
|
||||
priors=priors)
|
||||
|
||||
if rescale:
|
||||
if pad_param is not None:
|
||||
results.bboxes -= results.bboxes.new_tensor([
|
||||
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
|
||||
])
|
||||
results.bboxes /= results.bboxes.new_tensor(
|
||||
scale_factor).repeat((1, 2))
|
||||
|
||||
if cfg.get('yolox_style', False):
|
||||
# do not need max_per_img
|
||||
cfg.max_per_img = len(results)
|
||||
|
||||
results = self._bbox_mask_post_process(
|
||||
results=results,
|
||||
mask_feat=mask_feat,
|
||||
cfg=cfg,
|
||||
rescale_bbox=False,
|
||||
rescale_mask=rescale,
|
||||
with_nms=with_nms,
|
||||
pad_param=pad_param,
|
||||
img_meta=img_meta)
|
||||
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
|
||||
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
|
||||
|
||||
results_list.append(results)
|
||||
return results_list
|
||||
|
||||
def _bbox_mask_post_process(
|
||||
self,
|
||||
results: InstanceData,
|
||||
mask_feat: Tensor,
|
||||
cfg: ConfigDict,
|
||||
rescale_bbox: bool = False,
|
||||
rescale_mask: bool = True,
|
||||
with_nms: bool = True,
|
||||
pad_param: Optional[np.ndarray] = None,
|
||||
img_meta: Optional[dict] = None) -> InstanceData:
|
||||
"""bbox and mask post-processing method.
|
||||
|
||||
The boxes would be rescaled to the original image scale and do
|
||||
the nms operation. Usually `with_nms` is False is used for aug test.
|
||||
|
||||
Args:
|
||||
results (:obj:`InstaceData`): Detection instance results,
|
||||
each item has shape (num_bboxes, ).
|
||||
mask_feat (Tensor): Mask prototype features extracted from the
|
||||
mask head, has shape (batch_size, num_prototypes, H, W).
|
||||
cfg (ConfigDict): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
rescale_bbox (bool): If True, return boxes in original image space.
|
||||
Default to False.
|
||||
rescale_mask (bool): If True, return masks in original image space.
|
||||
Default to True.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default to True.
|
||||
img_meta (dict, optional): Image meta info. Defaults to None.
|
||||
|
||||
Returns:
|
||||
:obj:`InstanceData`: Detection results of each image
|
||||
after the post process.
|
||||
Each item usually contains following keys.
|
||||
|
||||
- scores (Tensor): Classification scores, has a shape
|
||||
(num_instance, )
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||
- masks (Tensor): Has a shape (num_instances, h, w).
|
||||
"""
|
||||
if rescale_bbox:
|
||||
assert img_meta.get('scale_factor') is not None
|
||||
scale_factor = [1 / s for s in img_meta['scale_factor']]
|
||||
results.bboxes = scale_boxes(results.bboxes, scale_factor)
|
||||
|
||||
if hasattr(results, 'score_factors'):
|
||||
# TODO: Add sqrt operation in order to be consistent with
|
||||
# the paper.
|
||||
score_factors = results.pop('score_factors')
|
||||
results.scores = results.scores * score_factors
|
||||
|
||||
# filter small size bboxes
|
||||
if cfg.get('min_bbox_size', -1) >= 0:
|
||||
w, h = get_box_wh(results.bboxes)
|
||||
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
|
||||
if not valid_mask.all():
|
||||
results = results[valid_mask]
|
||||
|
||||
# TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
|
||||
assert with_nms, 'with_nms must be True for RTMDet-Ins'
|
||||
if results.bboxes.numel() > 0:
|
||||
bboxes = get_box_tensor(results.bboxes)
|
||||
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
|
||||
results.labels, cfg.nms)
|
||||
results = results[keep_idxs]
|
||||
# some nms would reweight the score, such as softnms
|
||||
results.scores = det_bboxes[:, -1]
|
||||
results = results[:cfg.max_per_img]
|
||||
|
||||
# process masks
|
||||
mask_logits = self._mask_predict_by_feat(mask_feat,
|
||||
results.kernels,
|
||||
results.priors)
|
||||
|
||||
stride = self.prior_generator.strides[0][0]
|
||||
mask_logits = F.interpolate(
|
||||
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
|
||||
if rescale_mask:
|
||||
# TODO: When use mmdet.Resize or mmdet.Pad, will meet bug
|
||||
# Use img_meta to crop and resize
|
||||
ori_h, ori_w = img_meta['ori_shape'][:2]
|
||||
if isinstance(pad_param, np.ndarray):
|
||||
pad_param = pad_param.astype(np.int32)
|
||||
crop_y1, crop_y2 = pad_param[
|
||||
0], mask_logits.shape[-2] - pad_param[1]
|
||||
crop_x1, crop_x2 = pad_param[
|
||||
2], mask_logits.shape[-1] - pad_param[3]
|
||||
mask_logits = mask_logits[..., crop_y1:crop_y2,
|
||||
crop_x1:crop_x2]
|
||||
mask_logits = F.interpolate(
|
||||
mask_logits,
|
||||
size=[ori_h, ori_w],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
masks = mask_logits.sigmoid().squeeze(0)
|
||||
masks = masks > cfg.mask_thr_binary
|
||||
results.masks = masks
|
||||
else:
|
||||
h, w = img_meta['ori_shape'][:2] if rescale_mask else img_meta[
|
||||
'img_shape'][:2]
|
||||
results.masks = torch.zeros(
|
||||
size=(results.bboxes.shape[0], h, w),
|
||||
dtype=torch.bool,
|
||||
device=results.bboxes.device)
|
||||
return results
|
||||
|
||||
def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor,
|
||||
priors: Tensor) -> Tensor:
|
||||
"""Generate mask logits from mask features with dynamic convs.
|
||||
|
||||
Args:
|
||||
mask_feat (Tensor): Mask prototype features.
|
||||
Has shape (num_prototypes, H, W).
|
||||
kernels (Tensor): Kernel parameters for each instance.
|
||||
Has shape (num_instance, num_params)
|
||||
priors (Tensor): Center priors for each instance.
|
||||
Has shape (num_instance, 4).
|
||||
Returns:
|
||||
Tensor: Instance segmentation masks for each instance.
|
||||
Has shape (num_instance, H, W).
|
||||
"""
|
||||
num_inst = kernels.shape[0]
|
||||
h, w = mask_feat.size()[-2:]
|
||||
if num_inst < 1:
|
||||
return torch.empty(
|
||||
size=(num_inst, h, w),
|
||||
dtype=mask_feat.dtype,
|
||||
device=mask_feat.device)
|
||||
if len(mask_feat.shape) < 4:
|
||||
mask_feat.unsqueeze(0)
|
||||
|
||||
coord = self.prior_generator.single_level_grid_priors(
|
||||
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
|
||||
num_inst = priors.shape[0]
|
||||
points = priors[:, :2].reshape(-1, 1, 2)
|
||||
strides = priors[:, 2:].reshape(-1, 1, 2)
|
||||
relative_coord = (points - coord).permute(0, 2, 1) / (
|
||||
strides[..., 0].reshape(-1, 1, 1) * 8)
|
||||
relative_coord = relative_coord.reshape(num_inst, 2, h, w)
|
||||
|
||||
mask_feat = torch.cat(
|
||||
[relative_coord,
|
||||
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
|
||||
weights, biases = self.parse_dynamic_params(kernels)
|
||||
|
||||
n_layers = len(weights)
|
||||
x = mask_feat.reshape(1, -1, h, w)
|
||||
for i, (weight, bias) in enumerate(zip(weights, biases)):
|
||||
x = F.conv2d(
|
||||
x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
|
||||
if i < n_layers - 1:
|
||||
x = F.relu(x)
|
||||
x = x.reshape(num_inst, h, w)
|
||||
return x
|
||||
|
||||
def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
|
||||
"""split kernel head prediction to conv weight and bias."""
|
||||
n_inst = flatten_kernels.size(0)
|
||||
n_layers = len(self.head_module.weight_nums)
|
||||
params_splits = list(
|
||||
torch.split_with_sizes(
|
||||
flatten_kernels,
|
||||
self.head_module.weight_nums + self.head_module.bias_nums,
|
||||
dim=1))
|
||||
weight_splits = params_splits[:n_layers]
|
||||
bias_splits = params_splits[n_layers:]
|
||||
for i in range(n_layers):
|
||||
if i < n_layers - 1:
|
||||
weight_splits[i] = weight_splits[i].reshape(
|
||||
n_inst * self.head_module.dyconv_channels, -1, 1, 1)
|
||||
bias_splits[i] = bias_splits[i].reshape(
|
||||
n_inst * self.head_module.dyconv_channels)
|
||||
else:
|
||||
weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
|
||||
bias_splits[i] = bias_splits[i].reshape(n_inst)
|
||||
|
||||
return weight_splits, bias_splits
|
||||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
batch_gt_instances: InstanceList,
|
||||
batch_img_metas: List[dict],
|
||||
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
||||
raise NotImplementedError
|
|
@ -1,10 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmyolo.models import RTMDetInsSepBNHead
|
||||
from mmyolo.models.dense_heads import RTMDetHead
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
|
@ -137,3 +139,85 @@ class TestRTMDetHead(TestCase):
|
|||
'cls loss should be non-zero')
|
||||
self.assertGreater(onegt_box_loss.item(), 0,
|
||||
'box loss should be non-zero')
|
||||
|
||||
|
||||
class TestRTMDetInsHead(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.head_module = dict(
|
||||
type='RTMDetInsSepBNHeadModule',
|
||||
num_classes=4,
|
||||
in_channels=1,
|
||||
stacked_convs=1,
|
||||
feat_channels=64,
|
||||
featmap_strides=[4, 8, 16],
|
||||
num_prototypes=8,
|
||||
dyconv_channels=8,
|
||||
num_dyconvs=3,
|
||||
share_conv=True,
|
||||
use_sigmoid_cls=True)
|
||||
|
||||
def test_init_weights(self):
|
||||
head = RTMDetInsSepBNHead(head_module=self.head_module)
|
||||
head.head_module.init_weights()
|
||||
|
||||
def test_predict_by_feat(self):
|
||||
s = 256
|
||||
img_metas = [{
|
||||
'img_shape': (s, s, 3),
|
||||
'ori_shape': (s, s, 3),
|
||||
'scale_factor': (1.0, 1.0),
|
||||
'pad_param': np.array([0., 0., 0., 0.])
|
||||
}]
|
||||
test_cfg = dict(
|
||||
multi_label=False,
|
||||
nms_pre=1000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.6),
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5)
|
||||
test_cfg = Config(test_cfg)
|
||||
|
||||
head = RTMDetInsSepBNHead(
|
||||
head_module=self.head_module, test_cfg=test_cfg)
|
||||
feat = [
|
||||
torch.rand(1, 1, s // feat_size, s // feat_size)
|
||||
for feat_size in [4, 8, 16]
|
||||
]
|
||||
cls_scores, bbox_preds, kernel_preds, mask_feat = head.forward(feat)
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
kernel_preds,
|
||||
mask_feat,
|
||||
batch_img_metas=img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=True,
|
||||
with_nms=True)
|
||||
|
||||
img_metas_without_pad_param = [{
|
||||
'img_shape': (s, s, 3),
|
||||
'ori_shape': (s, s, 3),
|
||||
'scale_factor': (1.0, 1.0)
|
||||
}]
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
kernel_preds,
|
||||
mask_feat,
|
||||
batch_img_metas=img_metas_without_pad_param,
|
||||
cfg=test_cfg,
|
||||
rescale=True,
|
||||
with_nms=True)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
kernel_preds,
|
||||
mask_feat,
|
||||
batch_img_metas=img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=False,
|
||||
with_nms=False)
|
||||
|
|
Loading…
Reference in New Issue