[Improve] Beautify the YOLOX configuration (#529)

* Beautify the YOLOX configuration

* fix checks

* Update configs/yolox/yolox_s_fast_8xb8-300e_coco.py

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>

* fix letter case problem

* beauty yolox configs except yolox_s's config

* fix lint

* Update configs/yolox/yolox_s_fast_8xb8-300e_coco.py

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

* fix yolox_s yolox_tiny

* fix tiny

* fix tiny

* simple tiny

---------

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
pull/547/head
Youfu 2023-02-10 10:50:59 +08:00 committed by GitHub
parent 4e8bf17c90
commit 3a6899e232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 37 deletions

View File

@ -1,8 +1,10 @@
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 1.0
widen_factor = 1.0
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),

View File

@ -1,8 +1,10 @@
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),

View File

@ -1,9 +1,11 @@
_base_ = './yolox_tiny_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.25
use_depthwise = True
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(

View File

@ -1,21 +1,64 @@
_base_ = '../_base_/default_runtime.py'
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of train image path
img_scale = (640, 640) # width, height
deepen_factor = 0.33
widen_factor = 0.5
save_epoch_intervals = 10
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 8
# Worker to pre-fetch data for each single GPU during tarining
train_num_workers = 8
# Presistent_workers must be False if num_workers is 0
persistent_workers = True
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
model_test_cfg = dict(
yolox_style=True, # better
# The config of multi-label for multi-class prediction
multi_label=True, # 40.5 -> 40.7
score_thr=0.001, # Threshold to filter out boxes
max_per_img=300, # Max number of detections of each image
nms=dict(type='nms', iou_threshold=0.65)) # NMS type and threshold
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
max_epochs = 300
num_last_epochs = 15
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
# -----train val related-----
weight_decay = 0.0005
num_last_epochs = 15
random_affine_scaling_ratio_range = (0.1, 2)
mixup_ratio_range = (0.8, 1.6)
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# ===============================Unmodified in most cases====================
# model settings
model = dict(
type='YOLODetector',
@ -44,7 +87,7 @@ model = dict(
widen_factor=widen_factor,
out_indices=(2, 3, 4),
spp_kernal_sizes=(5, 9, 13),
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
),
neck=dict(
@ -53,20 +96,20 @@ model = dict(
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=256,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOXHead',
head_module=dict(
type='YOLOXHeadModule',
num_classes=80,
num_classes=num_classes,
in_channels=256,
feat_channels=256,
widen_factor=widen_factor,
stacked_convs=2,
featmap_strides=(8, 16, 32),
use_depthwise=False,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
),
loss_cls=dict(
@ -92,12 +135,7 @@ model = dict(
type='mmdet.SimOTAAssigner',
center_radius=2.5,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'))),
test_cfg=dict(
yolox_style=True, # better
multi_label=True, # 40.5 -> 40.7
score_thr=0.001,
max_per_img=300,
nms=dict(type='nms', iou_threshold=0.65)))
test_cfg=model_test_cfg)
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
@ -113,13 +151,13 @@ train_pipeline_stage1 = [
pre_transform=pre_transform),
dict(
type='mmdet.RandomAffine',
scaling_ratio_range=(0.1, 2),
scaling_ratio_range=random_affine_scaling_ratio_range,
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
ratio_range=mixup_ratio_range,
pad_val=114.0,
pre_transform=pre_transform),
dict(type='mmdet.YOLOXHSVRandomAug'),
@ -155,15 +193,15 @@ train_pipeline_stage2 = [
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=True,
persistent_workers=persistent_workers,
pin_memory=True,
collate_fn=dict(type='yolov5_collate'),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline_stage1))
@ -184,15 +222,15 @@ test_pipeline = [
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=True,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
ann_file=val_ann_file,
data_prefix=dict(img=val_data_prefix),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader
@ -201,18 +239,20 @@ test_dataloader = val_dataloader
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
# optimizer
# default 8 gpu
base_lr = 0.01
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
type='SGD',
lr=base_lr,
momentum=0.9,
weight_decay=weight_decay,
nesterov=True),
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
@ -248,7 +288,10 @@ param_scheduler = [
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook', interval=1, max_keep_ckpts=3, save_best='auto'))
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=max_keep_ckpts,
save_best='auto'))
custom_hooks = [
dict(

View File

@ -1,14 +1,20 @@
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.33
widen_factor = 0.375
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
scaling_ratio_range = (0.5, 1.5)
# =======================Unmodified in most cases==================
# model settings
model = dict(
data_preprocessor=dict(batch_augments=[
dict(
type='mmdet.BatchSyncRandomResize',
random_size_range=(320, 640), # note
type='YOLOXBatchSyncRandomResize',
random_size_range=(320, 640),
size_divisor=32,
interval=10)
]),
@ -16,9 +22,6 @@ model = dict(
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
train_pipeline_stage1 = [
*pre_transform,
dict(
@ -28,7 +31,7 @@ train_pipeline_stage1 = [
pre_transform=pre_transform),
dict(
type='mmdet.RandomAffine',
scaling_ratio_range=(0.5, 1.5), # note
scaling_ratio_range=scaling_ratio_range, # note
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(type='mmdet.YOLOXHSVRandomAug'),

View File

@ -1,8 +1,10 @@
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 1.33
widen_factor = 1.25
# =======================Unmodified in most cases==================
# model settings
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),