[Feature] Support PPYOLOE training (#259)

* add ppyoloe backbone, neck

* add ppyoloe test

* add docstring

* add ppyoloe m/l/x configfile

* add ppyoloe_coco.py

* rename config

* add typehint

* format code; add ut

* add datapre

* add datapre

* add ppyoloe datapre

* add ppyoloe datapre

* add ppyoloe datapre

* reproduce coco v0.1

* add ut

* add ut, docstring

* fix transforms bug

* use mmdet dfloss

* add non plus model config

* add non plus model config

* fix

* add ut

* produce coco v0.2

* fix config

* fix config

* fix eps and transforms bug

* add ema

* fix resize

* fix transforms.py

* fix transforms.py

* fix transforms.py

* old version

* old version

* old version

* old version

* old version

* old version

* fix stride loss error

* add INTER_LANCZOS4

* fix crop bug

* init commit

* format code

* format code

* bgr transforms.py

* add typehint and doc in transforms.py

* 继承新版yolov6head写法,删除不必要的注释

* fix transforms var name bug

* bbox decode use stridetensor insted of priors

* add headmodule todo

* add ppyoloe README.md

* add ppyoloe README.md

* Update tests/test_datasets/test_transforms/test_transforms.py

Co-authored-by: Range King <RangeKingHZ@gmail.com>

* Update tests/test_datasets/test_transforms/test_transforms.py

Co-authored-by: Range King <RangeKingHZ@gmail.com>

* save ckpt last 10 epochs

* save_best ckpt

* del ppyoloe collate

* change name of ppyoloebatchrandomresize

* add the reason for rewritten PPYOLOEDetDataPreprocessor

* solve ppyoloerandomresize name error

* rm PPYOLOERandomExpand

* rm l1 loss

* rm yolov6 loss_obj

* add difference between yolov6 and ppyoloe

* add reason for rewrite paramscheduler

* change proj init way

* fix error

* rm proj_conv in pth

* format code

* add load_from

* update

* support fast training

* add pretrained model url

* update

* add pretrained model url

* fix error

* add imagenet model convert and use init_cfg to init backbone

* add plus model pretrain model

* add ut

* add ut

* fix ut

* fix withstride bug

* cat in yolov5_collate

* merge

* fix typehint

* update readme

* add reason for gap

* fix log in README.md

* rollback yolov6

* change inherit

* fix ut

* fix ut

Co-authored-by: Range King <RangeKingHZ@gmail.com>
Co-authored-by: hha <1286304229@qq.com>
Co-authored-by: huanghaian <huanghaian@sensetime.com>
pull/349/head^2
Nioolek 2023-01-06 15:54:39 +08:00 committed by GitHub
parent a20b160f0f
commit 8127805dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1698 additions and 180 deletions

View File

@ -0,0 +1,38 @@
# PPYOLOE
<!-- [ALGORITHM] -->
## Abstract
PP-YOLOE is an excellent single-stage anchor-free model based on PP-YOLOv2, surpassing a variety of popular YOLO models. PP-YOLOE has a series of models, named s/m/l/x, which are configured through width multiplier and depth multiplier. PP-YOLOE avoids using special operators, such as Deformable Convolution or Matrix NMS, to be deployed friendly on various hardware.
<div align=center>
<img src="https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/images/ppyoloe_plus_map_fps.png" width="600" />
</div>
## Results and models
### PPYOLOE+ COCO
| Backbone | Arch | Size | Epoch | SyncBN | Mem (GB) | Box AP | Config | Download |
| :---------: | :--: | :--: | :---: | :----: | :------: | :----: | :-------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| PPYOLOE+ -s | P5 | 640 | 80 | Yes | 4.7 | 43.5 | [config](../ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052.log.json) |
| PPYOLOE+ -m | P5 | 640 | 80 | Yes | 8.4 | 49.5 | [config](../ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132.log.json) |
| PPYOLOE+ -l | P5 | 640 | 80 | Yes | 13.2 | 52.6 | [config](../ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825.log.json) |
| PPYOLOE+ -x | P5 | 640 | 80 | Yes | 19.1 | 54.2 | [config](../ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921.log.json) |
**Note**:
1. The above Box APs are all models with the best performance in COCO
2. The gap between the above performance and the official release is about 0.3. To speed up training in mmyolo, we use pytorch to implement the image resizing in `PPYOLOEBatchRandomResize` for multi-scale training, while official PPYOLOE use opencv. And `lanczos4` is not yet supported in `PPYOLOEBatchRandomResize`. The above two reasons lead to the gap. We will continue to experiment and address the gap in future releases.
3. The mAP of the non-Plus version needs more verification, and we will update more details of the non-Plus version in future versions.
```latex
@article{Xu2022PPYOLOEAE,
title={PP-YOLOE: An evolved version of YOLO},
author={Shangliang Xu and Xinxin Wang and Wenyu Lv and Qinyao Chang and Cheng Cui and Kaipeng Deng and Guanzhong Wang and Qingqing Dang and Shengyun Wei and Yuning Du and Baohua Lai},
journal={ArXiv},
year={2022},
volume={abs/2203.16250}
}
```

View File

@ -0,0 +1,69 @@
Collections:
- Name: PPYOLOE
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Nesterov
- Weight Decay
- Synchronize BN
Training Resources: 8x A100 GPUs
Architecture:
- PPYOLOECSPResNet
- PPYOLOECSPPAFPN
Paper:
URL: https://arxiv.org/abs/2203.16250
Title: 'PP-YOLOE: An evolved version of YOLO'
README: configs/ppyoloe/README.md
Code:
URL: https://github.com/open-mmlab/mmyolo/blob/v0.0.1/mmyolo/models/detectors/yolo_detector.py#L12
Version: v0.0.1
Models:
- Name: ppyoloe_plus_s_fast_8xb8-80e_coco
In Collection: PPYOLOE
Config: configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
Metadata:
Training Memory (GB): 4.7
Epochs: 80
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 43.5
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth
- Name: ppyoloe_plus_m_fast_8xb8-80e_coco
In Collection: PPYOLOE
Config: configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
Metadata:
Training Memory (GB): 8.4
Epochs: 80
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 49.5
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth
- Name: ppyoloe_plus_L_fast_8xb8-80e_coco
In Collection: PPYOLOE
Config: configs/ppyoloe/ppyoloe_plus_L_fast_8xb8-80e_coco.py
Metadata:
Training Memory (GB): 13.2
Epochs: 80
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 52.6
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth
- Name: ppyoloe_plus_x_fast_8xb8-80e_coco
In Collection: PPYOLOE
Config: configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
Metadata:
Training Memory (GB): 19.1
Epochs: 80
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 54.2
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth

View File

@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py' _base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_l_imagenet1k_pretrained-c0010e6c.pth' # noqa
deepen_factor = 1.0 deepen_factor = 1.0
widen_factor = 1.0 widen_factor = 1.0
# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 20 train_batch_size_per_gpu = 20
model = dict( model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor), backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
init_cfg=dict(checkpoint=checkpoint)),
neck=dict( neck=dict(
deepen_factor=deepen_factor, deepen_factor=deepen_factor,
widen_factor=widen_factor, widen_factor=widen_factor,
), ),
bbox_head=dict(head_module=dict(widen_factor=widen_factor))) bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
train_dataloader = dict(batch_size=train_batch_size_per_gpu)

View File

@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py' _base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_m_imagenet1k_pretrained-09f1eba2.pth' # noqa
deepen_factor = 0.67 deepen_factor = 0.67
widen_factor = 0.75 widen_factor = 0.75
# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 28 train_batch_size_per_gpu = 28
model = dict( model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor), backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
init_cfg=dict(checkpoint=checkpoint)),
neck=dict( neck=dict(
deepen_factor=deepen_factor, deepen_factor=deepen_factor,
widen_factor=widen_factor, widen_factor=widen_factor,
), ),
bbox_head=dict(head_module=dict(widen_factor=widen_factor))) bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
train_dataloader = dict(batch_size=train_batch_size_per_gpu)

View File

@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py' _base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_l_obj365_pretrained-3dd89562.pth' # noqa
deepen_factor = 1.0 deepen_factor = 1.0
widen_factor = 1.0 widen_factor = 1.0

View File

@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py' _base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_m_ojb365_pretrained-03206892.pth' # noqa
deepen_factor = 0.67 deepen_factor = 0.67
widen_factor = 0.75 widen_factor = 0.75

View File

