[Enchance] Optimize and accelerate YOLOX with RTMDet hyps (#542)

* enchance yolox

* update

* update

* fix

* fix

* fix lint
pull/571/head
Haian Huang(深度眸) 2023-02-17 11:27:03 +08:00 committed by GitHub
parent 8cdc741fd3
commit 6400fba1af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 309 additions and 20 deletions

View File

@ -26,7 +26,7 @@ anchors = [
[(142, 110), (192, 243), (459, 401)] # P5/32
]
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs

View File

@ -17,12 +17,28 @@ In this report, we present some experienced improvements to YOLO series, forming
YOLOX-l model structure
</div>
## Results and Models
## 🥳 🚀 Results and Models
| Backbone | size | Mem (GB) | box AP | Config | Download |
| :--------: | :--: | :------: | :----: | :------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/main/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/main/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
| Backbone | Size | Batch Size | AMP | RTMDet-Hyp | Mem (GB) | Box AP | Config | Download |
| :--------: | :--: | :--------: | :-: | :--------: | :------: | :---------: | :-----------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOX-tiny | 416 | 8xb8 | No | No | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
| YOLOX-tiny | 416 | 8xb32 | Yes | Yes | 4.9 | 34.3 (+1.6) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637-4c338102.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637.log.json) |
| YOLOX-s | 640 | 8xb8 | Yes | No | 2.9 | 40.7 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600-2b224d8b.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600.log.json) |
| YOLOX-s | 640 | 8xb32 | Yes | Yes | 9.8 | 41.9 (+1.2) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645-3a8dfbd7.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645.log.json) |
| YOLOX-m | 640 | 8xb8 | Yes | No | 4.9 | 46.9 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_m_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218-a71a6b25.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218.log.json) |
| YOLOX-m | 640 | 8xb32 | Yes | Yes | 17.6 | 47.5 (+0.6) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328.log.json) |
| YOLOX-l | 640 | 8xb8 | Yes | No | 8.0 | 50.1 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_l_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast__8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715-c731eb1c.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715.log.json) |
| YOLOX-x | 640 | 8xb8 | Yes | No | 9.8 | 51.4 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_x_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950-1d509fab.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950.log.json) |
YOLOX uses a default training configuration of `8xbs8` which results in a long training time, we expect it to use `8xbs32` to speed up the training and not cause a decrease in mAP. We modified `train_batch_size_per_gpu` from 8 to 32, `batch_augments_interval` from 10 to 1 and `base_lr` from 0.01 to 0.04 under YOLOX-s default configuration based on the linear scaling rule, which resulted in mAP degradation. Finally, I found that using RTMDet's training hyperparameter can improve performance in YOLOX Tiny/S/M, which also validates the superiority of RTMDet's training hyperparameter.
The modified training parameters are as follows
1. train_batch_size_per_gpu: 8 -> 32
2. batch_augments_interval: 10 -> 1
3. num_last_epochs: 15 -> 20
4. optim cfg: SGD -> AdamW, base_lr 0.01 -> 0.004, weight_decay 0.0005 -> 0.05
5. ema momentum: 0.0001 -> 0.0002
**Note**:

View File

@ -36,11 +36,83 @@ Models:
In Collection: YOLOX
Config: configs/yolox/yolox_s_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 5.6
Training Memory (GB): 2.9
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.8
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth
box AP: 40.7
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600-2b224d8b.pth
- Name: yolox_m_fast_8xb8-300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_m_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 4.9
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 46.9
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218-a71a6b25.pth
- Name: yolox_l_fast_8xb8-300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_l_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 8.0
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 50.1
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715-c731eb1c.pth
- Name: yolox_x_fast_8xb8-300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_x_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 9.8
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 51.4
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950-1d509fab.pth
- Name: yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco
In Collection: YOLOX
Config: configs/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py
Metadata:
Training Memory (GB): 4.9
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 34.3
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637-4c338102.pth
- Name: yolox_s_fast_8xb32-300e-rtmdet-hyp_coco
In Collection: YOLOX
Config: configs/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py
Metadata:
Training Memory (GB): 9.8
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 41.9
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645-3a8dfbd7.pth
- Name: yolox_m_fast_8xb32-300e-rtmdet-hyp_coco
In Collection: YOLOX
Config: configs/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco.py
Metadata:
Training Memory (GB): 17.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 47.5
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth

View File

@ -0,0 +1,12 @@
_base_ = './yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py'
# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

View File

@ -0,0 +1,21 @@
_base_ = './yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py'
# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.25
use_depthwise = True
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
use_depthwise=use_depthwise),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
use_depthwise=use_depthwise),
bbox_head=dict(
head_module=dict(
widen_factor=widen_factor, use_depthwise=use_depthwise)))

View File

