mirror of https://github.com/open-mmlab/mmyolo.git
[Enchance] Optimize and accelerate YOLOX with RTMDet hyps (#542)
* enchance yolox * update * update * fix * fix * fix lintpull/571/head
parent
8cdc741fd3
commit
6400fba1af
|
@ -26,7 +26,7 @@ anchors = [
|
|||
[(142, 110), (192, 243), (459, 401)] # P5/32
|
||||
]
|
||||
# -----train val related-----
|
||||
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
|
||||
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
|
||||
base_lr = 0.01
|
||||
max_epochs = 300 # Maximum training epochs
|
||||
|
||||
|
|
|
@ -17,12 +17,28 @@ In this report, we present some experienced improvements to YOLO series, forming
|
|||
YOLOX-l model structure
|
||||
</div>
|
||||
|
||||
## Results and Models
|
||||
## 🥳 🚀 Results and Models
|
||||
|
||||
| Backbone | size | Mem (GB) | box AP | Config | Download |
|
||||
| :--------: | :--: | :------: | :----: | :------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/main/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/main/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
|
||||
| Backbone | Size | Batch Size | AMP | RTMDet-Hyp | Mem (GB) | Box AP | Config | Download |
|
||||
| :--------: | :--: | :--------: | :-: | :--------: | :------: | :---------: | :-----------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOX-tiny | 416 | 8xb8 | No | No | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-tiny | 416 | 8xb32 | Yes | Yes | 4.9 | 34.3 (+1.6) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637-4c338102.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637.log.json) |
|
||||
| YOLOX-s | 640 | 8xb8 | Yes | No | 2.9 | 40.7 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600-2b224d8b.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600.log.json) |
|
||||
| YOLOX-s | 640 | 8xb32 | Yes | Yes | 9.8 | 41.9 (+1.2) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645-3a8dfbd7.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645.log.json) |
|
||||
| YOLOX-m | 640 | 8xb8 | Yes | No | 4.9 | 46.9 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_m_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218-a71a6b25.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218.log.json) |
|
||||
| YOLOX-m | 640 | 8xb32 | Yes | Yes | 17.6 | 47.5 (+0.6) | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328.log.json) |
|
||||
| YOLOX-l | 640 | 8xb8 | Yes | No | 8.0 | 50.1 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_l_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast__8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715-c731eb1c.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715.log.json) |
|
||||
| YOLOX-x | 640 | 8xb8 | Yes | No | 9.8 | 51.4 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolox/yolox_x_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950-1d509fab.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950.log.json) |
|
||||
|
||||
YOLOX uses a default training configuration of `8xbs8` which results in a long training time, we expect it to use `8xbs32` to speed up the training and not cause a decrease in mAP. We modified `train_batch_size_per_gpu` from 8 to 32, `batch_augments_interval` from 10 to 1 and `base_lr` from 0.01 to 0.04 under YOLOX-s default configuration based on the linear scaling rule, which resulted in mAP degradation. Finally, I found that using RTMDet's training hyperparameter can improve performance in YOLOX Tiny/S/M, which also validates the superiority of RTMDet's training hyperparameter.
|
||||
|
||||
The modified training parameters are as follows:
|
||||
|
||||
1. train_batch_size_per_gpu: 8 -> 32
|
||||
2. batch_augments_interval: 10 -> 1
|
||||
3. num_last_epochs: 15 -> 20
|
||||
4. optim cfg: SGD -> AdamW, base_lr 0.01 -> 0.004, weight_decay 0.0005 -> 0.05
|
||||
5. ema momentum: 0.0001 -> 0.0002
|
||||
|
||||
**Note**:
|
||||
|
||||
|
|
|
@ -36,11 +36,83 @@ Models:
|
|||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_s_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 5.6
|
||||
Training Memory (GB): 2.9
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 40.8
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth
|
||||
box AP: 40.7
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb8-300e_coco/yolox_s_fast_8xb8-300e_coco_20230213_142600-2b224d8b.pth
|
||||
- Name: yolox_m_fast_8xb8-300e_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_m_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 4.9
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 46.9
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb8-300e_coco/yolox_m_fast_8xb8-300e_coco_20230213_160218-a71a6b25.pth
|
||||
- Name: yolox_l_fast_8xb8-300e_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_l_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 8.0
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 50.1
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_coco_20230213_160715-c731eb1c.pth
|
||||
- Name: yolox_x_fast_8xb8-300e_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_x_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 9.8
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 51.4
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_x_fast_8xb8-300e_coco/yolox_x_fast_8xb8-300e_coco_20230215_133950-1d509fab.pth
|
||||
- Name: yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 4.9
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 34.3
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco/yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco_20230210_143637-4c338102.pth
|
||||
- Name: yolox_s_fast_8xb32-300e-rtmdet-hyp_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 9.8
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 41.9
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645-3a8dfbd7.pth
|
||||
- Name: yolox_m_fast_8xb32-300e-rtmdet-hyp_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 17.6
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 47.5
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = './yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
deepen_factor = 0.67
|
||||
widen_factor = 0.75
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
|
@ -0,0 +1,21 @@
|
|||
_base_ = './yolox_tiny_fast_8xb32-300e-rtmdet-hyp_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.25
|
||||
use_depthwise = True
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
use_depthwise=use_depthwise),
|
||||
neck=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
use_depthwise=use_depthwise),
|
||||
bbox_head=dict(
|
||||
head_module=dict(
|
||||
widen_factor=widen_factor, use_depthwise=use_depthwise)))
|
|
@ -0,0 +1,87 @@
|
|||
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
# Batch size of a single GPU during training
|
||||
# 8 -> 32
|
||||
train_batch_size_per_gpu = 32
|
||||
|
||||
# Multi-scale training intervals
|
||||
# 10 -> 1
|
||||
batch_augments_interval = 1
|
||||
|
||||
# Last epoch number to switch training pipeline
|
||||
# 15 -> 20
|
||||
num_last_epochs = 20
|
||||
|
||||
# Base learning rate for optim_wrapper. Corresponding to 8xb32=256 bs
|
||||
base_lr = 0.004
|
||||
|
||||
# SGD -> AdamW
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
||||
|
||||
# 0.0001 -> 0.0002
|
||||
ema_momentum = 0.0002
|
||||
|
||||
# ============================== Unmodified in most cases ===================
|
||||
model = dict(
|
||||
data_preprocessor=dict(batch_augments=[
|
||||
dict(
|
||||
type='YOLOXBatchSyncRandomResize',
|
||||
random_size_range=(480, 800),
|
||||
size_divisor=32,
|
||||
interval=batch_augments_interval)
|
||||
]))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
# use quadratic formula to warm up 5 epochs
|
||||
# and lr is updated by iteration
|
||||
# TODO: fix default scope in get function
|
||||
type='mmdet.QuadraticWarmupLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
# use cosine lr from 5 to 285 epoch
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=base_lr * 0.05,
|
||||
begin=5,
|
||||
T_max=_base_.max_epochs - num_last_epochs,
|
||||
end=_base_.max_epochs - num_last_epochs,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
# use fixed lr during last num_last_epochs epochs
|
||||
type='ConstantLR',
|
||||
by_epoch=True,
|
||||
factor=1,
|
||||
begin=_base_.max_epochs - num_last_epochs,
|
||||
end=_base_.max_epochs,
|
||||
)
|
||||
]
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='YOLOXModeSwitchHook',
|
||||
num_last_epochs=num_last_epochs,
|
||||
new_train_pipeline=_base_.train_pipeline_stage2,
|
||||
priority=48),
|
||||
dict(type='mmdet.SyncNormHook', priority=48),
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=ema_momentum,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49)
|
||||
]
|
||||
|
||||
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
|
||||
train_cfg = dict(dynamic_intervals=[(_base_.max_epochs - num_last_epochs, 1)])
|
||||
auto_scale_lr = dict(base_batch_size=8 * train_batch_size_per_gpu)
|
|
@ -47,9 +47,16 @@ deepen_factor = 0.33
|
|||
# The scaling factor that controls the width of the network structure
|
||||
widen_factor = 0.5
|
||||
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
|
||||
# generate new random resize shape interval
|
||||
batch_augments_interval = 10
|
||||
|
||||
# -----train val related-----
|
||||
weight_decay = 0.0005
|
||||
loss_cls_weight = 1.0
|
||||
loss_bbox_weight = 5.0
|
||||
loss_obj_weight = 1.0
|
||||
loss_bbox_aux_weight = 1.0
|
||||
center_radius = 2.5 # SimOTAAssigner
|
||||
num_last_epochs = 15
|
||||
random_affine_scaling_ratio_range = (0.1, 2)
|
||||
mixup_ratio_range = (0.8, 1.6)
|
||||
|
@ -58,6 +65,8 @@ save_epoch_intervals = 10
|
|||
# The maximum checkpoints to keep.
|
||||
max_keep_ckpts = 3
|
||||
|
||||
ema_momentum = 0.0001
|
||||
|
||||
# ===============================Unmodified in most cases====================
|
||||
# model settings
|
||||
model = dict(
|
||||
|
@ -79,7 +88,7 @@ model = dict(
|
|||
type='YOLOXBatchSyncRandomResize',
|
||||
random_size_range=(480, 800),
|
||||
size_divisor=32,
|
||||
interval=10)
|
||||
interval=batch_augments_interval)
|
||||
]),
|
||||
backbone=dict(
|
||||
type='YOLOXCSPDarknet',
|
||||
|
@ -116,24 +125,26 @@ model = dict(
|
|||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
loss_weight=loss_cls_weight),
|
||||
loss_bbox=dict(
|
||||
type='mmdet.IoULoss',
|
||||
mode='square',
|
||||
eps=1e-16,
|
||||
reduction='sum',
|
||||
loss_weight=5.0),
|
||||
loss_weight=loss_bbox_weight),
|
||||
loss_obj=dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
loss_weight=loss_obj_weight),
|
||||
loss_bbox_aux=dict(
|
||||
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0)),
|
||||
type='mmdet.L1Loss',
|
||||
reduction='sum',
|
||||
loss_weight=loss_bbox_aux_weight)),
|
||||
train_cfg=dict(
|
||||
assigner=dict(
|
||||
type='mmdet.SimOTAAssigner',
|
||||
center_radius=2.5,
|
||||
center_radius=center_radius,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D'))),
|
||||
test_cfg=model_test_cfg)
|
||||
|
||||
|
@ -303,7 +314,7 @@ custom_hooks = [
|
|||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0001,
|
||||
momentum=ema_momentum,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49)
|
||||
|
@ -315,6 +326,6 @@ train_cfg = dict(
|
|||
val_interval=save_epoch_intervals,
|
||||
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
|
||||
|
||||
auto_scale_lr = dict(base_batch_size=64)
|
||||
auto_scale_lr = dict(base_batch_size=8 * train_batch_size_per_gpu)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
_base_ = './yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.375
|
||||
|
||||
# Multi-scale training intervals
|
||||
# 10 -> 1
|
||||
batch_augments_interval = 1
|
||||
|
||||
scaling_ratio_range = (0.5, 1.5)
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
img_scale = _base_.img_scale
|
||||
pre_transform = _base_.pre_transform
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
data_preprocessor=dict(batch_augments=[
|
||||
dict(
|
||||
type='YOLOXBatchSyncRandomResize',
|
||||
random_size_range=(320, 640),
|
||||
size_divisor=32,
|
||||
interval=batch_augments_interval)
|
||||
]),
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
||||
|
||||
train_pipeline_stage1 = [
|
||||
*pre_transform,
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=img_scale,
|
||||
pad_val=114.0,
|
||||
pre_transform=pre_transform),
|
||||
dict(
|
||||
type='mmdet.RandomAffine',
|
||||
scaling_ratio_range=scaling_ratio_range, # note
|
||||
# img_scale is (width, height)
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='mmdet.FilterAnnotations',
|
||||
min_gt_bbox_wh=(1, 1),
|
||||
keep_empty=False),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='mmdet.Resize', scale=(416, 416), keep_ratio=True), # note
|
||||
dict(
|
||||
type='mmdet.Pad',
|
||||
pad_to_square=True,
|
||||
pad_val=dict(img=(114.0, 114.0, 114.0))),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline_stage1))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
|
@ -3,12 +3,12 @@ _base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
|||
# ========================modified parameters======================
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.375
|
||||
|
||||
img_scale = _base_.img_scale
|
||||
pre_transform = _base_.pre_transform
|
||||
scaling_ratio_range = (0.5, 1.5)
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
img_scale = _base_.img_scale
|
||||
pre_transform = _base_.pre_transform
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
data_preprocessor=dict(batch_augments=[
|
||||
|
|
Loading…
Reference in New Issue