@ -9,21 +9,40 @@ img_scale = (640, 640) # height, width
deepen_factor = 0.33 deepen_factor = 0.33
widen_factor = 0.5 widen_factor = 0.5
max_epochs = 80 max_epochs = 80
save_epoch_intervals = 10 num_classes = 80
save_epoch_intervals = 5
train_batch_size_per_gpu = 8 train_batch_size_per_gpu = 8
train_num_workers = 8 train_num_workers = 8
val_batch_size_per_gpu = 1 val_batch_size_per_gpu = 1
val_num_workers = 2 val_num_workers = 2
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_s_obj365_pretrained-bcfe8478.pth' # noqa
# persistent_workers must be False if num_workers is 0. # persistent_workers must be False if num_workers is 0.
persistent_workers = True persistent_workers = True
# Base learning rate for optim_wrapper
base_lr = 0.001
strides = [8, 16, 32] strides = [8, 16, 32]
model = dict( model = dict(
type='YOLODetector', type='YOLODetector',
data_preprocessor=dict( data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor', # use this to support multi_scale training
type='PPYOLOEDetDataPreprocessor',
pad_size_divisor=32,
batch_augments=[
dict(
type='PPYOLOEBatchRandomResize',
random_size_range=(320, 800),
interval=1,
size_divisor=32,
random_interp=True,
keep_ratio=False)
],
mean=[0., 0., 0.], mean=[0., 0., 0.],
std=[255., 255., 255.], std=[255., 255., 255.],
bgr_to_rgb=True), bgr_to_rgb=True),
@ -56,11 +75,52 @@ model = dict(
type='PPYOLOEHead', type='PPYOLOEHead',
head_module=dict( head_module=dict(
type='PPYOLOEHeadModule', type='PPYOLOEHeadModule',
num_classes=80, num_classes=num_classes,
in_channels=[192, 384, 768], in_channels=[192, 384, 768],
widen_factor=widen_factor, widen_factor=widen_factor,
featmap_strides=strides, featmap_strides=strides,
num_base_priors=1)), reg_max=16,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
act_cfg=dict(type='SiLU', inplace=True),
num_base_priors=1),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
bbox_coder=dict(type='DistancePointBBoxCoder'),
loss_cls=dict(
type='mmdet.VarifocalLoss',
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='sum',
loss_weight=1.0),
loss_bbox=dict(
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False),
# Since the dflloss is implemented differently in the official
# and mmdet, we're going to divide loss_weight by 4.
loss_dfl=dict(
type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=0.5 / 4)),
train_cfg=dict(
initial_epoch=30,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=num_classes,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
topk=13,
alpha=1,
beta=6,
eps=1e-9)),
test_cfg=dict( test_cfg=dict(
multi_label=True, multi_label=True,
nms_pre=1000, nms_pre=1000,
@ -68,10 +128,36 @@ model = dict(
nms=dict(type='nms', iou_threshold=0.7), nms=dict(type='nms', iou_threshold=0.7),
max_per_img=300)) max_per_img=300))
test_pipeline = [ train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='PPYOLOERandomDistort'),
dict(type='mmdet.Expand', mean=(103.53, 116.28, 123.675)),
dict(type='PPYOLOERandomCrop'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict( dict(
type='LoadImageFromFile', type='mmdet.PackDetInputs',
file_client_args={{_base_.file_client_args}}), meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='yolov5_collate', use_ms_training=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=0),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict( dict(
type='mmdet.FixShapeResize', type='mmdet.FixShapeResize',
width=img_scale[1], width=img_scale[1],
@ -103,6 +189,41 @@ val_dataloader = dict(
test_dataloader = val_dataloader test_dataloader = val_dataloader
param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.9,
weight_decay=5e-4,
nesterov=False),
paramwise_cfg=dict(norm_decay_mult=0.))
default_hooks = dict(
param_scheduler=dict(
type='PPYOLOEParamSchedulerHook',
warmup_min_iter=1000,
start_factor=0.,
warmup_epochs=5,
min_lr_ratio=0.0,
total_epochs=int(max_epochs * 1.2)),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
save_best='auto',
max_keep_ckpts=3))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49)
]
val_evaluator = dict( val_evaluator = dict(
type='mmdet.CocoMetric', type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10), proposal_nums=(100, 1, 10),
@ -110,5 +231,9 @@ val_evaluator = dict(
metric='bbox') metric='bbox')
test_evaluator = val_evaluator test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')

View File

@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py' _base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_x_obj365_pretrained-43a8000d.pth' # noqa
deepen_factor = 1.33 deepen_factor = 1.33
widen_factor = 1.25 widen_factor = 1.25

View File

@ -1,11 +1,36 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py' _base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
# TODO: training on ppyoloe need to be implemented. # The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_s_imagenet1k_pretrained-2be81763.pth' # noqa
train_batch_size_per_gpu = 32 train_batch_size_per_gpu = 32
max_epochs = 300 max_epochs = 300
# Base learning rate for optim_wrapper
base_lr = 0.01
model = dict( model = dict(
data_preprocessor=dict( data_preprocessor=dict(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255., 0.224 * 255., 0.225 * 255.]), std=[0.229 * 255., 0.224 * 255., 0.225 * 255.]),
backbone=dict(block_cfg=dict(use_alpha=False))) backbone=dict(
block_cfg=dict(use_alpha=False),
init_cfg=dict(
type='Pretrained',
prefix='backbone.',
checkpoint=checkpoint,
map_location='cpu')),
train_cfg=dict(initial_epoch=100))
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
optim_wrapper = dict(optimizer=dict(lr=base_lr))
default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
train_cfg = dict(max_epochs=max_epochs)
# PPYOLOE plus use obj365 pretrained model, but PPYOLOE not,
# `load_from` need to set to None.
load_from = None

View File

@ -1,4 +1,9 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py' _base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
# TODO: training on ppyoloe need to be implemented.
max_epochs = 400 max_epochs = 400
model = dict(train_cfg=dict(initial_epoch=133))
default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
train_cfg = dict(max_epochs=max_epochs)

View File

