[Feature] Implement fast version of RTMDet. (#425)

* Accelerate RTMDet

* update

* update

* update

* update1

* update2

* update pipeline

* update lr cudnnbenchmark

* revert batchsize

* fix batch inference

* refactor head

* update box

* bs=16

* update

* move reduce mean

* update head

* per img loss

* fix

* fix sum

* concat loss

* batch dsla

* sort topk

* bs 32

* clean code

* update readme

* update ut

* update checkpoint

* num_class

* clean code

* resolve comments

* fix readme

* fix ut

Co-authored-by: huanghaian <huanghaian@sensetime.com>
Co-authored-by: hha <1286304229@qq.com>
pull/259/head
RangiLyu 2023-01-04 16:06:44 +08:00 committed by GitHub
parent c7a9026812
commit 48f8896e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 470 additions and 518 deletions

View File

@ -1,26 +1,32 @@
# RTMDet
# RTMDet: An Empirical Study of Designing Real-Time Object Detectors
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
<!-- [ALGORITHM] -->
## Abstract
Our tech-report will be released soon.
In this paper, we aim to design an efficient real-time object detector that exceeds the YOLO series and is easily extensible for many object recognition tasks such as instance segmentation and rotated object detection. To obtain a more efficient model architecture, we explore an architecture that has compatible capacities in the backbone and neck, constructed by a basic building block that consists of large-kernel depth-wise convolutions. We further introduce soft labels when calculating matching costs in the dynamic label assignment to improve accuracy. Together with better training techniques, the resulting object detector, named RTMDet, achieves 52.8% AP on COCO with 300+ FPS on an NVIDIA 3090 GPU, outperforming the current mainstream industrial detectors. RTMDet achieves the best parameter-accuracy trade-off with tiny/small/medium/large/extra-large model sizes for various application scenarios, and obtains new state-of-the-art performance on real-time instance segmentation and rotated object detection. We hope the experimental results can provide new insights into designing versatile real-time object detectors for many object recognition tasks.
<div align=center>
<img src="https://user-images.githubusercontent.com/12907710/192182907-f9a671d6-89cb-4d73-abd8-c2b9dada3c66.png"/>
<img src="https://user-images.githubusercontent.com/12907710/208070055-7233a3d8-955f-486a-82da-b714b3c3bbd6.png"/>
</div>
## Results and Models
| Backbone | size | SyncBN | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
| :---------: | :--: | :----: | -----: | :-------: | :------: | :------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| RTMDet-tiny | 640 | Yes | 40.9 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco/rtmdet_tiny_syncbn_8xb32-300e_coco_20220902_112414-259f3241.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
| RTMDet-s | 640 | Yes | 44.5 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco/rtmdet_s_syncbn_8xb32-300e_coco_20220905_161602-fd1cacb9.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
| RTMDet-m | 640 | Yes | 49.1 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco/rtmdet_m_syncbn_8xb32-300e_coco_20220924_132959-d9f2e90d.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220924_132959.log.json) |
| RTMDet-l | 640 | Yes | 51.3 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco/rtmdet_l_syncbn_8xb32-300e_coco_20220926_150401-40c754b5.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220926_150401.log.json) |
| RTMDet-x | 640 | Yes | 52.6 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_syncbn_8xb32-300e_coco.py) | [model](<>) \| [log](<>) |
## Object Detection
| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
| :---------: | :--: | :----: | :-------: | :------: | :------------------: | :-------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| RTMDet-tiny | 640 | 41.0 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117.log.json) |
| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329.log.json) |
| RTMDet-m | 640 | 49.3 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952.log.json) |
| RTMDet-l | 640 | 51.4 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928.log.json) |
| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345.log.json) |
**Note**:
1. The inference speed is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
2. We still directly use the weights trained by `mmdet` currently. A re-trained model will be released later.
1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
2. For a fair comparison, the config of bbox postprocessing is changed to be consistent with YOLOv5/6/7 after [PR#9494](https://github.com/open-mmlab/mmdetection/pull/9494), bringing about 0.1~0.3% AP improvement.

View File

@ -15,9 +15,9 @@ Collections:
Version: v0.1.1
Models:
- Name: rtmdet_tiny_syncbn_8xb32-300e_coco
- Name: rtmdet_tiny_syncbn_fast_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py
Config: configs/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 11.7
Epochs: 300
@ -25,12 +25,12 @@ Models:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.9
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco/rtmdet_tiny_syncbn_8xb32-300e_coco_20220902_112414-259f3241.pth
box AP: 41.0
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth
- Name: rtmdet_s_syncbn_8xb32-300e_coco
- Name: rtmdet_s_syncbn_fast_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco.py
Config: configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 15.9
Epochs: 300
@ -38,12 +38,12 @@ Models:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 44.5
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco/rtmdet_s_syncbn_8xb32-300e_coco_20220905_161602-fd1cacb9.pth
box AP: 44.6
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth
- Name: rtmdet_m_syncbn_8xb32-300e_coco
- Name: rtmdet_m_syncbn_fast_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco.py
Config: configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 27.8
Epochs: 300
@ -51,12 +51,12 @@ Models:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 49.1
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco/rtmdet_m_syncbn_8xb32-300e_coco_20220924_132959-d9f2e90d.pth
box AP: 49.3
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth
- Name: rtmdet_l_syncbn_8xb32-300e_coco
- Name: rtmdet_l_syncbn_fast_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco.py
Config: configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 43.2
Epochs: 300
@ -64,5 +64,18 @@ Models:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 51.3
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco/rtmdet_l_syncbn_8xb32-300e_coco_20220926_150401-40c754b5.pth
box AP: 51.4
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth
- Name: rtmdet_x_syncbn_fast_8xb32-300e_coco
In Collection: RTMDet
Config: configs/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 63.4
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 52.8
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth

View File

@ -9,20 +9,33 @@ widen_factor = 1.0
max_epochs = 300
stage2_num_epochs = 20
interval = 10
num_classes = 80
train_batch_size_per_gpu = 32
train_num_workers = 10
val_batch_size_per_gpu = 5
val_batch_size_per_gpu = 32
val_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
strides = [8, 16, 32]
base_lr = 0.004
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
type='YOLOv5DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False),
@ -49,7 +62,7 @@ model = dict(
type='RTMDetHead',
head_module=dict(
type='RTMDetSepBNHeadModule',
num_classes=80,
num_classes=num_classes,
in_channels=256,
stacked_convs=2,
feat_channels=256,
@ -60,7 +73,7 @@ model = dict(
featmap_strides=strides),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0, strides=strides),
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
bbox_coder=dict(type='DistancePointBBoxCoder'),
loss_cls=dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
@ -69,18 +82,19 @@ model = dict(
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=2.0)),
train_cfg=dict(
assigner=dict(
type='mmdet.DynamicSoftLabelAssigner',
type='BatchDynamicSoftLabelAssigner',
num_classes=num_classes,
topk=13,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100),
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300),
)
train_pipeline = [
@ -102,13 +116,7 @@ train_pipeline = [
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
use_cached=True,
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
dict(type='mmdet.PackDetInputs')
]
@ -130,13 +138,17 @@ train_pipeline_stage2 = [
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
'scale_factor', 'pad_param'))
]
train_dataloader = dict(
@ -144,6 +156,7 @@ train_dataloader = dict(
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
collate_fn=dict(type='yolov5_collate'),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
@ -166,6 +179,7 @@ val_dataloader = dict(
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
batch_shapes_cfg=batch_shapes_cfg,
pipeline=test_pipeline))
test_dataloader = val_dataloader

View File

@ -1,4 +1,4 @@
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 0.67
widen_factor = 0.75

View File

@ -1,4 +1,4 @@
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa
deepen_factor = 0.33
@ -42,13 +42,7 @@ train_pipeline = [
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
use_cached=True,
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
dict(type='mmdet.PackDetInputs')
]

View File

@ -1,4 +1,4 @@
_base_ = './rtmdet_s_syncbn_8xb32-300e_coco.py'
_base_ = './rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa
@ -40,14 +40,11 @@ train_pipeline = [
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
ratio_range=(1.0, 1.0),
max_cached_images=10, # note
type='YOLOv5MixUp',
use_cached=True,
random_pop=False, # note
pad_val=(114, 114, 114),
prob=0.5), # note
random_pop=False,
max_cached_images=10,
prob=0.5),
dict(type='mmdet.PackDetInputs')
]

