mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Implement fast version of RTMDet. (#425)
* Accelerate RTMDet * update * update * update * update1 * update2 * update pipeline * update lr cudnnbenchmark * revert batchsize * fix batch inference * refactor head * update box * bs=16 * update * move reduce mean * update head * per img loss * fix * fix sum * concat loss * batch dsla * sort topk * bs 32 * clean code * update readme * update ut * update checkpoint * num_class * clean code * resolve comments * fix readme * fix ut Co-authored-by: huanghaian <huanghaian@sensetime.com> Co-authored-by: hha <1286304229@qq.com>pull/259/head
parent
c7a9026812
commit
48f8896e84
|
@ -1,26 +1,32 @@
|
|||
# RTMDet
|
||||
# RTMDet: An Empirical Study of Designing Real-Time Object Detectors
|
||||
|
||||
[](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real)
|
||||
[](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real)
|
||||
[](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Our tech-report will be released soon.
|
||||
In this paper, we aim to design an efficient real-time object detector that exceeds the YOLO series and is easily extensible for many object recognition tasks such as instance segmentation and rotated object detection. To obtain a more efficient model architecture, we explore an architecture that has compatible capacities in the backbone and neck, constructed by a basic building block that consists of large-kernel depth-wise convolutions. We further introduce soft labels when calculating matching costs in the dynamic label assignment to improve accuracy. Together with better training techniques, the resulting object detector, named RTMDet, achieves 52.8% AP on COCO with 300+ FPS on an NVIDIA 3090 GPU, outperforming the current mainstream industrial detectors. RTMDet achieves the best parameter-accuracy trade-off with tiny/small/medium/large/extra-large model sizes for various application scenarios, and obtains new state-of-the-art performance on real-time instance segmentation and rotated object detection. We hope the experimental results can provide new insights into designing versatile real-time object detectors for many object recognition tasks.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/12907710/192182907-f9a671d6-89cb-4d73-abd8-c2b9dada3c66.png"/>
|
||||
<img src="https://user-images.githubusercontent.com/12907710/208070055-7233a3d8-955f-486a-82da-b714b3c3bbd6.png"/>
|
||||
</div>
|
||||
|
||||
## Results and Models
|
||||
|
||||
| Backbone | size | SyncBN | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
|
||||
| :---------: | :--: | :----: | -----: | :-------: | :------: | :------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| RTMDet-tiny | 640 | Yes | 40.9 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco/rtmdet_tiny_syncbn_8xb32-300e_coco_20220902_112414-259f3241.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
|
||||
| RTMDet-s | 640 | Yes | 44.5 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco/rtmdet_s_syncbn_8xb32-300e_coco_20220905_161602-fd1cacb9.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
|
||||
| RTMDet-m | 640 | Yes | 49.1 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco/rtmdet_m_syncbn_8xb32-300e_coco_20220924_132959-d9f2e90d.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220924_132959.log.json) |
|
||||
| RTMDet-l | 640 | Yes | 51.3 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_syncbn_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco/rtmdet_l_syncbn_8xb32-300e_coco_20220926_150401-40c754b5.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220926_150401.log.json) |
|
||||
| RTMDet-x | 640 | Yes | 52.6 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_syncbn_8xb32-300e_coco.py) | [model](<>) \| [log](<>) |
|
||||
## Object Detection
|
||||
|
||||
| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download |
|
||||
| :---------: | :--: | :----: | :-------: | :------: | :------------------: | :-------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| RTMDet-tiny | 640 | 41.0 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117.log.json) |
|
||||
| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329.log.json) |
|
||||
| RTMDet-m | 640 | 49.3 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952.log.json) |
|
||||
| RTMDet-l | 640 | 51.4 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928.log.json) |
|
||||
| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. The inference speed is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
|
||||
2. We still directly use the weights trained by `mmdet` currently. A re-trained model will be released later.
|
||||
1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS.
|
||||
2. For a fair comparison, the config of bbox postprocessing is changed to be consistent with YOLOv5/6/7 after [PR#9494](https://github.com/open-mmlab/mmdetection/pull/9494), bringing about 0.1~0.3% AP improvement.
|
||||
|
|
|
@ -15,9 +15,9 @@ Collections:
|
|||
Version: v0.1.1
|
||||
|
||||
Models:
|
||||
- Name: rtmdet_tiny_syncbn_8xb32-300e_coco
|
||||
- Name: rtmdet_tiny_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py
|
||||
Config: configs/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 11.7
|
||||
Epochs: 300
|
||||
|
@ -25,12 +25,12 @@ Models:
|
|||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 40.9
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco/rtmdet_tiny_syncbn_8xb32-300e_coco_20220902_112414-259f3241.pth
|
||||
box AP: 41.0
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth
|
||||
|
||||
- Name: rtmdet_s_syncbn_8xb32-300e_coco
|
||||
- Name: rtmdet_s_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco.py
|
||||
Config: configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 15.9
|
||||
Epochs: 300
|
||||
|
@ -38,12 +38,12 @@ Models:
|
|||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 44.5
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_8xb32-300e_coco/rtmdet_s_syncbn_8xb32-300e_coco_20220905_161602-fd1cacb9.pth
|
||||
box AP: 44.6
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth
|
||||
|
||||
- Name: rtmdet_m_syncbn_8xb32-300e_coco
|
||||
- Name: rtmdet_m_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco.py
|
||||
Config: configs/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 27.8
|
||||
Epochs: 300
|
||||
|
@ -51,12 +51,12 @@ Models:
|
|||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 49.1
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_8xb32-300e_coco/rtmdet_m_syncbn_8xb32-300e_coco_20220924_132959-d9f2e90d.pth
|
||||
box AP: 49.3
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_m_syncbn_fast_8xb32-300e_coco/rtmdet_m_syncbn_fast_8xb32-300e_coco_20230102_135952-40af4fe8.pth
|
||||
|
||||
- Name: rtmdet_l_syncbn_8xb32-300e_coco
|
||||
- Name: rtmdet_l_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco.py
|
||||
Config: configs/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 43.2
|
||||
Epochs: 300
|
||||
|
@ -64,5 +64,18 @@ Models:
|
|||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 51.3
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_8xb32-300e_coco/rtmdet_l_syncbn_8xb32-300e_coco_20220926_150401-40c754b5.pth
|
||||
box AP: 51.4
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_l_syncbn_fast_8xb32-300e_coco/rtmdet_l_syncbn_fast_8xb32-300e_coco_20230102_135928-ee3abdc4.pth
|
||||
|
||||
- Name: rtmdet_x_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: RTMDet
|
||||
Config: configs/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 63.4
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 52.8
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_x_syncbn_fast_8xb32-300e_coco/rtmdet_x_syncbn_fast_8xb32-300e_coco_20221231_100345-b85cd476.pth
|
||||
|
|
|
@ -9,20 +9,33 @@ widen_factor = 1.0
|
|||
max_epochs = 300
|
||||
stage2_num_epochs = 20
|
||||
interval = 10
|
||||
num_classes = 80
|
||||
|
||||
train_batch_size_per_gpu = 32
|
||||
train_num_workers = 10
|
||||
val_batch_size_per_gpu = 5
|
||||
val_batch_size_per_gpu = 32
|
||||
val_num_workers = 10
|
||||
# persistent_workers must be False if num_workers is 0.
|
||||
persistent_workers = True
|
||||
strides = [8, 16, 32]
|
||||
base_lr = 0.004
|
||||
|
||||
# single-scale training is recommended to
|
||||
# be turned on, which can speed up training.
|
||||
env_cfg = dict(cudnn_benchmark=True)
|
||||
|
||||
# only on Val
|
||||
batch_shapes_cfg = dict(
|
||||
type='BatchShapePolicy',
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
img_size=img_scale[0],
|
||||
size_divisor=32,
|
||||
extra_pad_ratio=0.5)
|
||||
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
data_preprocessor=dict(
|
||||
type='mmdet.DetDataPreprocessor',
|
||||
type='YOLOv5DetDataPreprocessor',
|
||||
mean=[103.53, 116.28, 123.675],
|
||||
std=[57.375, 57.12, 58.395],
|
||||
bgr_to_rgb=False),
|
||||
|
@ -49,7 +62,7 @@ model = dict(
|
|||
type='RTMDetHead',
|
||||
head_module=dict(
|
||||
type='RTMDetSepBNHeadModule',
|
||||
num_classes=80,
|
||||
num_classes=num_classes,
|
||||
in_channels=256,
|
||||
stacked_convs=2,
|
||||
feat_channels=256,
|
||||
|
@ -60,7 +73,7 @@ model = dict(
|
|||
featmap_strides=strides),
|
||||
prior_generator=dict(
|
||||
type='mmdet.MlvlPointGenerator', offset=0, strides=strides),
|
||||
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
|
||||
bbox_coder=dict(type='DistancePointBBoxCoder'),
|
||||
loss_cls=dict(
|
||||
type='mmdet.QualityFocalLoss',
|
||||
use_sigmoid=True,
|
||||
|
@ -69,18 +82,19 @@ model = dict(
|
|||
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=2.0)),
|
||||
train_cfg=dict(
|
||||
assigner=dict(
|
||||
type='mmdet.DynamicSoftLabelAssigner',
|
||||
type='BatchDynamicSoftLabelAssigner',
|
||||
num_classes=num_classes,
|
||||
topk=13,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False),
|
||||
test_cfg=dict(
|
||||
nms_pre=1000,
|
||||
min_bbox_size=0,
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.6),
|
||||
max_per_img=100),
|
||||
multi_label=True,
|
||||
nms_pre=30000,
|
||||
score_thr=0.001,
|
||||
nms=dict(type='nms', iou_threshold=0.65),
|
||||
max_per_img=300),
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
|
@ -102,13 +116,7 @@ train_pipeline = [
|
|||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(
|
||||
type='YOLOXMixUp',
|
||||
img_scale=img_scale,
|
||||
use_cached=True,
|
||||
ratio_range=(1.0, 1.0),
|
||||
max_cached_images=20,
|
||||
pad_val=(114, 114, 114)),
|
||||
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
|
||||
dict(type='mmdet.PackDetInputs')
|
||||
]
|
||||
|
||||
|
@ -130,13 +138,17 @@ train_pipeline_stage2 = [
|
|||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor'))
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
|
@ -144,6 +156,7 @@ train_dataloader = dict(
|
|||
num_workers=train_num_workers,
|
||||
persistent_workers=persistent_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=dict(type='yolov5_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
|
@ -166,6 +179,7 @@ val_dataloader = dict(
|
|||
ann_file='annotations/instances_val2017.json',
|
||||
data_prefix=dict(img='val2017/'),
|
||||
test_mode=True,
|
||||
batch_shapes_cfg=batch_shapes_cfg,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
test_dataloader = val_dataloader
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
|
||||
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.67
|
||||
widen_factor = 0.75
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
|
||||
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
|
||||
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa
|
||||
|
||||
deepen_factor = 0.33
|
||||
|
@ -42,13 +42,7 @@ train_pipeline = [
|
|||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(
|
||||
type='YOLOXMixUp',
|
||||
img_scale=img_scale,
|
||||
use_cached=True,
|
||||
ratio_range=(1.0, 1.0),
|
||||
max_cached_images=20,
|
||||
pad_val=(114, 114, 114)),
|
||||
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
|
||||
dict(type='mmdet.PackDetInputs')
|
||||
]
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './rtmdet_s_syncbn_8xb32-300e_coco.py'
|
||||
_base_ = './rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa
|
||||
|
||||
|
@ -40,14 +40,11 @@ train_pipeline = [
|
|||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(
|
||||
type='YOLOXMixUp',
|
||||
img_scale=img_scale,
|
||||
ratio_range=(1.0, 1.0),
|
||||
max_cached_images=10, # note
|
||||
type='YOLOv5MixUp',
|
||||
use_cached=True,
|
||||
random_pop=False, # note
|
||||
pad_val=(114, 114, 114),
|
||||
prob=0.5), # note
|
||||
random_pop=False,
|
||||
max_cached_images=10,
|
||||
prob=0.5),
|
||||
dict(type='mmdet.PackDetInputs')
|
||||
]
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './rtmdet_l_syncbn_8xb32-300e_coco.py'
|
||||
_base_ = './rtmdet_l_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 1.33
|
||||
widen_factor = 1.25
|
|
@ -1,19 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, is_norm
|
||||
from mmdet.models.task_modules.prior_generators import anchor_inside_flags
|
||||
from mmdet.models.task_modules.samplers import PseudoSampler
|
||||
from mmdet.models.utils import images_to_levels, multi_apply, unmap
|
||||
from mmdet.structures.bbox import distance2bbox
|
||||
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
|
||||
OptInstanceList, OptMultiConfig, reduce_mean)
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
|
||||
normal_init)
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import MODELS, TASK_UTILS
|
||||
|
@ -172,7 +168,7 @@ class RTMDetSepBNHeadModule(BaseModule):
|
|||
|
||||
cls_scores = []
|
||||
bbox_preds = []
|
||||
for idx, (x, stride) in enumerate(zip(feats, self.featmap_strides)):
|
||||
for idx, x in enumerate(feats):
|
||||
cls_feat = x
|
||||
reg_feat = x
|
||||
|
||||
|
@ -183,7 +179,7 @@ class RTMDetSepBNHeadModule(BaseModule):
|
|||
for reg_layer in self.reg_convs[idx]:
|
||||
reg_feat = reg_layer(reg_feat)
|
||||
|
||||
reg_dist = self.rtm_reg[idx](reg_feat) * stride
|
||||
reg_dist = self.rtm_reg[idx](reg_feat)
|
||||
cls_scores.append(cls_score)
|
||||
bbox_preds.append(reg_dist)
|
||||
return tuple(cls_scores), tuple(bbox_preds)
|
||||
|
@ -210,28 +206,28 @@ class RTMDetHead(YOLOv5Head):
|
|||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_module: ConfigType,
|
||||
prior_generator: ConfigType = dict(
|
||||
type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16,
|
||||
32]),
|
||||
bbox_coder: ConfigType = dict(type='mmdet.DistancePointBBoxCoder'),
|
||||
loss_cls: ConfigType = dict(
|
||||
type='mmdet.QualityFocalLoss',
|
||||
use_sigmoid=True,
|
||||
beta=2.0,
|
||||
loss_weight=1.0),
|
||||
loss_bbox: ConfigType = dict(
|
||||
type='mmdet.GIoULoss', loss_weight=2.0),
|
||||
loss_obj: ConfigType = dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
def __init__(self,
|
||||
head_module: ConfigType,
|
||||
prior_generator: ConfigType = dict(
|
||||
type='mmdet.MlvlPointGenerator',
|
||||
offset=0,
|
||||
strides=[8, 16, 32]),
|
||||
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
||||
loss_cls: ConfigType = dict(
|
||||
type='mmdet.QualityFocalLoss',
|
||||
use_sigmoid=True,
|
||||
beta=2.0,
|
||||
loss_weight=1.0),
|
||||
loss_bbox: ConfigType = dict(
|
||||
type='mmdet.GIoULoss', loss_weight=2.0),
|
||||
loss_obj: ConfigType = dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
|
||||
super().__init__(
|
||||
head_module=head_module,
|
||||
|
@ -276,116 +272,6 @@ class RTMDetHead(YOLOv5Head):
|
|||
"""
|
||||
return self.head_module(x)
|
||||
|
||||
def predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
cfg: Optional[ConfigDict] = None,
|
||||
rescale: bool = True,
|
||||
with_nms: bool = True) -> List[InstanceData]:
|
||||
"""Transform a batch of output features extracted from the head into
|
||||
bbox results.
|
||||
|
||||
Args:
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
batch_img_metas (list[dict], Optional): Batch image meta info.
|
||||
Defaults to None.
|
||||
cfg (ConfigDict, optional): Test / postprocessing
|
||||
configuration, if None, test_cfg would be used.
|
||||
Defaults to None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Defaults to False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
list[:obj:`InstanceData`]: Object detection results of each image
|
||||
after the post process. Each item usually contains following keys.
|
||||
|
||||
- scores (Tensor): Classification scores, has a shape
|
||||
(num_instance, )
|
||||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
- bboxes (Tensor): Has a shape (num_instances, 4),
|
||||
the last dimension 4 arrange as (x1, y1, x2, y2).
|
||||
"""
|
||||
return super(YOLOv5Head, self).predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
None,
|
||||
batch_img_metas=batch_img_metas,
|
||||
cfg=cfg,
|
||||
rescale=rescale,
|
||||
with_nms=with_nms)
|
||||
|
||||
def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
|
||||
labels: Tensor, label_weights: Tensor,
|
||||
bbox_targets: Tensor, assign_metrics: Tensor,
|
||||
stride: List[int]) -> list:
|
||||
"""Compute loss of a single scale level.
|
||||
|
||||
Args:
|
||||
cls_score (Tensor): Box scores for each scale level
|
||||
Has shape (N, num_anchors * num_classes, H, W).
|
||||
bbox_pred (Tensor): Decoded bboxes for each scale
|
||||
level with shape (N, num_anchors * 4, H, W).
|
||||
labels (Tensor): Labels of each anchors with shape
|
||||
(N, num_total_anchors).
|
||||
label_weights (Tensor): Label weights of each anchor with shape
|
||||
(N, num_total_anchors).
|
||||
bbox_targets (Tensor): BBox regression targets of each anchor with
|
||||
shape (N, num_total_anchors, 4).
|
||||
assign_metrics (Tensor): Assign metrics with shape
|
||||
(N, num_total_anchors).
|
||||
stride (List[int]): Downsample stride of the feature map.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
|
||||
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
|
||||
-1, self.cls_out_channels).contiguous()
|
||||
bbox_pred = bbox_pred.reshape(-1, 4)
|
||||
bbox_targets = bbox_targets.reshape(-1, 4)
|
||||
labels = labels.reshape(-1)
|
||||
assign_metrics = assign_metrics.reshape(-1)
|
||||
label_weights = label_weights.reshape(-1)
|
||||
targets = (labels, assign_metrics)
|
||||
|
||||
loss_cls = self.loss_cls(
|
||||
cls_score, targets, label_weights, avg_factor=1.0)
|
||||
|
||||
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
||||
bg_class_ind = self.num_classes
|
||||
pos_inds = ((labels >= 0)
|
||||
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
||||
|
||||
if len(pos_inds) > 0:
|
||||
pos_bbox_targets = bbox_targets[pos_inds]
|
||||
pos_bbox_pred = bbox_pred[pos_inds]
|
||||
|
||||
pos_decode_bbox_pred = pos_bbox_pred
|
||||
pos_decode_bbox_targets = pos_bbox_targets
|
||||
|
||||
# regression loss
|
||||
pos_bbox_weight = assign_metrics[pos_inds]
|
||||
|
||||
loss_bbox = self.loss_bbox(
|
||||
pos_decode_bbox_pred,
|
||||
pos_decode_bbox_targets,
|
||||
weight=pos_bbox_weight,
|
||||
avg_factor=1.0)
|
||||
else:
|
||||
loss_bbox = bbox_pred.sum() * 0
|
||||
pos_bbox_weight = bbox_targets.new_tensor(0.)
|
||||
|
||||
return loss_cls, loss_bbox, assign_metrics.sum(), pos_bbox_weight.sum()
|
||||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
|
@ -418,286 +304,131 @@ class RTMDetHead(YOLOv5Head):
|
|||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
||||
assert len(featmap_sizes) == self.prior_generator.num_levels
|
||||
|
||||
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||
gt_labels = gt_info[:, :, :1]
|
||||
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
||||
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
||||
|
||||
device = cls_scores[0].device
|
||||
anchor_list, valid_flag_list = self.get_anchors(
|
||||
featmap_sizes, batch_img_metas, device=device)
|
||||
|
||||
# If the shape does not equal, generate new one
|
||||
if featmap_sizes != self.featmap_sizes:
|
||||
self.featmap_sizes = featmap_sizes
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device, with_stride=True)
|
||||
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
|
||||
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
|
||||
|
||||
flatten_cls_scores = torch.cat([
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
self.cls_out_channels)
|
||||
for cls_score in cls_scores
|
||||
], 1).contiguous()
|
||||
|
||||
flatten_bboxes = torch.cat([
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
], 1)
|
||||
decoded_bboxes = []
|
||||
for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
|
||||
anchor = anchor.reshape(-1, 4)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
bbox_pred = distance2bbox(anchor, bbox_pred)
|
||||
decoded_bboxes.append(bbox_pred)
|
||||
flatten_bboxes = flatten_bboxes * self.flatten_priors[..., -1, None]
|
||||
flatten_bboxes = distance2bbox(self.flatten_priors[..., :2],
|
||||
flatten_bboxes)
|
||||
|
||||
flatten_bboxes = torch.cat(decoded_bboxes, 1)
|
||||
assigned_result = self.assigner(flatten_bboxes.detach(),
|
||||
flatten_cls_scores.detach(),
|
||||
self.flatten_priors, gt_labels,
|
||||
gt_bboxes, pad_bbox_flag)
|
||||
|
||||
cls_reg_targets = self.get_targets(
|
||||
flatten_cls_scores,
|
||||
flatten_bboxes,
|
||||
anchor_list,
|
||||
valid_flag_list,
|
||||
batch_gt_instances,
|
||||
batch_img_metas,
|
||||
batch_gt_instances_ignore=batch_gt_instances_ignore)
|
||||
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
|
||||
assign_metrics_list) = cls_reg_targets
|
||||
labels = assigned_result['assigned_labels'].reshape(-1)
|
||||
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
|
||||
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
|
||||
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
|
||||
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
|
||||
bbox_preds = flatten_bboxes.reshape(-1, 4)
|
||||
|
||||
losses_cls, losses_bbox,\
|
||||
cls_avg_factors, bbox_avg_factors = multi_apply(
|
||||
self.loss_by_feat_single,
|
||||
cls_scores,
|
||||
decoded_bboxes,
|
||||
labels_list,
|
||||
label_weights_list,
|
||||
bbox_targets_list,
|
||||
assign_metrics_list,
|
||||
self.prior_generator.strides)
|
||||
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
||||
bg_class_ind = self.num_classes
|
||||
pos_inds = ((labels >= 0)
|
||||
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
||||
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
|
||||
|
||||
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
|
||||
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
|
||||
loss_cls = self.loss_cls(
|
||||
cls_preds, (labels, assign_metrics),
|
||||
label_weights,
|
||||
avg_factor=avg_factor)
|
||||
|
||||
bbox_avg_factor = reduce_mean(
|
||||
sum(bbox_avg_factors)).clamp_(min=1).item()
|
||||
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
|
||||
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
|
||||
|
||||
def get_targets(self,
|
||||
cls_scores: Tensor,
|
||||
bbox_preds: Tensor,
|
||||
anchor_list: List[List[Tensor]],
|
||||
valid_flag_list: List[List[Tensor]],
|
||||
batch_gt_instances: InstanceList,
|
||||
batch_img_metas: List[dict],
|
||||
batch_gt_instances_ignore: OptInstanceList = None,
|
||||
unmap_outputs=True) -> Union[tuple, None]:
|
||||
"""Compute regression and classification targets for anchors in
|
||||
multiple images.
|
||||
|
||||
Args:
|
||||
cls_scores (Tensor): Classification predictions of images,
|
||||
a 3D-Tensor with shape [num_imgs, num_priors, num_classes].
|
||||
bbox_preds (Tensor): Decoded bboxes predictions of one image,
|
||||
a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x,
|
||||
tl_y, br_x, br_y] format.
|
||||
anchor_list (list[list[Tensor]]): Multi level anchors of each
|
||||
image. The outer list indicates images, and the inner list
|
||||
corresponds to feature levels of the image. Each element of
|
||||
the inner list is a tensor of shape (num_anchors, 4).
|
||||
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
|
||||
each image. The outer list indicates images, and the inner list
|
||||
corresponds to feature levels of the image. Each element of
|
||||
the inner list is a tensor of shape (num_anchors, )
|
||||
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``bboxes`` and ``labels``
|
||||
attributes.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
|
||||
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
||||
data that is ignored during training and testing.
|
||||
Defaults to None.
|
||||
unmap_outputs (bool): Whether to map outputs back to the original
|
||||
set of anchors. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple: a tuple containing learning targets.
|
||||
|
||||
- anchors_list (list[list[Tensor]]): Anchors of each level.
|
||||
- labels_list (list[Tensor]): Labels of each level.
|
||||
- label_weights_list (list[Tensor]): Label weights of each
|
||||
level.
|
||||
- bbox_targets_list (list[Tensor]): BBox targets of each level.
|
||||
- assign_metrics_list (list[Tensor]): alignment metrics of each
|
||||
level.
|
||||
"""
|
||||
num_imgs = len(batch_img_metas)
|
||||
assert len(anchor_list) == len(valid_flag_list) == num_imgs
|
||||
|
||||
# anchor number of multi levels
|
||||
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
|
||||
|
||||
# concat all level anchors and flags to a single tensor
|
||||
for i in range(num_imgs):
|
||||
assert len(anchor_list[i]) == len(valid_flag_list[i])
|
||||
anchor_list[i] = torch.cat(anchor_list[i])
|
||||
valid_flag_list[i] = torch.cat(valid_flag_list[i])
|
||||
|
||||
# compute targets for each image
|
||||
if batch_gt_instances_ignore is None:
|
||||
batch_gt_instances_ignore = [None] * num_imgs
|
||||
# anchor_list: list(b * [-1, 4])
|
||||
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
|
||||
all_assign_metrics) = multi_apply(
|
||||
self._get_targets_single,
|
||||
cls_scores.detach(),
|
||||
bbox_preds.detach(),
|
||||
anchor_list,
|
||||
valid_flag_list,
|
||||
batch_gt_instances,
|
||||
batch_img_metas,
|
||||
batch_gt_instances_ignore,
|
||||
unmap_outputs=unmap_outputs)
|
||||
# no valid anchors
|
||||
if any([labels is None for labels in all_labels]):
|
||||
return None
|
||||
|
||||
# split targets to a list w.r.t. multiple levels
|
||||
anchors_list = images_to_levels(all_anchors, num_level_anchors)
|
||||
labels_list = images_to_levels(all_labels, num_level_anchors)
|
||||
label_weights_list = images_to_levels(all_label_weights,
|
||||
num_level_anchors)
|
||||
bbox_targets_list = images_to_levels(all_bbox_targets,
|
||||
num_level_anchors)
|
||||
assign_metrics_list = images_to_levels(all_assign_metrics,
|
||||
num_level_anchors)
|
||||
|
||||
return (anchors_list, labels_list, label_weights_list,
|
||||
bbox_targets_list, assign_metrics_list)
|
||||
|
||||
def _get_targets_single(self,
|
||||
cls_scores: Tensor,
|
||||
bbox_preds: Tensor,
|
||||
flat_anchors: Tensor,
|
||||
valid_flags: Tensor,
|
||||
gt_instances: InstanceData,
|
||||
img_meta: dict,
|
||||
gt_instances_ignore: Optional[InstanceData] = None,
|
||||
unmap_outputs=True) -> tuple:
|
||||
"""Compute regression, classification targets for anchors in a single
|
||||
image.
|
||||
|
||||
Args:
|
||||
cls_scores (list(Tensor)): Box scores for each image.
|
||||
bbox_preds (list(Tensor)): Box energies / deltas for each image.
|
||||
flat_anchors (Tensor): Multi-level anchors of the image, which are
|
||||
concatenated into a single tensor of shape (num_anchors ,4)
|
||||
valid_flags (Tensor): Multi level valid flags of the image,
|
||||
which are concatenated into a single tensor of
|
||||
shape (num_anchors,).
|
||||
gt_instances (:obj:`InstanceData`): Ground truth of instance
|
||||
annotations. It usually includes ``bboxes`` and ``labels``
|
||||
attributes.
|
||||
img_meta (dict): Meta information for current image.
|
||||
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
|
||||
to be ignored during training. It includes ``bboxes`` attribute
|
||||
data that is ignored during training and testing.
|
||||
Defaults to None.
|
||||
unmap_outputs (bool): Whether to map outputs back to the original
|
||||
set of anchors. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple: N is the number of total anchors in the image.
|
||||
|
||||
- anchors (Tensor): All anchors in the image with shape (N, 4).
|
||||
- labels (Tensor): Labels of all anchors in the image with shape
|
||||
(N,).
|
||||
- label_weights (Tensor): Label weights of all anchor in the
|
||||
image with shape (N,).
|
||||
- bbox_targets (Tensor): BBox targets of all anchors in the
|
||||
image with shape (N, 4).
|
||||
- norm_alignment_metrics (Tensor): Normalized alignment metrics
|
||||
of all priors in the image with shape (N,).
|
||||
"""
|
||||
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
|
||||
img_meta['img_shape'][:2],
|
||||
self.train_cfg.allowed_border)
|
||||
if not inside_flags.any():
|
||||
return (None, ) * 7
|
||||
# assign gt and sample anchors
|
||||
anchors = flat_anchors[inside_flags, :]
|
||||
|
||||
pred_instances = InstanceData(
|
||||
scores=cls_scores[inside_flags, :],
|
||||
bboxes=bbox_preds[inside_flags, :],
|
||||
priors=anchors)
|
||||
|
||||
assign_result = self.assigner.assign(pred_instances, gt_instances,
|
||||
gt_instances_ignore)
|
||||
|
||||
sampling_result = self.sampler.sample(assign_result, pred_instances,
|
||||
gt_instances)
|
||||
|
||||
num_valid_anchors = anchors.shape[0]
|
||||
bbox_targets = torch.zeros_like(anchors)
|
||||
labels = anchors.new_full((num_valid_anchors, ),
|
||||
self.num_classes,
|
||||
dtype=torch.long)
|
||||
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
|
||||
assign_metrics = anchors.new_zeros(
|
||||
num_valid_anchors, dtype=torch.float)
|
||||
|
||||
pos_inds = sampling_result.pos_inds
|
||||
neg_inds = sampling_result.neg_inds
|
||||
if len(pos_inds) > 0:
|
||||
# point-based
|
||||
pos_bbox_targets = sampling_result.pos_gt_bboxes
|
||||
bbox_targets[pos_inds, :] = pos_bbox_targets
|
||||
loss_bbox = self.loss_bbox(
|
||||
bbox_preds[pos_inds],
|
||||
bbox_targets[pos_inds],
|
||||
weight=assign_metrics[pos_inds],
|
||||
avg_factor=avg_factor)
|
||||
else:
|
||||
loss_bbox = bbox_preds.sum() * 0
|
||||
|
||||
labels[pos_inds] = sampling_result.pos_gt_labels
|
||||
if self.train_cfg.pos_weight <= 0:
|
||||
label_weights[pos_inds] = 1.0
|
||||
else:
|
||||
label_weights[pos_inds] = self.train_cfg.pos_weight
|
||||
if len(neg_inds) > 0:
|
||||
label_weights[neg_inds] = 1.0
|
||||
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
|
||||
|
||||
class_assigned_gt_inds = torch.unique(
|
||||
sampling_result.pos_assigned_gt_inds)
|
||||
for gt_inds in class_assigned_gt_inds:
|
||||
gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
|
||||
gt_inds]
|
||||
assign_metrics[gt_class_inds] = assign_result.max_overlaps[
|
||||
gt_class_inds]
|
||||
@staticmethod
|
||||
def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence],
|
||||
batch_size: int) -> Tensor:
|
||||
"""Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
|
||||
to.
|
||||
|
||||
# map up to original set of anchors
|
||||
if unmap_outputs:
|
||||
num_total_anchors = flat_anchors.size(0)
|
||||
anchors = unmap(anchors, num_total_anchors, inside_flags)
|
||||
labels = unmap(
|
||||
labels, num_total_anchors, inside_flags, fill=self.num_classes)
|
||||
label_weights = unmap(label_weights, num_total_anchors,
|
||||
inside_flags)
|
||||
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
|
||||
assign_metrics = unmap(assign_metrics, num_total_anchors,
|
||||
inside_flags)
|
||||
return anchors, labels, label_weights, bbox_targets, assign_metrics
|
||||
|
||||
def get_anchors(self,
|
||||
featmap_sizes: List[tuple],
|
||||
batch_img_metas: List[dict],
|
||||
device: Union[torch.device, str] = 'cuda') \
|
||||
-> Tuple[List[List[Tensor]], List[List[Tensor]]]:
|
||||
"""Get anchors according to feature map sizes.
|
||||
[batch_size, number_gt, 5]. If some shape of single batch smaller than
|
||||
gt bbox len, then using [-1., 0., 0., 0., 0.] to fill.
|
||||
|
||||
Args:
|
||||
featmap_sizes (list[tuple]): Multi-level feature map sizes.
|
||||
batch_img_metas (list[dict]): Image meta info.
|
||||
device (torch.device or str): Device for returned tensors.
|
||||
Defaults to cuda.
|
||||
batch_gt_instances (Sequence[Tensor]): Ground truth
|
||||
instances for whole batch, shape [all_gt_bboxes, 6]
|
||||
batch_size (int): Batch size.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
|
||||
- anchor_list (list[list[Tensor]]): Anchors of each image.
|
||||
- valid_flag_list (list[list[Tensor]]): Valid flags of each
|
||||
image.
|
||||
Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
|
||||
"""
|
||||
num_imgs = len(batch_img_metas)
|
||||
if isinstance(batch_gt_instances, Sequence):
|
||||
max_gt_bbox_len = max(
|
||||
[len(gt_instances) for gt_instances in batch_gt_instances])
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
batch_instance_list = []
|
||||
for index, gt_instance in enumerate(batch_gt_instances):
|
||||
bboxes = gt_instance.bboxes
|
||||
labels = gt_instance.labels
|
||||
batch_instance_list.append(
|
||||
torch.cat((labels[:, None], bboxes), dim=-1))
|
||||
|
||||
# since feature map sizes of all images are the same, we only compute
|
||||
# anchors for one time
|
||||
multi_level_anchors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=device, with_stride=True)
|
||||
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
|
||||
if bboxes.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
|
||||
# for each image, we compute valid flags of multi level anchors
|
||||
valid_flag_list = []
|
||||
for img_id, img_meta in enumerate(batch_img_metas):
|
||||
multi_level_flags = self.prior_generator.valid_flags(
|
||||
featmap_sizes, img_meta['pad_shape'], device)
|
||||
valid_flag_list.append(multi_level_flags)
|
||||
return anchor_list, valid_flag_list
|
||||
fill_tensor = bboxes.new_full(
|
||||
[max_gt_bbox_len - bboxes.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[-1], fill_tensor), dim=0)
|
||||
|
||||
return torch.stack(batch_instance_list)
|
||||
else:
|
||||
# faster version
|
||||
# sqlit batch gt instance [all_gt_bboxes, 6] ->
|
||||
# [batch_size, number_gt_each_batch, 5]
|
||||
batch_instance_list = []
|
||||
max_gt_bbox_len = 0
|
||||
for i in range(batch_size):
|
||||
single_batch_instance = \
|
||||
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
||||
single_batch_instance = single_batch_instance[:, 1:]
|
||||
batch_instance_list.append(single_batch_instance)
|
||||
if len(single_batch_instance) > max_gt_bbox_len:
|
||||
max_gt_bbox_len = len(single_batch_instance)
|
||||
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
for index, gt_instance in enumerate(batch_instance_list):
|
||||
if gt_instance.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
fill_tensor = batch_gt_instances.new_full(
|
||||
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[index], fill_tensor), dim=0)
|
||||
|
||||
return torch.stack(batch_instance_list)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .batch_atss_assigner import BatchATSSAssigner
|
||||
from .batch_dsl_assigner import BatchDynamicSoftLabelAssigner
|
||||
from .batch_task_aligned_assigner import BatchTaskAlignedAssigner
|
||||
from .utils import (select_candidates_in_gts, select_highest_overlaps,
|
||||
yolov6_iou_calculator)
|
||||
|
@ -7,5 +8,5 @@ from .utils import (select_candidates_in_gts, select_highest_overlaps,
|
|||
__all__ = [
|
||||
'BatchATSSAssigner', 'BatchTaskAlignedAssigner',
|
||||
'select_candidates_in_gts', 'select_highest_overlaps',
|
||||
'yolov6_iou_calculator'
|
||||
'yolov6_iou_calculator', 'BatchDynamicSoftLabelAssigner'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmdet.structures.bbox import BaseBoxes
|
||||
from mmdet.utils import ConfigType
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import TASK_UTILS
|
||||
|
||||
INF = 100000000
|
||||
EPS = 1.0e-7
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchDynamicSoftLabelAssigner(nn.Module):
|
||||
"""Computes matching between predictions and ground truth with dynamic soft
|
||||
label assignment.
|
||||
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
soft_center_radius (float): Radius of the soft center prior.
|
||||
Defaults to 3.0.
|
||||
topk (int): Select top-k predictions to calculate dynamic k
|
||||
best matches for each gt. Defaults to 13.
|
||||
iou_weight (float): The scale factor of iou cost. Defaults to 3.0.
|
||||
iou_calculator (ConfigType): Config of overlaps Calculator.
|
||||
Defaults to dict(type='BboxOverlaps2D').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
soft_center_radius: float = 3.0,
|
||||
topk: int = 13,
|
||||
iou_weight: float = 3.0,
|
||||
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.soft_center_radius = soft_center_radius
|
||||
self.topk = topk
|
||||
self.iou_weight = iou_weight
|
||||
self.iou_calculator = TASK_UTILS.build(iou_calculator)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor,
|
||||
gt_labels: Tensor, gt_bboxes: Tensor,
|
||||
pad_bbox_flag: Tensor) -> dict:
|
||||
num_gt = gt_bboxes.size(1)
|
||||
decoded_bboxes = pred_bboxes
|
||||
num_bboxes = decoded_bboxes.size(1)
|
||||
batch_size = decoded_bboxes.size(0)
|
||||
|
||||
if num_gt == 0 or num_bboxes == 0:
|
||||
return {
|
||||
'assigned_labels':
|
||||
gt_labels.new_full(
|
||||
pred_scores[..., 0].shape,
|
||||
self.num_classes,
|
||||
dtype=torch.long),
|
||||
'assigned_labels_weights':
|
||||
gt_bboxes.new_full(pred_scores[..., 0].shape, 1),
|
||||
'assigned_bboxes':
|
||||
gt_bboxes.new_full(pred_bboxes.shape, 0),
|
||||
'assign_metrics':
|
||||
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
|
||||
}
|
||||
|
||||
prior_center = priors[:, :2]
|
||||
if isinstance(gt_bboxes, BaseBoxes):
|
||||
raise NotImplementedError(
|
||||
f'type of {type(gt_bboxes)} are not implemented !')
|
||||
else:
|
||||
# Tensor boxes will be treated as horizontal boxes by defaults
|
||||
lt_ = prior_center[:, None, None] - gt_bboxes[..., :2]
|
||||
rb_ = gt_bboxes[..., 2:] - prior_center[:, None, None]
|
||||
|
||||
deltas = torch.cat([lt_, rb_], dim=-1)
|
||||
is_in_gts = deltas.min(dim=-1).values > 0
|
||||
is_in_gts = is_in_gts * pad_bbox_flag[..., 0][None]
|
||||
is_in_gts = is_in_gts.permute(1, 0, 2)
|
||||
valid_mask = is_in_gts.sum(dim=-1) > 0
|
||||
|
||||
# Tensor boxes will be treated as horizontal boxes by defaults
|
||||
gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
|
||||
|
||||
strides = priors[..., 2]
|
||||
distance = (priors[None].unsqueeze(2)[..., :2] -
|
||||
gt_center[:, None, :, :]
|
||||
).pow(2).sum(-1).sqrt() / strides[None, :, None]
|
||||
|
||||
# prevent overflow
|
||||
distance = distance * valid_mask.unsqueeze(-1)
|
||||
soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
|
||||
|
||||
pairwise_ious = self.iou_calculator(decoded_bboxes, gt_bboxes)
|
||||
iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight
|
||||
|
||||
# select the predicted scores corresponded to the gt_labels
|
||||
pairwise_pred_scores = pred_scores.permute(0, 2, 1)
|
||||
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
|
||||
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
|
||||
idx[1] = gt_labels.long().squeeze(-1)
|
||||
pairwise_pred_scores = pairwise_pred_scores[idx[0],
|
||||
idx[1]].permute(0, 2, 1)
|
||||
# classification cost
|
||||
scale_factor = pairwise_ious - pairwise_pred_scores.sigmoid()
|
||||
pairwise_cls_cost = F.binary_cross_entropy_with_logits(
|
||||
pairwise_pred_scores, pairwise_ious,
|
||||
reduction='none') * scale_factor.abs().pow(2.0)
|
||||
|
||||
cost_matrix = pairwise_cls_cost + iou_cost + soft_center_prior
|
||||
|
||||
max_pad_value = torch.ones_like(cost_matrix) * INF
|
||||
cost_matrix = torch.where(valid_mask[..., None].repeat(1, 1, num_gt),
|
||||
cost_matrix, max_pad_value)
|
||||
|
||||
(matched_pred_ious, matched_gt_inds,
|
||||
fg_mask_inboxes) = self.dynamic_k_matching(cost_matrix, pairwise_ious,
|
||||
pad_bbox_flag)
|
||||
|
||||
del pairwise_ious, cost_matrix
|
||||
|
||||
batch_index = (fg_mask_inboxes > 0).nonzero(as_tuple=True)[0]
|
||||
|
||||
assigned_labels = gt_labels.new_full(pred_scores[..., 0].shape,
|
||||
self.num_classes)
|
||||
assigned_labels[fg_mask_inboxes] = gt_labels[
|
||||
batch_index, matched_gt_inds].squeeze(-1)
|
||||
assigned_labels = assigned_labels.long()
|
||||
|
||||
assigned_labels_weights = gt_bboxes.new_full(pred_scores[..., 0].shape,
|
||||
1)
|
||||
|
||||
assigned_bboxes = gt_bboxes.new_full(pred_bboxes.shape, 0)
|
||||
assigned_bboxes[fg_mask_inboxes] = gt_bboxes[batch_index,
|
||||
matched_gt_inds]
|
||||
|
||||
assign_metrics = gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
|
||||
assign_metrics[fg_mask_inboxes] = matched_pred_ious
|
||||
|
||||
return dict(
|
||||
assigned_labels=assigned_labels,
|
||||
assigned_labels_weights=assigned_labels_weights,
|
||||
assigned_bboxes=assigned_bboxes,
|
||||
assign_metrics=assign_metrics)
|
||||
|
||||
def dynamic_k_matching(self, cost_matrix: Tensor, pairwise_ious: Tensor,
|
||||
pad_bbox_flag: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Use IoU and matching cost to calculate the dynamic top-k positive
|
||||
targets.
|
||||
|
||||
Args:
|
||||
cost_matrix (Tensor): Cost matrix.
|
||||
pairwise_ious (Tensor): Pairwise iou matrix.
|
||||
num_gt (int): Number of gt.
|
||||
valid_mask (Tensor): Mask for valid bboxes.
|
||||
Returns:
|
||||
tuple: matched ious and gt indexes.
|
||||
"""
|
||||
matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
|
||||
# select candidate topk ious for dynamic-k calculation
|
||||
candidate_topk = min(self.topk, pairwise_ious.size(1))
|
||||
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
|
||||
# calculate dynamic k for each gt
|
||||
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
||||
|
||||
num_gts = pad_bbox_flag.sum((1, 2)).int()
|
||||
# sorting the batch cost matirx is faster than topk
|
||||
_, sorted_indices = torch.sort(cost_matrix, dim=1)
|
||||
for b in range(pad_bbox_flag.shape[0]):
|
||||
for gt_idx in range(num_gts[b]):
|
||||
topk_ids = sorted_indices[b, :dynamic_ks[b, gt_idx], gt_idx]
|
||||
matching_matrix[b, :, gt_idx][topk_ids] = 1
|
||||
|
||||
del topk_ious, dynamic_ks
|
||||
|
||||
prior_match_gt_mask = matching_matrix.sum(2) > 1
|
||||
if prior_match_gt_mask.sum() > 0:
|
||||
cost_min, cost_argmin = torch.min(
|
||||
cost_matrix[prior_match_gt_mask, :], dim=1)
|
||||
matching_matrix[prior_match_gt_mask, :] *= 0
|
||||
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
|
||||
|
||||
# get foreground mask inside box and center prior
|
||||
fg_mask_inboxes = matching_matrix.sum(2) > 0
|
||||
matched_pred_ious = (matching_matrix *
|
||||
pairwise_ious).sum(2)[fg_mask_inboxes]
|
||||
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
|
||||
return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
|
|
@ -33,11 +33,13 @@ class TestRTMDetHead(TestCase):
|
|||
'ori_shape': (s, s, 3),
|
||||
'scale_factor': (1.0, 1.0),
|
||||
}]
|
||||
test_cfg = Config(
|
||||
dict(
|
||||
max_per_img=300,
|
||||
score_thr=0.01,
|
||||
nms=dict(type='nms', iou_threshold=0.65)))
|
||||
test_cfg = dict(
|
||||
multi_label=True,
|
||||
nms_pre=30000,
|
||||
score_thr=0.001,
|
||||
nms=dict(type='nms', iou_threshold=0.65),
|
||||
max_per_img=300)
|
||||
test_cfg = Config(test_cfg)
|
||||
|
||||
head = RTMDetHead(head_module=self.head_module, test_cfg=test_cfg)
|
||||
feat = [
|
||||
|
@ -48,14 +50,14 @@ class TestRTMDetHead(TestCase):
|
|||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
batch_img_metas=img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=True,
|
||||
with_nms=True)
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
batch_img_metas=img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=False,
|
||||
with_nms=False)
|
||||
|
@ -64,18 +66,19 @@ class TestRTMDetHead(TestCase):
|
|||
s = 256
|
||||
img_metas = [{
|
||||
'img_shape': (s, s, 3),
|
||||
'pad_shape': (s, s, 3),
|
||||
'batch_input_shape': (s, s),
|
||||
'scale_factor': 1,
|
||||
}]
|
||||
train_cfg = Config(
|
||||
dict(
|
||||
assigner=dict(
|
||||
type='mmdet.DynamicSoftLabelAssigner',
|
||||
topk=13,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1))
|
||||
|
||||
train_cfg = dict(
|
||||
assigner=dict(
|
||||
num_classes=80,
|
||||
type='BatchDynamicSoftLabelAssigner',
|
||||
topk=13,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False)
|
||||
train_cfg = Config(train_cfg)
|
||||
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
|
||||
|
||||
feat = [
|
||||
|
@ -84,53 +87,53 @@ class TestRTMDetHead(TestCase):
|
|||
]
|
||||
cls_scores, bbox_preds = head.forward(feat)
|
||||
|
||||
# TODO
|
||||
# Test that empty ground truth encourages the network to predict
|
||||
# background
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
|
||||
|
||||
# empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
# [gt_instances],
|
||||
# img_metas)
|
||||
# # When there is no truth, the cls loss should be nonzero but there
|
||||
# # should be no box loss.
|
||||
# empty_cls_loss = empty_gt_losses['loss_cls'].sum()
|
||||
# empty_box_loss = empty_gt_losses['loss_bbox'].sum()
|
||||
# self.assertEqual(
|
||||
# empty_cls_loss.item(), 0,
|
||||
# 'there should be no cls loss when there are no true boxes')
|
||||
# self.assertEqual(
|
||||
# empty_box_loss.item(), 0,
|
||||
# 'there should be no box loss when there are no true boxes')
|
||||
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
[gt_instances], img_metas)
|
||||
# When there is no truth, the cls loss should be nonzero but there
|
||||
# should be no box loss.
|
||||
empty_cls_loss = empty_gt_losses['loss_cls'].sum()
|
||||
empty_box_loss = empty_gt_losses['loss_bbox'].sum()
|
||||
self.assertGreater(empty_cls_loss.item(), 0,
|
||||
'classification loss should be non-zero')
|
||||
self.assertEqual(
|
||||
empty_box_loss.item(), 0,
|
||||
'there should be no box loss when there are no true boxes')
|
||||
|
||||
# When truth is non-empty then both cls and box loss should be nonzero
|
||||
# for random inputs
|
||||
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
|
||||
labels=torch.LongTensor([2]))
|
||||
labels=torch.LongTensor([1]))
|
||||
|
||||
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
[gt_instances], img_metas)
|
||||
onegt_cls_loss = sum(one_gt_losses['loss_cls'])
|
||||
onegt_box_loss = sum(one_gt_losses['loss_bbox'])
|
||||
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
|
||||
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
|
||||
self.assertGreater(onegt_cls_loss.item(), 0,
|
||||
'cls loss should be non-zero')
|
||||
self.assertGreater(onegt_box_loss.item(), 0,
|
||||
'box loss should be non-zero')
|
||||
|
||||
# Test groud truth out of bound
|
||||
# test num_class = 1
|
||||
self.head_module['num_classes'] = 1
|
||||
head = RTMDetHead(head_module=self.head_module, train_cfg=train_cfg)
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
|
||||
labels=torch.LongTensor([2]))
|
||||
gt_losses = head.loss_by_feat(cls_scores, bbox_preds, [gt_instances],
|
||||
img_metas)
|
||||
cls_loss = sum(gt_losses['loss_cls'])
|
||||
empty_box_loss = sum(gt_losses['loss_bbox'])
|
||||
self.assertGreater(
|
||||
cls_loss.item(), 0,
|
||||
'there should be no cls loss when gt_bboxes out of bound')
|
||||
self.assertEqual(
|
||||
empty_box_loss.item(), 0,
|
||||
'there should be no box loss when gt_bboxes out of bound')
|
||||
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
|
||||
labels=torch.LongTensor([0]))
|
||||
|
||||
cls_scores, bbox_preds = head.forward(feat)
|
||||
|
||||
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
[gt_instances], img_metas)
|
||||
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
|
||||
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
|
||||
self.assertGreater(onegt_cls_loss.item(), 0,
|
||||
'cls loss should be non-zero')
|
||||
self.assertGreater(onegt_box_loss.item(), 0,
|
||||
'box loss should be non-zero')
|
||||
|
|
|
@ -22,7 +22,7 @@ class TestSingleStageDetector(TestCase):
|
|||
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
|
||||
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
|
||||
'yolox/yolox_tiny_8xb8-300e_coco.py',
|
||||
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py',
|
||||
'rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py',
|
||||
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py'
|
||||
])
|
||||
def test_init(self, cfg_file):
|
||||
|
@ -39,7 +39,7 @@ class TestSingleStageDetector(TestCase):
|
|||
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
])
|
||||
def test_forward_loss_mode(self, cfg_file, devices):
|
||||
message_hub = MessageHub.get_instance(
|
||||
|
@ -79,7 +79,7 @@ class TestSingleStageDetector(TestCase):
|
|||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
])
|
||||
def test_forward_predict_mode(self, cfg_file, devices):
|
||||
model = get_detector_cfg(cfg_file)
|
||||
|
@ -111,7 +111,7 @@ class TestSingleStageDetector(TestCase):
|
|||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
])
|
||||
def test_forward_tensor_mode(self, cfg_file, devices):
|
||||
model = get_detector_cfg(cfg_file)
|
||||
|
|
Loading…
Reference in New Issue