@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py' _base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
# The pretrained model is geted and converted from official PPYOLOE.
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_x_imagenet1k_pretrained-81c33ccb.pth' # noqa
deepen_factor = 1.33 deepen_factor = 1.33
widen_factor = 1.25 widen_factor = 1.25
# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 16 train_batch_size_per_gpu = 16
model = dict( model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor), backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
init_cfg=dict(checkpoint=checkpoint)),
neck=dict( neck=dict(
deepen_factor=deepen_factor, deepen_factor=deepen_factor,
widen_factor=widen_factor, widen_factor=widen_factor,
), ),
bbox_head=dict(head_module=dict(widen_factor=widen_factor))) bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
train_dataloader = dict(batch_size=train_batch_size_per_gpu)

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
from .transforms import (LetterResize, LoadAnnotations, YOLOv5HSVRandomAug, from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
PPYOLOERandomDistort, YOLOv5HSVRandomAug,
YOLOv5KeepRatioResize, YOLOv5RandomAffine) YOLOv5KeepRatioResize, YOLOv5RandomAffine)
__all__ = [ __all__ = [
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp', 'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations', 'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
'YOLOv5RandomAffine', 'Mosaic9' 'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
'Mosaic9'
] ]

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Tuple, Union from typing import List, Tuple, Union
import cv2 import cv2
import mmcv import mmcv
@ -675,3 +675,397 @@ class YOLOv5RandomAffine(BaseTransform):
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]], translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
dtype=np.float32) dtype=np.float32)
return translation_matrix return translation_matrix
@TRANSFORMS.register_module()
class PPYOLOERandomDistort(BaseTransform):
"""Random hue, saturation, contrast and brightness distortion.
Required Keys:
- img
Modified Keys:
- img (np.float32)
Args:
hue_cfg (dict): Hue settings. Defaults to dict(min=-18,
max=18, prob=0.5).
saturation_cfg (dict): Saturation settings. Defaults to dict(
min=0.5, max=1.5, prob=0.5).
contrast_cfg (dict): Contrast settings. Defaults to dict(
min=0.5, max=1.5, prob=0.5).
brightness_cfg (dict): Brightness settings. Defaults to dict(
min=0.5, max=1.5, prob=0.5).
num_distort_func (int): The number of distort function. Defaults
to 4.
"""
def __init__(self,
hue_cfg: dict = dict(min=-18, max=18, prob=0.5),
saturation_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
contrast_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
brightness_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
num_distort_func: int = 4):
self.hue_cfg = hue_cfg
self.saturation_cfg = saturation_cfg
self.contrast_cfg = contrast_cfg
self.brightness_cfg = brightness_cfg
self.num_distort_func = num_distort_func
assert 0 < self.num_distort_func <= 4,\
'num_distort_func must > 0 and <= 4'
for cfg in [
self.hue_cfg, self.saturation_cfg, self.contrast_cfg,
self.brightness_cfg
]:
assert 0. <= cfg['prob'] <= 1., 'prob must >=0 and <=1'
def transform_hue(self, results):
"""Transform hue randomly."""
if random.uniform(0., 1.) >= self.hue_cfg['prob']:
return results
img = results['img']
delta = random.uniform(self.hue_cfg['min'], self.hue_cfg['max'])
u = np.cos(delta * np.pi)
w = np.sin(delta * np.pi)
delta_iq = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
rgb2yiq_matrix = np.array([[0.114, 0.587, 0.299],
[-0.321, -0.274, 0.596],
[0.311, -0.523, 0.211]])
yiq2rgb_matric = np.array([[1.0, -1.107, 1.705], [1.0, -0.272, -0.647],
[1.0, 0.956, 0.621]])
t = np.dot(np.dot(yiq2rgb_matric, delta_iq), rgb2yiq_matrix).T
img = np.dot(img, t)
results['img'] = img
return results
def transform_saturation(self, results):
"""Transform saturation randomly."""
if random.uniform(0., 1.) >= self.saturation_cfg['prob']:
return results
img = results['img']
delta = random.uniform(self.saturation_cfg['min'],
self.saturation_cfg['max'])
# convert bgr img to gray img
gray = img * np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
img *= delta
img += gray
results['img'] = img
return results
def transform_contrast(self, results):
"""Transform contrast randomly."""
if random.uniform(0., 1.) >= self.contrast_cfg['prob']:
return results
img = results['img']
delta = random.uniform(self.contrast_cfg['min'],
self.contrast_cfg['max'])
img *= delta
results['img'] = img
return results
def transform_brightness(self, results):
"""Transform brightness randomly."""
if random.uniform(0., 1.) >= self.brightness_cfg['prob']:
return results
img = results['img']
delta = random.uniform(self.brightness_cfg['min'],
self.brightness_cfg['max'])
img += delta
results['img'] = img
return results
def transform(self, results: dict) -> dict:
"""The hue, saturation, contrast and brightness distortion function.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
results['img'] = results['img'].astype(np.float32)
functions = [
self.transform_brightness, self.transform_contrast,
self.transform_saturation, self.transform_hue
]
distortions = random.permutation(functions)[:self.num_distort_func]
for func in distortions:
results = func(results)
return results
@TRANSFORMS.register_module()
class PPYOLOERandomCrop(BaseTransform):
"""Random crop the img and bboxes. Different thresholds are used in PPYOLOE
to judge whether the clipped image meets the requirements. This
implementation is different from the implementation of RandomCrop in mmdet.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (bool) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Added Keys:
- pad_param (np.float32)
Args:
aspect_ratio (List[float]): Aspect ratio of cropped region. Default to
[.5, 2].
thresholds (List[float]): Iou thresholds for decide a valid bbox crop
in [min, max] format. Defaults to [.0, .1, .3, .5, .7, .9].
scaling (List[float]): Ratio between a cropped region and the original
image in [min, max] format. Default to [.3, 1.].
num_attempts (int): Number of tries for each threshold before
giving up. Default to 50.
allow_no_crop (bool): Allow return without actually cropping them.
Default to True.
cover_all_box (bool): Ensure all bboxes are covered in the final crop.
Default to False.
"""
def __init__(self,
aspect_ratio: List[float] = [.5, 2.],
thresholds: List[float] = [.0, .1, .3, .5, .7, .9],
scaling: List[float] = [.3, 1.],
num_attempts: int = 50,
allow_no_crop: bool = True,
cover_all_box: bool = False):
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
self.scaling = scaling
self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
def _crop_data(self, results: dict, crop_box: Tuple[int, int, int, int],
valid_inds: np.ndarray) -> Union[dict, None]:
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
crop_box (Tuple[int, int, int, int]): Expected absolute coordinates
for cropping, (x1, y1, x2, y2).
valid_inds (np.ndarray): The indexes of gt that needs to be
retained.
Returns:
results (Union[dict, None]): Randomly cropped results, 'img_shape'
key in result dict is updated according to crop size. None will
be returned when there is no valid bbox after cropping.
"""
# crop the image
img = results['img']
crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
results['img'] = img
img_shape = img.shape
results['img_shape'] = img.shape
# crop bboxes accordingly and clip to the image boundary
if results.get('gt_bboxes', None) is not None:
bboxes = results['gt_bboxes']
bboxes.translate_([-crop_x1, -crop_y1])
bboxes.clip_(img_shape[:2])
results['gt_bboxes'] = bboxes[valid_inds]
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = \
results['gt_ignore_flags'][valid_inds]
if results.get('gt_bboxes_labels', None) is not None:
results['gt_bboxes_labels'] = \
results['gt_bboxes_labels'][valid_inds]
if results.get('gt_masks', None) is not None:
results['gt_masks'] = results['gt_masks'][
valid_inds.nonzero()[0]].crop(
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
# crop semantic seg
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
crop_x1:crop_x2]
return results
@autocast_box_type()
def transform(self, results: dict) -> Union[dict, None]:
"""The random crop transform function.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
if results.get('gt_bboxes', None) is None or len(
results['gt_bboxes']) == 0:
return results
orig_img_h, orig_img_w = results['img'].shape[:2]
gt_bboxes = results['gt_bboxes']
thresholds = list(self.thresholds)
if self.allow_no_crop:
thresholds.append('no_crop')
random.shuffle(thresholds)
for thresh in thresholds:
# Determine the coordinates for cropping
if thresh == 'no_crop':
return results
found = False
for i in range(self.num_attempts):
crop_h, crop_w = self._get_crop_size((orig_img_h, orig_img_w))
if self.aspect_ratio is None:
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
continue
# get image crop_box
margin_h = max(orig_img_h - crop_h, 0)
margin_w = max(orig_img_w - crop_w, 0)
offset_h, offset_w = self._rand_offset((margin_h, margin_w))
crop_y1, crop_y2 = offset_h, offset_h + crop_h
crop_x1, crop_x2 = offset_w, offset_w + crop_w
crop_box = [crop_x1, crop_y1, crop_x2, crop_y2]
# Calculate the iou between gt_bboxes and crop_boxes
iou = self._iou_matrix(gt_bboxes,
np.array([crop_box], dtype=np.float32))
# If the maximum value of the iou is less than thresh,
# the current crop_box is considered invalid.
if iou.max() < thresh:
continue
# If cover_all_box == True and the minimum value of
# the iou is less than thresh, the current crop_box
# is considered invalid.
if self.cover_all_box and iou.min() < thresh:
continue
# Get which gt_bboxes to keep after cropping.
valid_inds = self._get_valid_inds(
gt_bboxes, np.array(crop_box, dtype=np.float32))
if valid_inds.size > 0:
found = True
break
if found:
results = self._crop_data(results, crop_box, valid_inds)
return results
return results
@cache_randomness
def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
"""Randomly generate crop offset.
Args:
margin (Tuple[int, int]): The upper bound for the offset generated
randomly.
Returns:
Tuple[int, int]: The random offset for the crop.
"""
margin_h, margin_w = margin
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
return (offset_h, offset_w)
@cache_randomness
def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
"""Randomly generates the crop size based on `image_size`.
Args:
image_size (Tuple[int, int]): (h, w).
Returns:
crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
"""
h, w = image_size
scale = random.uniform(*self.scaling)
if self.aspect_ratio is not None:
min_ar, max_ar = self.aspect_ratio
aspect_ratio = random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
h_scale = scale / np.sqrt(aspect_ratio)
w_scale = scale * np.sqrt(aspect_ratio)
else:
h_scale = random.uniform(*self.scaling)
w_scale = random.uniform(*self.scaling)
crop_h = h * h_scale
crop_w = w * w_scale
return int(crop_h), int(crop_w)
def _iou_matrix(self,
gt_bbox: HorizontalBoxes,
crop_bbox: np.ndarray,
eps: float = 1e-10) -> np.ndarray:
"""Calculate iou between gt and image crop box.
Args:
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
crop_bbox (np.ndarray): Image crop coordinates in
[x1, y1, x2, y2] format.
eps (float): Default to 1e-10.
Return:
(np.ndarray): IoU.
"""
gt_bbox = gt_bbox.tensor.numpy()
lefttop = np.maximum(gt_bbox[:, np.newaxis, :2], crop_bbox[:, :2])
rightbottom = np.minimum(gt_bbox[:, np.newaxis, 2:], crop_bbox[:, 2:])
overlap = np.prod(
rightbottom - lefttop,
axis=2) * (lefttop < rightbottom).all(axis=2)
area_gt_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
area_crop_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
area_o = (area_gt_bbox[:, np.newaxis] + area_crop_bbox - overlap)
return overlap / (area_o + eps)
def _get_valid_inds(self, gt_bbox: HorizontalBoxes,
img_crop_bbox: np.ndarray) -> np.ndarray:
"""Get which Bboxes to keep at the current cropping coordinates.
Args:
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
img_crop_bbox (np.ndarray): Image crop coordinates in
[x1, y1, x2, y2] format.
Returns:
(np.ndarray): Valid indexes.
"""
cropped_box = gt_bbox.tensor.numpy().copy()
gt_bbox = gt_bbox.tensor.numpy().copy()
cropped_box[:, :2] = np.maximum(gt_bbox[:, :2], img_crop_bbox[:2])
cropped_box[:, 2:] = np.minimum(gt_bbox[:, 2:], img_crop_bbox[2:])
cropped_box[:, :2] -= img_crop_bbox[:2]
cropped_box[:, 2:] -= img_crop_bbox[:2]
centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
valid = np.logical_and(img_crop_bbox[:2] <= centers,
centers < img_crop_bbox[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return np.where(valid)[0]

View File

@ -9,8 +9,14 @@ from ..registry import TASK_UTILS
@COLLATE_FUNCTIONS.register_module() @COLLATE_FUNCTIONS.register_module()
def yolov5_collate(data_batch: Sequence) -> dict: def yolov5_collate(data_batch: Sequence,
"""Rewrite collate_fn to get faster training speed.""" use_ms_training: bool = False) -> dict:
"""Rewrite collate_fn to get faster training speed.
Args:
data_batch (Sequence): Batch of data.
use_ms_training (bool): Whether to use multi-scale training.
"""
batch_imgs = [] batch_imgs = []
batch_bboxes_labels = [] batch_bboxes_labels = []
for i in range(len(data_batch)): for i in range(len(data_batch)):
@ -25,10 +31,16 @@ def yolov5_collate(data_batch: Sequence) -> dict:
batch_bboxes_labels.append(bboxes_labels) batch_bboxes_labels.append(bboxes_labels)
batch_imgs.append(inputs) batch_imgs.append(inputs)
return { if use_ms_training:
'inputs': torch.stack(batch_imgs, 0), return {
'data_samples': torch.cat(batch_bboxes_labels, 0) 'inputs': batch_imgs,
} 'data_samples': torch.cat(batch_bboxes_labels, 0)
}
else:
return {
'inputs': torch.stack(batch_imgs, 0),
'data_samples': torch.cat(batch_bboxes_labels, 0)
}
@TASK_UTILS.register_module() @TASK_UTILS.register_module()

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
from .switch_to_deploy_hook import SwitchToDeployHook from .switch_to_deploy_hook import SwitchToDeployHook
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook from .yolox_mode_switch_hook import YOLOXModeSwitchHook
__all__ = [ __all__ = [
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook' 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
'PPYOLOEParamSchedulerHook'
] ]

View File

@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
from mmengine.hooks import ParamSchedulerHook
from mmengine.runner import Runner
from mmyolo.registry import HOOKS
@HOOKS.register_module()
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
"""A hook to update learning rate and momentum in optimizer of PPYOLOE. We
use this hook to implement adaptive computation for `warmup_total_iters`,
which is not possible with the built-in ParamScheduler in mmyolo.
Args:
warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
start_factor (float): The number we multiply learning rate in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 0.
warmup_epochs (int): Epochs for warmup. Defaults to 5.
min_lr_ratio (float): Minimum learning rate ratio.
total_epochs (int): In PPYOLOE, `total_epochs` is set to
training_epochs x 1.2. Defaults to 360.
"""
priority = 9
def __init__(self,
warmup_min_iter: int = 1000,
start_factor: float = 0.,
warmup_epochs: int = 5,
min_lr_ratio: float = 0.0,
total_epochs: int = 360):
self.warmup_min_iter = warmup_min_iter
self.start_factor = start_factor
self.warmup_epochs = warmup_epochs
self.min_lr_ratio = min_lr_ratio
self.total_epochs = total_epochs
self._warmup_end = False
self._base_lr = None
def before_train(self, runner: Runner):
"""Operations before train.
Args:
runner (Runner): The runner of the training process.
"""
optimizer = runner.optim_wrapper.optimizer
for group in optimizer.param_groups:
# If the param is never be scheduled, record the current value
# as the initial value.
group.setdefault('initial_lr', group['lr'])
self._base_lr = [
group['initial_lr'] for group in optimizer.param_groups
]
self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
def before_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None):
"""Operations before each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
"""
cur_iters = runner.iter
optimizer = runner.optim_wrapper.optimizer
dataloader_len = len(runner.train_dataloader)
# The minimum warmup is self.warmup_min_iter
warmup_total_iters = max(
round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
if cur_iters <= warmup_total_iters:
# warm up
alpha = cur_iters / warmup_total_iters
factor = self.start_factor * (1 - alpha) + alpha
for group_idx, param in enumerate(optimizer.param_groups):
param['lr'] = self._base_lr[group_idx] * factor
else:
for group_idx, param in enumerate(optimizer.param_groups):
total_iters = self.total_epochs * dataloader_len
lr = self._min_lr[group_idx] + (
self._base_lr[group_idx] -
self._min_lr[group_idx]) * 0.5 * (
math.cos((cur_iters - warmup_total_iters) * math.pi /
(total_iters - warmup_total_iters)) + 1.0)
param['lr'] = lr

View File

@ -1,4 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .data_preprocessor import YOLOv5DetDataPreprocessor from .data_preprocessor import (PPYOLOEBatchRandomResize,
PPYOLOEDetDataPreprocessor,
YOLOv5DetDataPreprocessor)
__all__ = ['YOLOv5DetDataPreprocessor'] __all__ = [
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
'PPYOLOEBatchRandomResize'
]

View File

@ -1,6 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Tuple, Union
import torch import torch
import torch.nn.functional as F
from mmdet.models import BatchSyncRandomResize
from mmdet.models.data_preprocessors import DetDataPreprocessor from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmengine import MessageHub, is_list_of
from torch import Tensor
from mmyolo.registry import MODELS from mmyolo.registry import MODELS
@ -50,3 +57,191 @@ class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas} data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
return {'inputs': inputs, 'data_samples': data_samples} return {'inputs': inputs, 'data_samples': data_samples}
@MODELS.register_module()
class PPYOLOEDetDataPreprocessor(DetDataPreprocessor):
"""Image pre-processor for detection tasks.
The main difference between PPYOLOEDetDataPreprocessor and
DetDataPreprocessor is the normalization order. The official
PPYOLOE resize image first, and then normalize image.
In DetDataPreprocessor, the order is reversed.
Note: It must be used together with
`mmyolo.datasets.utils.yolov5_collate`
"""
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``. This class use batch_augments first, and then
normalize the image, which is different from the `DetDataPreprocessor`
.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
if not training:
return super().forward(data, training)
assert isinstance(data['inputs'], list) and is_list_of(
data['inputs'], torch.Tensor), \
'"inputs" should be a list of Tensor, but got ' \
f'{type(data["inputs"])}. The possible reason for this ' \
'is that you are not using it with ' \
'"mmyolo.datasets.utils.yolov5_collate". Please refer to ' \
'"cconfigs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py".'
data = self.cast_data(data)
inputs, data_samples = data['inputs'], data['data_samples']
# Process data.
batch_inputs = []
for _batch_input, data_sample in zip(inputs, data_samples):
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
batch_inputs.append(_batch_input)
# Batch random resize image.
if self.batch_augments is not None:
for batch_aug in self.batch_augments:
inputs, data_samples = batch_aug(batch_inputs, data_samples)
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
return {'inputs': inputs, 'data_samples': data_samples}
# TODO: No generality. Its input data format is different
# mmdet's batch aug, and it must be compatible in the future.
@MODELS.register_module()
class PPYOLOEBatchRandomResize(BatchSyncRandomResize):
"""PPYOLOE batch random resize.
Args:
random_size_range (tuple): The multi-scale random range during
multi-scale training.
interval (int): The iter interval of change
image size. Defaults to 10.
size_divisor (int): Image size divisible factor.
Defaults to 32.
random_interp (bool): Whether to choose interp_mode randomly.
If set to True, the type of `interp_mode` must be list.
If set to False, the type of `interp_mode` must be str.
Defaults to True.
interp_mode (Union[List, str]): The modes available for resizing
are ('nearest', 'bilinear', 'bicubic', 'area').
keep_ratio (bool): Whether to keep the aspect ratio when resizing
the image. Now we only support keep_ratio=False.
Defaults to False.
"""
def __init__(self,
random_size_range: Tuple[int, int],
interval: int = 1,
size_divisor: int = 32,
random_interp=True,
interp_mode: Union[List[str], str] = [
'nearest', 'bilinear', 'bicubic', 'area'
],
keep_ratio: bool = False) -> None:
super().__init__(random_size_range, interval, size_divisor)
self.random_interp = random_interp
self.keep_ratio = keep_ratio
# TODO: need to support keep_ratio==True
assert not self.keep_ratio, 'We do not yet support keep_ratio=True'
if self.random_interp:
assert isinstance(interp_mode, list) and len(interp_mode) > 1,\
'While random_interp==True, the type of `interp_mode`' \
' must be list and len(interp_mode) must large than 1'
self.interp_mode_list = interp_mode
self.interp_mode = None
else:
assert isinstance(interp_mode, str),\
'While random_interp==False, the type of ' \
'`interp_mode` must be str'
assert interp_mode in ['nearest', 'bilinear', 'bicubic', 'area']
self.interp_mode_list = None
self.interp_mode = interp_mode
def forward(self, inputs: list,
data_samples: Tensor) -> Tuple[Tensor, Tensor]:
"""Resize a batch of images and bboxes to shape ``self._input_size``.
The inputs and data_samples should be list, and
``PPYOLOEBatchRandomResize`` must be used with
``PPYOLOEDetDataPreprocessor`` and ``yolov5_collate`` with
``use_ms_training == True``.
"""
assert isinstance(inputs, list),\
'The type of inputs must be list. The possible reason for this ' \
'is that you are not using it with `PPYOLOEDetDataPreprocessor` ' \
'and `yolov5_collate` with use_ms_training == True.'
message_hub = MessageHub.get_current_instance()
if (message_hub.get_info('iter') + 1) % self._interval == 0:
# get current input size
self._input_size, interp_mode = self._get_random_size_and_interp()
if self.random_interp:
self.interp_mode = interp_mode
# TODO: need to support type(inputs)==Tensor
if isinstance(inputs, list):
outputs = []
for i in range(len(inputs)):
_batch_input = inputs[i]
h, w = _batch_input.shape[-2:]
scale_y = self._input_size[0] / h
scale_x = self._input_size[1] / w
if scale_x != 1. or scale_y != 1.:
if self.interp_mode in ('nearest', 'area'):
align_corners = None
else:
align_corners = False
_batch_input = F.interpolate(
_batch_input.unsqueeze(0),
size=self._input_size,
mode=self.interp_mode,
align_corners=align_corners)
# rescale boxes
indexes = data_samples[:, 0] == i
data_samples[indexes, 2] *= scale_x
data_samples[indexes, 3] *= scale_y
data_samples[indexes, 4] *= scale_x
data_samples[indexes, 5] *= scale_y
else:
_batch_input = _batch_input.unsqueeze(0)
outputs.append(_batch_input)
# convert to Tensor
return torch.cat(outputs, dim=0), data_samples
else:
raise NotImplementedError('Not implemented yet!')
def _get_random_size_and_interp(self) -> Tuple[int, int]:
"""Randomly generate a shape in ``_random_size_range`` and a
interp_mode in interp_mode_list."""
size = random.randint(*self._random_size_range)
input_size = (self._size_divisor * size, self._size_divisor * size)
if self.random_interp:
interp_ind = random.randint(0, len(self.interp_mode_list) - 1)
interp_mode = self.interp_mode_list[interp_ind]
else:
interp_mode = None
return input_size, interp_mode

View File

@ -1,24 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union from typing import Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList, from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig) OptMultiConfig, reduce_mean)
from mmengine import MessageHub
from mmengine.model import BaseModule, bias_init_with_prob from mmengine.model import BaseModule, bias_init_with_prob
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
from mmyolo.registry import MODELS from mmyolo.registry import MODELS
from ..layers.yolo_bricks import PPYOLOESELayer from ..layers.yolo_bricks import PPYOLOESELayer
from .yolov5_head import YOLOv5Head from .yolov6_head import YOLOv6Head
@MODELS.register_module() @MODELS.register_module()
class PPYOLOEHeadModule(BaseModule): class PPYOLOEHeadModule(BaseModule):
"""PPYOLOEHead head module used in `PPYOLOE` """PPYOLOEHead head module used in `PPYOLOE.
<https://arxiv.org/abs/2203.16250>`_.
Args: Args:
num_classes (int): Number of categories excluding the background num_classes (int): Number of categories excluding the background
@ -30,7 +33,8 @@ class PPYOLOEHeadModule(BaseModule):
on the feature grid. on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map. featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32). Defaults to (8, 16, 32).
reg_max (int): TOOD reg_max param. reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
in QFL setting. Defaults to 16.
norm_cfg (dict): Config dict for normalization layer. norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001). Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer. act_cfg (dict): Config dict for activation layer.
@ -100,15 +104,12 @@ class PPYOLOEHeadModule(BaseModule):
self.reg_preds.append( self.reg_preds.append(
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1)) nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
self.proj_conv = nn.Conv2d(self.reg_max + 1, 1, 1, bias=False) # init proj
self.proj = nn.Parameter( proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
torch.linspace(0, self.reg_max, self.reg_max + 1), [1, self.reg_max + 1, 1, 1])
requires_grad=False) self.register_buffer('proj', proj, persistent=False)
self.proj_conv.weight = nn.Parameter(
self.proj.view([1, self.reg_max + 1, 1, 1]).clone().detach(),
requires_grad=False)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tuple[Tensor]) -> Tensor:
"""Forward features from the upstream network. """Forward features from the upstream network.
Args: Args:
@ -131,17 +132,24 @@ class PPYOLOEHeadModule(BaseModule):
hw = h * w hw = h * w
avg_feat = F.adaptive_avg_pool2d(x, (1, 1)) avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
cls_logit = cls_pred(cls_stem(x, avg_feat) + x) cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
reg_dist = reg_pred(reg_stem(x, avg_feat)) bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, # TODO: Test whether use matmul instead of conv can speed up training.
hw]).permute(0, 2, 3, 1) bbox_dist_preds = bbox_dist_preds.reshape(
reg_dist = self.proj_conv(F.softmax(reg_dist, dim=1)) [-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
return cls_logit, reg_dist bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds
else:
return cls_logit, bbox_preds
@MODELS.register_module() @MODELS.register_module()
class PPYOLOEHead(YOLOv5Head): class PPYOLOEHead(YOLOv6Head):
"""PPYOLOEHead head used in `PPYOLOE`. """PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_.
The YOLOv6 head and the PPYOLOE head are only slightly different.
Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
Args: Args:
head_module(ConfigType): Base module used for YOLOv5Head head_module(ConfigType): Base module used for YOLOv5Head
@ -150,7 +158,8 @@ class PPYOLOEHead(YOLOv5Head):
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss. loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None. anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
@ -168,17 +177,24 @@ class PPYOLOEHead(YOLOv5Head):
strides=[8, 16, 32]), strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict( loss_cls: ConfigType = dict(
type='mmdet.CrossEntropyLoss', type='mmdet.VarifocalLoss',
use_sigmoid=True, use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='sum', reduction='sum',
loss_weight=1.0), loss_weight=1.0),
loss_bbox: ConfigType = dict( loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0), type='IoULoss',
loss_obj: ConfigType = dict( iou_mode='giou',
type='mmdet.CrossEntropyLoss', bbox_format='xyxy',
use_sigmoid=True, reduction='mean',
reduction='sum', loss_weight=2.5,
loss_weight=1.0), return_iou=False),
loss_dfl: ConfigType = dict(
type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=0.5 / 4),
train_cfg: OptConfigType = None, train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None, test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None): init_cfg: OptMultiConfig = None):
@ -188,19 +204,18 @@ class PPYOLOEHead(YOLOv5Head):
bbox_coder=bbox_coder, bbox_coder=bbox_coder,
loss_cls=loss_cls, loss_cls=loss_cls,
loss_bbox=loss_bbox, loss_bbox=loss_bbox,
loss_obj=loss_obj,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=init_cfg) init_cfg=init_cfg)
self.loss_dfl = MODELS.build(loss_dfl)
def special_init(self): # ppyoloe doesn't need loss_obj
"""Not Implenented.""" self.loss_obj = None
pass
def loss_by_feat( def loss_by_feat(
self, self,
cls_scores: Sequence[Tensor], cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor], bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData], batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict], batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict: batch_gt_instances_ignore: OptInstanceList = None) -> dict:
@ -214,6 +229,8 @@ class PPYOLOEHead(YOLOv5Head):
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is level, each is a 4D-tensor, the channel number is
num_priors * 4. num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels`` gt_instance. It usually includes ``bboxes`` and ``labels``
attributes. attributes.
@ -226,4 +243,131 @@ class PPYOLOEHead(YOLOv5Head):
Returns: Returns:
dict[str, Tensor]: A dictionary of losses. dict[str, Tensor]: A dictionary of losses.
""" """
raise NotImplementedError('Not implemented yet')
# get epoch information from message hub
message_hub = MessageHub.get_current_instance()
current_epoch = message_hub.get_info('epoch')
num_imgs = len(batch_img_metas)
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = current_featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
# (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
flatten_pred_dists = [
bbox_pred_org.permute(0, 2, 3, 1).reshape(
num_imgs, -1, (self.head_module.reg_max + 1) * 4)
for bbox_pred_org in bbox_dist_preds
]
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.stride_tensor[..., 0])
pred_scores = torch.sigmoid(flatten_cls_preds)
if current_epoch < self.initial_epoch:
assigned_result = self.initial_assigner(
flatten_pred_bboxes.detach(), self.flatten_priors_train,
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
else:
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
pred_scores.detach(),
self.flatten_priors_train,
gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
# cls loss
with torch.cuda.amp.autocast(enabled=False):
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
assigned_scores_sum = assigned_scores.sum()
# reduce_mean between all gpus
assigned_scores_sum = torch.clamp(
reduce_mean(assigned_scores_sum), min=1)
loss_cls /= assigned_scores_sum
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos,
assigned_bboxes_pos,
weight=bbox_weight,
avg_factor=assigned_scores_sum)
# dfl loss
dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
[1, 1, (self.head_module.reg_max + 1) * 4])
pred_dist_pos = torch.masked_select(
flatten_dist_preds,
dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
assigned_ltrb = self.bbox_coder.encode(
self.flatten_priors_train[..., :2] / self.stride_tensor,
assigned_bboxes,
max_dis=self.head_module.reg_max,
eps=0.01)
assigned_ltrb_pos = torch.masked_select(
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
loss_dfl = self.loss_dfl(
pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
assigned_ltrb_pos.reshape(-1),
weight=bbox_weight.expand(-1, 4).reshape(-1),
avg_factor=assigned_scores_sum)
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
loss_dfl = flatten_pred_bboxes.sum() * 0
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)

View File

@ -169,7 +169,6 @@ class YOLOv6Head(YOLOv5Head):
in 2D points-based detectors. in 2D points-based detectors.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None. anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
@ -201,11 +200,6 @@ class YOLOv6Head(YOLOv5Head):
reduction='mean', reduction='mean',
loss_weight=2.5, loss_weight=2.5,
return_iou=False), return_iou=False),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
train_cfg: OptConfigType = None, train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None, test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None): init_cfg: OptMultiConfig = None):
@ -215,13 +209,11 @@ class YOLOv6Head(YOLOv5Head):
bbox_coder=bbox_coder, bbox_coder=bbox_coder,
loss_cls=loss_cls, loss_cls=loss_cls,
loss_bbox=loss_bbox, loss_bbox=loss_bbox,
loss_obj=loss_obj,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=init_cfg) init_cfg=init_cfg)
# yolov6 doesn't need loss_obj
self.loss_bbox = MODELS.build(loss_bbox) self.loss_obj = None
self.loss_cls = MODELS.build(loss_cls)
def special_init(self): def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but """Since YOLO series algorithms will inherit from YOLOv5Head, but
@ -236,10 +228,9 @@ class YOLOv6Head(YOLOv5Head):
self.assigner = TASK_UTILS.build(self.train_cfg.assigner) self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
# Add common attributes to reduce calculation # Add common attributes to reduce calculation
self.featmap_sizes = None self.featmap_sizes_train = None
self.mlvl_priors = None
self.num_level_priors = None self.num_level_priors = None
self.flatten_priors = None self.flatten_priors_train = None
self.stride_tensor = None self.stride_tensor = None
def loss_by_feat( def loss_by_feat(
@ -284,19 +275,19 @@ class YOLOv6Head(YOLOv5Head):
cls_score.shape[2:] for cls_score in cls_scores cls_score.shape[2:] for cls_score in cls_scores
] ]
# If the shape does not equal, generate new one # If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes: if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes = current_featmap_sizes self.featmap_sizes_train = current_featmap_sizes
mlvl_priors = self.prior_generator.grid_priors( mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes, self.featmap_sizes_train,
dtype=cls_scores[0].dtype, dtype=cls_scores[0].dtype,
device=cls_scores[0].device, device=cls_scores[0].device,
with_stride=True) with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors] self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors = torch.cat(mlvl_priors, dim=0) self.flatten_priors_train = torch.cat(
self.stride_tensor = self.flatten_priors[..., [2]] mlvl_priors_with_stride, dim=0)
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors] self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info # gt info
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs) gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
@ -319,19 +310,20 @@ class YOLOv6Head(YOLOv5Head):
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1) flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1) flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode( flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors[..., :2], flatten_pred_bboxes, self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.flatten_priors[..., 2]) self.stride_tensor[:, 0])
pred_scores = torch.sigmoid(flatten_cls_preds) pred_scores = torch.sigmoid(flatten_cls_preds)
if current_epoch < self.initial_epoch: if current_epoch < self.initial_epoch:
assigned_result = self.initial_assigner( assigned_result = self.initial_assigner(
flatten_pred_bboxes.detach(), self.flatten_priors, flatten_pred_bboxes.detach(), self.flatten_priors_train,
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag) self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
else: else:
assigned_result = self.assigner(flatten_pred_bboxes.detach(), assigned_result = self.assigner(flatten_pred_bboxes.detach(),
pred_scores.detach(), pred_scores.detach(),
self.flatten_priors, gt_labels, self.flatten_priors_train,
gt_bboxes, pad_bbox_flag) gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes'] assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores'] assigned_scores = assigned_result['assigned_scores']

View File

@ -92,7 +92,7 @@ class BatchATSSAssigner(nn.Module):
Args: Args:
pred_bboxes (Tensor): Predicted bounding boxes, pred_bboxes (Tensor): Predicted bounding boxes,
shape(batch_size, num_priors, 4) shape(batch_size, num_priors, 4)
priors (Tensor): Model priors, shape(num_priors, 4) priors (Tensor): Model priors with stride, shape(num_priors, 4)
num_level_priors (List): Number of bboxes in each level, len(3) num_level_priors (List): Number of bboxes in each level, len(3)
gt_labels (Tensor): Ground truth label, gt_labels (Tensor): Ground truth label,
shape(batch_size, num_gt, 1) shape(batch_size, num_gt, 1)

View File

@ -4,7 +4,7 @@ from typing import Optional, Sequence, Union
import torch import torch
from mmdet.models.task_modules.coders import \ from mmdet.models.task_modules.coders import \
DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder
from mmdet.structures.bbox import distance2bbox from mmdet.structures.bbox import bbox2distance, distance2bbox
from mmyolo.registry import TASK_UTILS from mmyolo.registry import TASK_UTILS
@ -51,3 +51,29 @@ class DistancePointBBoxCoder(MMDET_DistancePointBBoxCoder):
pred_bboxes = pred_bboxes * stride[None, :, None] pred_bboxes = pred_bboxes * stride[None, :, None]
return distance2bbox(points, pred_bboxes, max_shape) return distance2bbox(points, pred_bboxes, max_shape)
def encode(self,
points: torch.Tensor,
gt_bboxes: torch.Tensor,
max_dis: float = 16.,
eps: float = 0.01) -> torch.Tensor:
"""Encode bounding box to distances. The rewrite is to support batch
operations.
Args:
points (Tensor): Shape (B, N, 2) or (N, 2), The format is [x, y].
gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
is "xyxy"
max_dis (float): Upper bound of the distance. Default to 16..
eps (float): a small value to ensure target < max_dis, instead <=.
Default 0.01.
Returns:
Tensor: Box transformation deltas. The shape is (N, 4) or
(B, N, 4).
"""
assert points.size(-2) == gt_bboxes.size(-2)
assert points.size(-1) == 2
assert gt_bboxes.size(-1) == 4
return bbox2distance(points, gt_bboxes, max_dis, eps)

View File

@ -4,3 +4,4 @@ Import:
- configs/yolox/metafile.yml - configs/yolox/metafile.yml
- configs/rtmdet/metafile.yml - configs/rtmdet/metafile.yml
- configs/yolov7/metafile.yml - configs/yolov7/metafile.yml
- configs/ppyoloe/metafile.yml

View File

@ -13,6 +13,8 @@ from mmyolo.datasets.transforms import (LetterResize, LoadAnnotations,
YOLOv5HSVRandomAug, YOLOv5HSVRandomAug,
YOLOv5KeepRatioResize, YOLOv5KeepRatioResize,
YOLOv5RandomAffine) YOLOv5RandomAffine)
from mmyolo.datasets.transforms.transforms import (PPYOLOERandomCrop,
PPYOLOERandomDistort)
class TestLetterResize(unittest.TestCase): class TestLetterResize(unittest.TestCase):
@ -355,3 +357,100 @@ class TestYOLOv5RandomAffine(unittest.TestCase):
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64) self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == torch.float32) self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool) self.assertTrue(results['gt_ignore_flags'].dtype == bool)
class TestPPYOLOERandomCrop(unittest.TestCase):
def setUp(self):
"""Setup the data info which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
self.results = {
'img':
np.random.random((224, 224, 3)),
'img_shape': (224, 224),
'gt_bboxes_labels':
np.array([1, 2, 3], dtype=np.int64),
'gt_bboxes':
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
dtype=np.float32),
'gt_ignore_flags':
np.array([0, 0, 1], dtype=bool),
}
def test_transform(self):
transform = PPYOLOERandomCrop()
results = transform(copy.deepcopy(self.results))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
def test_transform_with_boxlist(self):
results = copy.deepcopy(self.results)
results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
transform = PPYOLOERandomCrop()
results = transform(copy.deepcopy(results))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
class TestPPYOLOERandomDistort(unittest.TestCase):
def setUp(self):
"""Setup the data info which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
self.results = {
'img':
np.random.random((224, 224, 3)),
'img_shape': (224, 224),
'gt_bboxes_labels':
np.array([1, 2, 3], dtype=np.int64),
'gt_bboxes':
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
dtype=np.float32),
'gt_ignore_flags':
np.array([0, 0, 1], dtype=bool),
}
def test_transform(self):
# test assertion for invalid prob
with self.assertRaises(AssertionError):
transform = PPYOLOERandomDistort(
hue_cfg=dict(min=-18, max=18, prob=1.5))
# test assertion for invalid num_distort_func
with self.assertRaises(AssertionError):
transform = PPYOLOERandomDistort(num_distort_func=5)
transform = PPYOLOERandomDistort()
results = transform(copy.deepcopy(self.results))
self.assertTrue(results['img'].shape[:2] == (224, 224))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
def test_transform_with_boxlist(self):
results = copy.deepcopy(self.results)
results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
transform = PPYOLOERandomDistort()
results = transform(copy.deepcopy(results))
self.assertTrue(results['img'].shape[:2] == (224, 224))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)