View File

@ -1,4 +1,4 @@
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 1.33
widen_factor = 1.25

View File

@ -1,19 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, is_norm
from mmdet.models.task_modules.prior_generators import anchor_inside_flags
from mmdet.models.task_modules.samplers import PseudoSampler
from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.structures.bbox import distance2bbox
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, OptMultiConfig, reduce_mean)
from mmengine.config import ConfigDict
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
normal_init)
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS, TASK_UTILS
@ -172,7 +168,7 @@ class RTMDetSepBNHeadModule(BaseModule):
cls_scores = []
bbox_preds = []
for idx, (x, stride) in enumerate(zip(feats, self.featmap_strides)):
for idx, x in enumerate(feats):
cls_feat = x
reg_feat = x
@ -183,7 +179,7 @@ class RTMDetSepBNHeadModule(BaseModule):
for reg_layer in self.reg_convs[idx]:
reg_feat = reg_layer(reg_feat)
reg_dist = self.rtm_reg[idx](reg_feat) * stride
reg_dist = self.rtm_reg[idx](reg_feat)
cls_scores.append(cls_score)
bbox_preds.append(reg_dist)
return tuple(cls_scores), tuple(bbox_preds)
@ -210,28 +206,28 @@ class RTMDetHead(YOLOv5Head):
Defaults to None.
"""
def __init__(
self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16,
32]),
bbox_coder: ConfigType = dict(type='mmdet.DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', loss_weight=2.0),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', loss_weight=2.0),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
@ -276,116 +272,6 @@ class RTMDetHead(YOLOv5Head):
"""
return self.head_module(x)
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted from the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
return super(YOLOv5Head, self).predict_by_feat(
cls_scores,
bbox_preds,
None,
batch_img_metas=batch_img_metas,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
labels: Tensor, label_weights: Tensor,
bbox_targets: Tensor, assign_metrics: Tensor,
stride: List[int]) -> list:
"""Compute loss of a single scale level.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Decoded bboxes for each scale
level with shape (N, num_anchors * 4, H, W).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors).
bbox_targets (Tensor): BBox regression targets of each anchor with
shape (N, num_total_anchors, 4).
assign_metrics (Tensor): Assign metrics with shape
(N, num_total_anchors).
stride (List[int]): Downsample stride of the feature map.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels).contiguous()
bbox_pred = bbox_pred.reshape(-1, 4)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
assign_metrics = assign_metrics.reshape(-1)
label_weights = label_weights.reshape(-1)
targets = (labels, assign_metrics)
loss_cls = self.loss_cls(
cls_score, targets, label_weights, avg_factor=1.0)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_decode_bbox_pred = pos_bbox_pred
pos_decode_bbox_targets = pos_bbox_targets
# regression loss
pos_bbox_weight = assign_metrics[pos_inds]
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=pos_bbox_weight,
avg_factor=1.0)
else:
loss_bbox = bbox_pred.sum() * 0
pos_bbox_weight = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum()
def loss_by_feat(
self,
cls_scores: List[Tensor],
@ -418,286 +304,131 @@ class RTMDetHead(YOLOv5Head):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes:
self.featmap_sizes = featmap_sizes
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1).contiguous()
flatten_bboxes = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
decoded_bboxes = []
for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
anchor = anchor.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
bbox_pred = distance2bbox(anchor, bbox_pred)
decoded_bboxes.append(bbox_pred)
flatten_bboxes = flatten_bboxes * self.flatten_priors[..., -1, None]
flatten_bboxes = distance2bbox(self.flatten_priors[..., :2],
flatten_bboxes)
flatten_bboxes = torch.cat(decoded_bboxes, 1)
assigned_result = self.assigner(flatten_bboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors, gt_labels,
gt_bboxes, pad_bbox_flag)
cls_reg_targets = self.get_targets(
flatten_cls_scores,
flatten_bboxes,
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
assign_metrics_list) = cls_reg_targets
labels = assigned_result['assigned_labels'].reshape(-1)
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
bbox_preds = flatten_bboxes.reshape(-1, 4)
losses_cls, losses_bbox,\
cls_avg_factors, bbox_avg_factors = multi_apply(
self.loss_by_feat_single,
cls_scores,
decoded_bboxes,
labels_list,
label_weights_list,
bbox_targets_list,
assign_metrics_list,
self.prior_generator.strides)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
loss_cls = self.loss_cls(
cls_preds, (labels, assign_metrics),
label_weights,
avg_factor=avg_factor)
bbox_avg_factor = reduce_mean(
sum(bbox_avg_factors)).clamp_(min=1).item()
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def get_targets(self,
cls_scores: Tensor,
bbox_preds: Tensor,
anchor_list: List[List[Tensor]],
valid_flag_list: List[List[Tensor]],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
unmap_outputs=True) -> Union[tuple, None]:
"""Compute regression and classification targets for anchors in
multiple images.
Args:
cls_scores (Tensor): Classification predictions of images,
a 3D-Tensor with shape [num_imgs, num_priors, num_classes].
bbox_preds (Tensor): Decoded bboxes predictions of one image,
a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x,
tl_y, br_x, br_y] format.
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors. Defaults to True.
Returns:
tuple: a tuple containing learning targets.
- anchors_list (list[list[Tensor]]): Anchors of each level.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each
level.
- bbox_targets_list (list[Tensor]): BBox targets of each level.
- assign_metrics_list (list[Tensor]): alignment metrics of each
level.
"""
num_imgs = len(batch_img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
# anchor_list: list(b * [-1, 4])
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_assign_metrics) = multi_apply(
self._get_targets_single,
cls_scores.detach(),
bbox_preds.detach(),
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore,
unmap_outputs=unmap_outputs)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
assign_metrics_list = images_to_levels(all_assign_metrics,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, assign_metrics_list)
def _get_targets_single(self,
cls_scores: Tensor,
bbox_preds: Tensor,
flat_anchors: Tensor,
valid_flags: Tensor,
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None,
unmap_outputs=True) -> tuple:
"""Compute regression, classification targets for anchors in a single
image.
Args:
cls_scores (list(Tensor)): Box scores for each image.
bbox_preds (list(Tensor)): Box energies / deltas for each image.
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors. Defaults to True.
Returns:
tuple: N is the number of total anchors in the image.
- anchors (Tensor): All anchors in the image with shape (N, 4).
- labels (Tensor): Labels of all anchors in the image with shape
(N,).
- label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
- bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
- norm_alignment_metrics (Tensor): Normalized alignment metrics
of all priors in the image with shape (N,).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
pred_instances = InstanceData(
scores=cls_scores[inside_flags, :],
bboxes=bbox_preds[inside_flags, :],
priors=anchors)
assign_result = self.assigner.assign(pred_instances, gt_instances,
gt_instances_ignore)
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
assign_metrics = anchors.new_zeros(
num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
# point-based
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
loss_bbox = self.loss_bbox(
bbox_preds[pos_inds],
bbox_targets[pos_inds],
weight=assign_metrics[pos_inds],
avg_factor=avg_factor)
else:
loss_bbox = bbox_preds.sum() * 0
labels[pos_inds] = sampling_result.pos_gt_labels
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
class_assigned_gt_inds = torch.unique(
sampling_result.pos_assigned_gt_inds)
for gt_inds in class_assigned_gt_inds:
gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
gt_inds]
assign_metrics[gt_class_inds] = assign_result.max_overlaps[
gt_class_inds]
@staticmethod
def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence],
batch_size: int) -> Tensor:
"""Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
to.
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
assign_metrics = unmap(assign_metrics, num_total_anchors,
inside_flags)
return anchors, labels, label_weights, bbox_targets, assign_metrics
def get_anchors(self,
featmap_sizes: List[tuple],
batch_img_metas: List[dict],
device: Union[torch.device, str] = 'cuda') \
-> Tuple[List[List[Tensor]], List[List[Tensor]]]:
"""Get anchors according to feature map sizes.
[batch_size, number_gt, 5]. If some shape of single batch smaller than
gt bbox len, then using [-1., 0., 0., 0., 0.] to fill.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
batch_img_metas (list[dict]): Image meta info.
device (torch.device or str): Device for returned tensors.
Defaults to cuda.
batch_gt_instances (Sequence[Tensor]): Ground truth
instances for whole batch, shape [all_gt_bboxes, 6]
batch_size (int): Batch size.
Returns:
tuple:
- anchor_list (list[list[Tensor]]): Anchors of each image.
- valid_flag_list (list[list[Tensor]]): Valid flags of each
image.
Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
"""
num_imgs = len(batch_img_metas)
if isinstance(batch_gt_instances, Sequence):
max_gt_bbox_len = max(
[len(gt_instances) for gt_instances in batch_gt_instances])
# fill [-1., 0., 0., 0., 0.] if some shape of
# single batch not equal max_gt_bbox_len
batch_instance_list = []
for index, gt_instance in enumerate(batch_gt_instances):
bboxes = gt_instance.bboxes
labels = gt_instance.labels
batch_instance_list.append(
torch.cat((labels[:, None], bboxes), dim=-1))
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
if bboxes.shape[0] >= max_gt_bbox_len:
continue
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(batch_img_metas):
multi_level_flags = self.prior_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'], device)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
fill_tensor = bboxes.new_full(
[max_gt_bbox_len - bboxes.shape[0], 5], 0)
fill_tensor[:, 0] = -1.
batch_instance_list[index] = torch.cat(
(batch_instance_list[-1], fill_tensor), dim=0)
return torch.stack(batch_instance_list)
else:
# faster version
# sqlit batch gt instance [all_gt_bboxes, 6] ->
# [batch_size, number_gt_each_batch, 5]
batch_instance_list = []
max_gt_bbox_len = 0
for i in range(batch_size):
single_batch_instance = \
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
single_batch_instance = single_batch_instance[:, 1:]
batch_instance_list.append(single_batch_instance)
if len(single_batch_instance) > max_gt_bbox_len:
max_gt_bbox_len = len(single_batch_instance)
# fill [-1., 0., 0., 0., 0.] if some shape of
# single batch not equal max_gt_bbox_len
for index, gt_instance in enumerate(batch_instance_list):
if gt_instance.shape[0] >= max_gt_bbox_len:
continue
fill_tensor = batch_gt_instances.new_full(
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
fill_tensor[:, 0] = -1.
batch_instance_list[index] = torch.cat(
(batch_instance_list[index], fill_tensor), dim=0)
return torch.stack(batch_instance_list)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_atss_assigner import BatchATSSAssigner
from .batch_dsl_assigner import BatchDynamicSoftLabelAssigner
from .batch_task_aligned_assigner import BatchTaskAlignedAssigner
from .utils import (select_candidates_in_gts, select_highest_overlaps,
yolov6_iou_calculator)
@ -7,5 +8,5 @@ from .utils import (select_candidates_in_gts, select_highest_overlaps,
__all__ = [
'BatchATSSAssigner', 'BatchTaskAlignedAssigner',
'select_candidates_in_gts', 'select_highest_overlaps',
'yolov6_iou_calculator'
'yolov6_iou_calculator', 'BatchDynamicSoftLabelAssigner'
]

View File

@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.structures.bbox import BaseBoxes
from mmdet.utils import ConfigType
from torch import Tensor
from mmyolo.registry import TASK_UTILS
INF = 100000000
EPS = 1.0e-7
@TASK_UTILS.register_module()
class BatchDynamicSoftLabelAssigner(nn.Module):
"""Computes matching between predictions and ground truth with dynamic soft
label assignment.
Args:
num_classes (int): number of class
soft_center_radius (float): Radius of the soft center prior.
Defaults to 3.0.
topk (int): Select top-k predictions to calculate dynamic k
best matches for each gt. Defaults to 13.
iou_weight (float): The scale factor of iou cost. Defaults to 3.0.
iou_calculator (ConfigType): Config of overlaps Calculator.
Defaults to dict(type='BboxOverlaps2D').
"""
def __init__(
self,
num_classes,
soft_center_radius: float = 3.0,
topk: int = 13,
iou_weight: float = 3.0,
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D')
) -> None:
super().__init__()
self.num_classes = num_classes
self.soft_center_radius = soft_center_radius
self.topk = topk
self.iou_weight = iou_weight
self.iou_calculator = TASK_UTILS.build(iou_calculator)
@torch.no_grad()
def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor,
gt_labels: Tensor, gt_bboxes: Tensor,
pad_bbox_flag: Tensor) -> dict:
num_gt = gt_bboxes.size(1)
decoded_bboxes = pred_bboxes
num_bboxes = decoded_bboxes.size(1)
batch_size = decoded_bboxes.size(0)
if num_gt == 0 or num_bboxes == 0:
return {
'assigned_labels':
gt_labels.new_full(
pred_scores[..., 0].shape,
self.num_classes,
dtype=torch.long),
'assigned_labels_weights':
gt_bboxes.new_full(pred_scores[..., 0].shape, 1),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assign_metrics':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
prior_center = priors[:, :2]
if isinstance(gt_bboxes, BaseBoxes):
raise NotImplementedError(
f'type of {type(gt_bboxes)} are not implemented !')
else:
# Tensor boxes will be treated as horizontal boxes by defaults
lt_ = prior_center[:, None, None] - gt_bboxes[..., :2]
rb_ = gt_bboxes[..., 2:] - prior_center[:, None, None]
deltas = torch.cat([lt_, rb_], dim=-1)
is_in_gts = deltas.min(dim=-1).values > 0
is_in_gts = is_in_gts * pad_bbox_flag[..., 0][None]
is_in_gts = is_in_gts.permute(1, 0, 2)
valid_mask = is_in_gts.sum(dim=-1) > 0
# Tensor boxes will be treated as horizontal boxes by defaults
gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
strides = priors[..., 2]
distance = (priors[None].unsqueeze(2)[..., :2] -
gt_center[:, None, :, :]
).pow(2).sum(-1).sqrt() / strides[None, :, None]
# prevent overflow
distance = distance * valid_mask.unsqueeze(-1)
soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
pairwise_ious = self.iou_calculator(decoded_bboxes, gt_bboxes)
iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight
# select the predicted scores corresponded to the gt_labels
pairwise_pred_scores = pred_scores.permute(0, 2, 1)
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
idx[1] = gt_labels.long().squeeze(-1)
pairwise_pred_scores = pairwise_pred_scores[idx[0],
idx[1]].permute(0, 2, 1)
# classification cost
scale_factor = pairwise_ious - pairwise_pred_scores.sigmoid()
pairwise_cls_cost = F.binary_cross_entropy_with_logits(
pairwise_pred_scores, pairwise_ious,
reduction='none') * scale_factor.abs().pow(2.0)
cost_matrix = pairwise_cls_cost + iou_cost + soft_center_prior
max_pad_value = torch.ones_like(cost_matrix) * INF
cost_matrix = torch.where(valid_mask[..., None].repeat(1, 1, num_gt),
cost_matrix, max_pad_value)
(matched_pred_ious, matched_gt_inds,
fg_mask_inboxes) = self.dynamic_k_matching(cost_matrix, pairwise_ious,
pad_bbox_flag)
del pairwise_ious, cost_matrix
batch_index = (fg_mask_inboxes > 0).nonzero(as_tuple=True)[0]
assigned_labels = gt_labels.new_full(pred_scores[..., 0].shape,
self.num_classes)
assigned_labels[fg_mask_inboxes] = gt_labels[
batch_index, matched_gt_inds].squeeze(-1)
assigned_labels = assigned_labels.long()
assigned_labels_weights = gt_bboxes.new_full(pred_scores[..., 0].shape,
1)
assigned_bboxes = gt_bboxes.new_full(pred_bboxes.shape, 0)
assigned_bboxes[fg_mask_inboxes] = gt_bboxes[batch_index,
matched_gt_inds]
assign_metrics = gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
assign_metrics[fg_mask_inboxes] = matched_pred_ious
return dict(
assigned_labels=assigned_labels,
assigned_labels_weights=assigned_labels_weights,
assigned_bboxes=assigned_bboxes,
assign_metrics=assign_metrics)
def dynamic_k_matching(self, cost_matrix: Tensor, pairwise_ious: Tensor,
pad_bbox_flag: int) -> Tuple[Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets.
Args:
cost_matrix (Tensor): Cost matrix.
pairwise_ious (Tensor): Pairwise iou matrix.
num_gt (int): Number of gt.
valid_mask (Tensor): Mask for valid bboxes.
Returns:
tuple: matched ious and gt indexes.
"""
matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.topk, pairwise_ious.size(1))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
num_gts = pad_bbox_flag.sum((1, 2)).int()
# sorting the batch cost matirx is faster than topk
_, sorted_indices = torch.sort(cost_matrix, dim=1)
for b in range(pad_bbox_flag.shape[0]):
for gt_idx in range(num_gts[b]):
topk_ids = sorted_indices[b, :dynamic_ks[b, gt_idx], gt_idx]
matching_matrix[b, :, gt_idx][topk_ids] = 1
del topk_ious, dynamic_ks
prior_match_gt_mask = matching_matrix.sum(2) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost_matrix[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(2) > 0
matched_pred_ious = (matching_matrix *
pairwise_ious).sum(2)[fg_mask_inboxes]
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
return matched_pred_ious, matched_gt_inds, fg_mask_inboxes

View File

@ -33,11 +33,13 @@ class TestRTMDetHead(TestCase):
'ori_shape': (s, s, 3),
'scale_factor': (1.0, 1.0),
}]
test_cfg = Config(
dict(
max_per_img=300,
score_thr=0.01,
nms=dict(type='nms', iou_threshold=0.65)))
test_cfg = dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300)
test_cfg = Config(test_cfg)
head = RTMDetHead(head_module=self.head_module, test_cfg=test_cfg)
feat = [
@ -48,14 +50,14 @@ class TestRTMDetHead(TestCase):
head.predict_by_feat(
cls_scores,
bbox_preds,
img_metas,
batch_img_metas=img_metas,
cfg=test_cfg,
rescale=True,
with_nms=True)
head.predict_by_feat(
cls_scores,
bbox_preds,
img_metas,
batch_img_metas=img_metas,
cfg=test_cfg,
rescale=False,
with_nms=False)
@ -64,18 +66,19 @@ class TestRTMDetHead(TestCase):
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'pad_shape': (s, s, 3),
'batch_input_shape': (s, s),
'scale_factor': 1,
}]
train_cfg = Config(
dict(
assigner=dict(
type='mmdet.DynamicSoftLabelAssigner',
topk=13,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1))
train_cfg = dict(
assigner=dict(
num_classes=80,
type='BatchDynamicSoftLabelAssigner',
topk=13,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False)
train_cfg = Config(train_cfg)
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
feat = [
@ -84,53 +87,53 @@ class TestRTMDetHead(TestCase):
]
cls_scores, bbox_preds = head.forward(feat)
# TODO
# Test that empty ground truth encourages the network to predict
# background
gt_instances = InstanceData(
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
# empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
# [gt_instances],
# img_metas)
# # When there is no truth, the cls loss should be nonzero but there
# # should be no box loss.
# empty_cls_loss = empty_gt_losses['loss_cls'].sum()
# empty_box_loss = empty_gt_losses['loss_bbox'].sum()
# self.assertEqual(
# empty_cls_loss.item(), 0,
# 'there should be no cls loss when there are no true boxes')
# self.assertEqual(
# empty_box_loss.item(), 0,
# 'there should be no box loss when there are no true boxes')
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
[gt_instances], img_metas)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
empty_cls_loss = empty_gt_losses['loss_cls'].sum()
empty_box_loss = empty_gt_losses['loss_bbox'].sum()
self.assertGreater(empty_cls_loss.item(), 0,
'classification loss should be non-zero')
self.assertEqual(
empty_box_loss.item(), 0,
'there should be no box loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([2]))
labels=torch.LongTensor([1]))
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
[gt_instances], img_metas)
onegt_cls_loss = sum(one_gt_losses['loss_cls'])
onegt_box_loss = sum(one_gt_losses['loss_bbox'])
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
self.assertGreater(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')
# Test groud truth out of bound
# test num_class = 1
self.head_module['num_classes'] = 1
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
gt_instances = InstanceData(
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
labels=torch.LongTensor([2]))
gt_losses = head.loss_by_feat(cls_scores, bbox_preds, [gt_instances],
img_metas)
cls_loss = sum(gt_losses['loss_cls'])
empty_box_loss = sum(gt_losses['loss_bbox'])
self.assertGreater(
cls_loss.item(), 0,
'there should be no cls loss when gt_bboxes out of bound')
self.assertEqual(
empty_box_loss.item(), 0,
'there should be no box loss when gt_bboxes out of bound')
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([0]))
cls_scores, bbox_preds = head.forward(feat)
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
[gt_instances], img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
self.assertGreater(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')

View File

@ -22,7 +22,7 @@ class TestSingleStageDetector(TestCase):
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
'yolox/yolox_tiny_8xb8-300e_coco.py',
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py',
'rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py',
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py'
])
def test_init(self, cfg_file):
@ -39,7 +39,7 @@ class TestSingleStageDetector(TestCase):
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_loss_mode(self, cfg_file, devices):
message_hub = MessageHub.get_instance(
@ -79,7 +79,7 @@ class TestSingleStageDetector(TestCase):
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_predict_mode(self, cfg_file, devices):
model = get_detector_cfg(cfg_file)
@ -111,7 +111,7 @@ class TestSingleStageDetector(TestCase):
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_tensor_mode(self, cfg_file, devices):
model = get_detector_cfg(cfg_file)