@ -0,0 +1,87 @@
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
# Batch size of a single GPU during training
# 8 -> 32
train_batch_size_per_gpu = 32
# Multi-scale training intervals
# 10 -> 1
batch_augments_interval = 1
# Last epoch number to switch training pipeline
# 15 -> 20
num_last_epochs = 20
# Base learning rate for optim_wrapper. Corresponding to 8xb32=256 bs
base_lr = 0.004
# SGD -> AdamW
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
# 0.0001 -> 0.0002
ema_momentum = 0.0002
# ============================== Unmodified in most cases ===================
model = dict(
data_preprocessor=dict(batch_augments=[
dict(
type='YOLOXBatchSyncRandomResize',
random_size_range=(480, 800),
size_divisor=32,
interval=batch_augments_interval)
]))
param_scheduler = [
dict(
# use quadratic formula to warm up 5 epochs
# and lr is updated by iteration
# TODO: fix default scope in get function
type='mmdet.QuadraticWarmupLR',
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
# use cosine lr from 5 to 285 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=5,
T_max=_base_.max_epochs - num_last_epochs,
end=_base_.max_epochs - num_last_epochs,
by_epoch=True,
convert_to_iter_based=True),
dict(
# use fixed lr during last num_last_epochs epochs
type='ConstantLR',
by_epoch=True,
factor=1,
begin=_base_.max_epochs - num_last_epochs,
end=_base_.max_epochs,
)
]
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
new_train_pipeline=_base_.train_pipeline_stage2,
priority=48),
dict(type='mmdet.SyncNormHook', priority=48),
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=ema_momentum,
update_buffers=True,
strict_load=False,
priority=49)
]
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
train_cfg = dict(dynamic_intervals=[(_base_.max_epochs - num_last_epochs, 1)])
auto_scale_lr = dict(base_batch_size=8 * train_batch_size_per_gpu)

View File

@ -47,9 +47,16 @@ deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
# generate new random resize shape interval
batch_augments_interval = 10
# -----train val related-----
weight_decay = 0.0005
loss_cls_weight = 1.0
loss_bbox_weight = 5.0
loss_obj_weight = 1.0
loss_bbox_aux_weight = 1.0
center_radius = 2.5 # SimOTAAssigner
num_last_epochs = 15
random_affine_scaling_ratio_range = (0.1, 2)
mixup_ratio_range = (0.8, 1.6)
@ -58,6 +65,8 @@ save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
ema_momentum = 0.0001
# ===============================Unmodified in most cases====================
# model settings
model = dict(
@ -79,7 +88,7 @@ model = dict(
type='YOLOXBatchSyncRandomResize',
random_size_range=(480, 800),
size_divisor=32,
interval=10)
interval=batch_augments_interval)
]),
backbone=dict(
type='YOLOXCSPDarknet',
@ -116,24 +125,26 @@ model = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_weight=loss_cls_weight),
loss_bbox=dict(
type='mmdet.IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_weight=loss_bbox_weight),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_weight=loss_obj_weight),
loss_bbox_aux=dict(
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0)),
type='mmdet.L1Loss',
reduction='sum',
loss_weight=loss_bbox_aux_weight)),
train_cfg=dict(
assigner=dict(
type='mmdet.SimOTAAssigner',
center_radius=2.5,
center_radius=center_radius,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'))),
test_cfg=model_test_cfg)
@ -303,7 +314,7 @@ custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
momentum=ema_momentum,
update_buffers=True,
strict_load=False,
priority=49)
@ -315,6 +326,6 @@ train_cfg = dict(
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
auto_scale_lr = dict(base_batch_size=64)
auto_scale_lr = dict(base_batch_size=8 * train_batch_size_per_gpu)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -0,0 +1,70 @@
_base_ = './yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py'
# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.375
# Multi-scale training intervals
# 10 -> 1
batch_augments_interval = 1
scaling_ratio_range = (0.5, 1.5)
# =======================Unmodified in most cases==================
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
# model settings
model = dict(
data_preprocessor=dict(batch_augments=[
dict(
type='YOLOXBatchSyncRandomResize',
random_size_range=(320, 640),
size_divisor=32,
interval=batch_augments_interval)
]),
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
train_pipeline_stage1 = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='mmdet.RandomAffine',
scaling_ratio_range=scaling_ratio_range, # note
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.FilterAnnotations',
min_gt_bbox_wh=(1, 1),
keep_empty=False),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='mmdet.Resize', scale=(416, 416), keep_ratio=True), # note
dict(
type='mmdet.Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline_stage1))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

View File

@ -3,12 +3,12 @@ _base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.375
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
scaling_ratio_range = (0.5, 1.5)
# =======================Unmodified in most cases==================
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
# model settings
model = dict(
data_preprocessor=dict(batch_augments=[