View File

@ -47,6 +47,36 @@ class TestYOLOv5Collate(unittest.TestCase):
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10)) self.assertTrue(out['inputs'].shape == (2, 3, 10, 10))
self.assertTrue(out['data_samples'].shape == (8, 6)) self.assertTrue(out['data_samples'].shape == (8, 6))
def test_yolov5_collate_with_multi_scale(self):
rng = np.random.RandomState(0)
inputs = torch.randn((3, 10, 10))
data_samples = DetDataSample()
gt_instances = InstanceData()
bboxes = _rand_bboxes(rng, 4, 6, 8)
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
labels = rng.randint(1, 2, size=len(bboxes))
gt_instances.labels = torch.LongTensor(labels)
data_samples.gt_instances = gt_instances
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)],
use_ms_training=True)
self.assertIsInstance(out, dict)
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
print(out['data_samples'].shape)
self.assertTrue(out['data_samples'].shape == (4, 6))
self.assertIsInstance(out['inputs'], list)
self.assertIsInstance(out['data_samples'], torch.Tensor)
out = yolov5_collate(
[dict(inputs=inputs, data_samples=data_samples)] * 2,
use_ms_training=True)
self.assertIsInstance(out, dict)
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
self.assertTrue(out['data_samples'].shape == (8, 6))
self.assertIsInstance(out['inputs'], list)
self.assertIsInstance(out['data_samples'], torch.Tensor)
class TestBatchShapePolicy(unittest.TestCase): class TestBatchShapePolicy(unittest.TestCase):

View File

@ -3,8 +3,13 @@ from unittest import TestCase
import torch import torch
from mmdet.structures import DetDataSample from mmdet.structures import DetDataSample
from mmengine import MessageHub
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv5DetDataPreprocessor(TestCase): class TestYOLOv5DetDataPreprocessor(TestCase):
@ -69,3 +74,51 @@ class TestYOLOv5DetDataPreprocessor(TestCase):
# data_samples must be tensor # data_samples must be tensor
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
processor(data, training=True) processor(data, training=True)
class TestPPYOLOEDetDataPreprocessor(TestCase):
def test_batch_random_resize(self):
processor = PPYOLOEDetDataPreprocessor(
pad_size_divisor=32,
batch_augments=[
dict(
type='PPYOLOEBatchRandomResize',
random_size_range=(320, 480),
interval=1,
size_divisor=32,
random_interp=True,
keep_ratio=False)
],
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True)
self.assertTrue(
isinstance(processor.batch_augments[0], PPYOLOEBatchRandomResize))
message_hub = MessageHub.get_instance('test_batch_random_resize')
message_hub.update_info('iter', 0)
# test training
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 10, 11))
],
'data_samples':
torch.randint(0, 11, (18, 6)).float(),
}
out_data = processor(data, training=True)
batch_data_samples = out_data['data_samples']
self.assertIn('img_metas', batch_data_samples)
self.assertIn('bboxes_labels', batch_data_samples)
self.assertIsInstance(batch_data_samples['bboxes_labels'],
torch.Tensor)
self.assertIsInstance(batch_data_samples['img_metas'], list)
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': DetDataSample()
}
# data_samples must be list
with self.assertRaises(TypeError):
processor(data, training=True)

View File

@ -2,6 +2,7 @@
from unittest import TestCase from unittest import TestCase
import torch import torch
from mmengine import ConfigDict, MessageHub
from mmengine.config import Config from mmengine.config import Config
from mmengine.model import bias_init_with_prob from mmengine.model import bias_init_with_prob
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
@ -12,11 +13,14 @@ from mmyolo.utils import register_all_modules
register_all_modules() register_all_modules()
class TestYOLOXHead(TestCase): class TestPPYOLOEHead(TestCase):
def setUp(self): def setUp(self):
self.head_module = dict( self.head_module = dict(
type='PPYOLOEHeadModule', num_classes=4, in_channels=[32, 64, 128]) type='PPYOLOEHeadModule',
num_classes=4,
in_channels=[32, 64, 128],
featmap_strides=(8, 16, 32))
def test_init_weights(self): def test_init_weights(self):
head = PPYOLOEHead(head_module=self.head_module) head = PPYOLOEHead(head_module=self.head_module)
@ -50,6 +54,7 @@ class TestYOLOXHead(TestCase):
max_per_img=300)) max_per_img=300))
head = PPYOLOEHead(head_module=self.head_module, test_cfg=test_cfg) head = PPYOLOEHead(head_module=self.head_module, test_cfg=test_cfg)
head.eval()
feat = [ feat = [
torch.rand(1, in_channels, s // feat_size, s // feat_size) torch.rand(1, in_channels, s // feat_size, s // feat_size)
for in_channels, feat_size in [[32, 8], [64, 16], [128, 32]] for in_channels, feat_size in [[32, 8], [64, 16], [128, 32]]
@ -71,3 +76,130 @@ class TestYOLOXHead(TestCase):
cfg=test_cfg, cfg=test_cfg,
rescale=False, rescale=False,
with_nms=False) with_nms=False)
def test_loss_by_feat(self):
message_hub = MessageHub.get_instance('test_ppyoloe_loss_by_feat')
message_hub.update_info('epoch', 1)
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'batch_input_shape': (s, s),
'scale_factor': 1,
}]
head = PPYOLOEHead(
head_module=self.head_module,
train_cfg=ConfigDict(
initial_epoch=31,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=4,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=4,
topk=13,
alpha=1,
beta=6)))
head.train()
feat = []
for i in range(len(self.head_module['in_channels'])):
in_channel = self.head_module['in_channels'][i]
feat_size = self.head_module['featmap_strides'][i]
feat.append(
torch.rand(1, in_channel, s // feat_size, s // feat_size))
cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
# Test that empty ground truth encourages the network to predict
# background
gt_instances = torch.empty((0, 6), dtype=torch.float32)
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
bbox_dist_preds, gt_instances,
img_metas)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
empty_cls_loss = empty_gt_losses['loss_cls'].sum()
empty_box_loss = empty_gt_losses['loss_bbox'].sum()
empty_dfl_loss = empty_gt_losses['loss_dfl'].sum()
self.assertGreater(empty_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertEqual(
empty_box_loss.item(), 0,
'there should be no box loss when there are no true boxes')
self.assertEqual(
empty_dfl_loss.item(), 0,
'there should be df loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
head = PPYOLOEHead(
head_module=self.head_module,
train_cfg=ConfigDict(
initial_epoch=31,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=4,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=4,
topk=13,
alpha=1,
beta=6)))
head.train()
gt_instances = torch.Tensor(
[[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
bbox_dist_preds, gt_instances,
img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
self.assertGreater(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')
self.assertGreater(onegt_loss_dfl.item(), 0,
'obj loss should be non-zero')
# test num_class = 1
self.head_module['num_classes'] = 1
head = PPYOLOEHead(
head_module=self.head_module,
train_cfg=ConfigDict(
initial_epoch=31,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=1,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=1,
topk=13,
alpha=1,
beta=6)))
head.train()
gt_instances = torch.Tensor(
[[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
bbox_dist_preds, gt_instances,
img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
self.assertGreater(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')
self.assertGreater(onegt_loss_dfl.item(), 0,
'obj loss should be non-zero')

View File

@ -5,13 +5,13 @@ from collections import OrderedDict
import torch import torch
def convert_bn(k): def convert_bn(k: str):
name = k.replace('._mean', name = k.replace('._mean',
'.running_mean').replace('._variance', '.running_var') '.running_mean').replace('._variance', '.running_var')
return name return name
def convert_repvgg(k): def convert_repvgg(k: str):
if '.conv2.conv1.' in k: if '.conv2.conv1.' in k:
name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.') name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.')
return name return name
@ -22,111 +22,142 @@ def convert_repvgg(k):
return k return k
def convert(src, dst): def convert(src: str, dst: str, imagenet_pretrain: bool = False):
# TODO: add pretrained model convert
with open(src, 'rb') as f: with open(src, 'rb') as f:
model = pickle.load(f) model = pickle.load(f)
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in model.items(): if imagenet_pretrain:
name = k for k, v in model.items():
if k.startswith('backbone.'): if '@@' in k:
if '.stem.' in k: continue
if 'stem.' in k:
# backbone.stem.conv1.conv.weight # backbone.stem.conv1.conv.weight
# -> backbone.stem.0.conv.weight # -> backbone.stem.0.conv.weight
org_ind = k.split('.')[2][-1] org_ind = k.split('.')[1][-1]
new_ind = str(int(org_ind) - 1) new_ind = str(int(org_ind) - 1)
name = k.replace('.stem.conv%s.' % org_ind, name = k.replace('stem.conv%s.' % org_ind,
'.stem.%s.' % new_ind) 'stem.%s.' % new_ind)
else: else:
# backbone.stages.1.conv2.bn._variance # backbone.stages.1.conv2.bn._variance
# -> backbone.stage2.0.conv2.bn.running_var # -> backbone.stage2.0.conv2.bn.running_var
org_stage_ind = k.split('.')[2] org_stage_ind = k.split('.')[1]
new_stage_ind = str(int(org_stage_ind) + 1) new_stage_ind = str(int(org_stage_ind) + 1)
name = k.replace('.stages.%s.' % org_stage_ind, name = k.replace('stages.%s.' % org_stage_ind,
'.stage%s.0.' % new_stage_ind) 'stage%s.0.' % new_stage_ind)
name = convert_repvgg(name) name = convert_repvgg(name)
if '.attn.' in k: if '.attn.' in k:
name = name.replace('.attn.fc.', '.attn.fc.conv.') name = name.replace('.attn.fc.', '.attn.fc.conv.')
name = convert_bn(name) name = convert_bn(name)
elif k.startswith('neck.'): name = 'backbone.' + name
# fpn_stages
if k.startswith('neck.fpn_stages.'): new_state_dict[name] = torch.from_numpy(v)
# neck.fpn_stages.0.0.conv1.conv.weight else:
# -> neck.reduce_layers.2.0.conv1.conv.weight for k, v in model.items():
if k.startswith('neck.fpn_stages.0.0.'): name = k
name = k.replace('neck.fpn_stages.0.0.', if k.startswith('backbone.'):
'neck.reduce_layers.2.0.') if '.stem.' in k:
if '.spp.' in name: # backbone.stem.conv1.conv.weight
name = name.replace('.spp.conv.', '.spp.conv2.') # -> backbone.stem.0.conv.weight
# neck.fpn_stages.1.0.conv1.conv.weight org_ind = k.split('.')[2][-1]
# -> neck.top_down_layers.0.0.conv1.conv.weight new_ind = str(int(org_ind) - 1)
elif k.startswith('neck.fpn_stages.1.0.'): name = k.replace('.stem.conv%s.' % org_ind,
name = k.replace('neck.fpn_stages.1.0.', '.stem.%s.' % new_ind)
'neck.top_down_layers.0.0.')
elif k.startswith('neck.fpn_stages.2.0.'):
name = k.replace('neck.fpn_stages.2.0.',
'neck.top_down_layers.1.0.')
else: else:
raise NotImplementedError('Not implemented.') # backbone.stages.1.conv2.bn._variance
name = name.replace('.0.convs.', '.0.blocks.') # -> backbone.stage2.0.conv2.bn.running_var
elif k.startswith('neck.fpn_routes.'): org_stage_ind = k.split('.')[2]
# neck.fpn_routes.0.conv.weight new_stage_ind = str(int(org_stage_ind) + 1)
# -> neck.upsample_layers.0.0.conv.weight name = k.replace('.stages.%s.' % org_stage_ind,
index = k.split('.')[2] '.stage%s.0.' % new_stage_ind)
name = 'neck.upsample_layers.' + index + '.0.' + '.'.join( name = convert_repvgg(name)
k.split('.')[-2:]) if '.attn.' in k:
name = name.replace('.0.convs.', '.0.blocks.') name = name.replace('.attn.fc.', '.attn.fc.conv.')
elif k.startswith('neck.pan_stages.'): name = convert_bn(name)
# neck.pan_stages.0.0.conv1.conv.weight elif k.startswith('neck.'):
# -> neck.bottom_up_layers.1.0.conv1.conv.weight # fpn_stages
ind = k.split('.')[2] if k.startswith('neck.fpn_stages.'):
name = k.replace( # neck.fpn_stages.0.0.conv1.conv.weight
'neck.pan_stages.' + ind, # -> neck.reduce_layers.2.0.conv1.conv.weight
'neck.bottom_up_layers.' + ('0' if ind == '1' else '1')) if k.startswith('neck.fpn_stages.0.0.'):
name = name.replace('.0.convs.', '.0.blocks.') name = k.replace('neck.fpn_stages.0.0.',
elif k.startswith('neck.pan_routes.'): 'neck.reduce_layers.2.0.')
# neck.pan_routes.0.conv.weight if '.spp.' in name:
# -> neck.downsample_layers.0.conv.weight name = name.replace('.spp.conv.', '.spp.conv2.')
ind = k.split('.')[2] # neck.fpn_stages.1.0.conv1.conv.weight
name = k.replace( # -> neck.top_down_layers.0.0.conv1.conv.weight
'neck.pan_routes.' + ind, elif k.startswith('neck.fpn_stages.1.0.'):
'neck.downsample_layers.' + ('0' if ind == '1' else '1')) name = k.replace('neck.fpn_stages.1.0.',
name = name.replace('.0.convs.', '.0.blocks.') 'neck.top_down_layers.0.0.')
elif k.startswith('neck.fpn_stages.2.0.'):
name = k.replace('neck.fpn_stages.2.0.',
'neck.top_down_layers.1.0.')
else:
raise NotImplementedError('Not implemented.')
name = name.replace('.0.convs.', '.0.blocks.')
elif k.startswith('neck.fpn_routes.'):
# neck.fpn_routes.0.conv.weight
# -> neck.upsample_layers.0.0.conv.weight
index = k.split('.')[2]
name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
k.split('.')[-2:])
name = name.replace('.0.convs.', '.0.blocks.')
elif k.startswith('neck.pan_stages.'):
# neck.pan_stages.0.0.conv1.conv.weight
# -> neck.bottom_up_layers.1.0.conv1.conv.weight
ind = k.split('.')[2]
name = k.replace(
'neck.pan_stages.' + ind, 'neck.bottom_up_layers.' +
('0' if ind == '1' else '1'))
name = name.replace('.0.convs.', '.0.blocks.')
elif k.startswith('neck.pan_routes.'):
# neck.pan_routes.0.conv.weight
# -> neck.downsample_layers.0.conv.weight
ind = k.split('.')[2]
name = k.replace(
'neck.pan_routes.' + ind, 'neck.downsample_layers.' +
('0' if ind == '1' else '1'))
name = name.replace('.0.convs.', '.0.blocks.')
else:
raise NotImplementedError('Not implement.')
name = convert_repvgg(name)
name = convert_bn(name)
elif k.startswith('yolo_head.'):
if ('anchor_points' in k) or ('stride_tensor' in k):
continue
if 'proj_conv' in k:
name = k.replace('yolo_head.proj_conv.',
'bbox_head.head_module.proj_conv.')
else:
for org_key, rep_key in [
[
'yolo_head.stem_cls.',
'bbox_head.head_module.cls_stems.'
],
[
'yolo_head.stem_reg.',
'bbox_head.head_module.reg_stems.'
],
[
'yolo_head.pred_cls.',
'bbox_head.head_module.cls_preds.'
],
[
'yolo_head.pred_reg.',
'bbox_head.head_module.reg_preds.'
]
]:
name = name.replace(org_key, rep_key)
name = name.split('.')
ind = name[3]
name[3] = str(2 - int(ind))
name = '.'.join(name)
name = convert_bn(name)
else: else:
raise NotImplementedError('Not implement.')
name = convert_repvgg(name)
name = convert_bn(name)
elif k.startswith('yolo_head.'):
if ('anchor_points' in k) or ('stride_tensor' in k):
continue continue
if 'proj_conv' in k:
name = k.replace('yolo_head.proj_conv.',
'bbox_head.head_module.proj_conv.')
else:
for org_key, rep_key in [[
'yolo_head.stem_cls.',
'bbox_head.head_module.cls_stems.'
], ['yolo_head.stem_reg.', 'bbox_head.head_module.reg_stems.'],
[
'yolo_head.pred_cls.',
'bbox_head.head_module.cls_preds.'
],
[
'yolo_head.pred_reg.',
'bbox_head.head_module.reg_preds.'
]]:
name = name.replace(org_key, rep_key)
name = name.split('.')
ind = name[3]
name[3] = str(2 - int(ind))
name = '.'.join(name)
name = convert_bn(name)
else:
continue
new_state_dict[name] = torch.from_numpy(v) new_state_dict[name] = torch.from_numpy(v)
data = {'state_dict': new_state_dict} data = {'state_dict': new_state_dict}
torch.save(data, dst) torch.save(data, dst)
@ -139,8 +170,14 @@ def main():
help='src ppyoloe model path') help='src ppyoloe model path')
parser.add_argument( parser.add_argument(
'--dst', default='mmppyoloe_plus_s.pt', help='save path') '--dst', default='mmppyoloe_plus_s.pt', help='save path')
parser.add_argument(
'--imagenet-pretrain',
action='store_true',
default=False,
help='Load model pretrained on imagenet dataset which only '
'have weight for backbone.')
args = parser.parse_args() args = parser.parse_args()
convert(args.src, args.dst) convert(args.src, args.dst, args.imagenet_pretrain)
if __name__ == '__main__': if __name__ == '__main__':