Beautify the YOLOv8 configuration (#516)

* Update yolov5_s-v61_syncbn_8xb16-300e_coco.py

* Update yolov8_s_syncbn_fast_8xb16-500e_coco.py

* Update yolov8_m_syncbn_fast_8xb16-500e_coco.py

* Update yolov8_l_syncbn_fast_8xb16-500e_coco.py

* Update yolov8_s_syncbn_fast_8xb16-500e_coco.py

* Add todo

* Update yolov8_s_syncbn_fast_8xb16-500e_coco.py

* Update transforms.py
pull/517/head
Range King 2023-02-06 19:39:39 +08:00 committed by GitHub
parent f54e5603fd
commit c3acf42db4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 140 additions and 77 deletions

View File

@ -69,7 +69,7 @@ widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio

View File

@ -1,10 +1,13 @@
_base_ = './yolov8_m_syncbn_fast_8xb16-500e_coco.py'
# ========================modified parameters======================
deepen_factor = 1.00
widen_factor = 1.00
last_stage_out_channels = 512
mixup_ratio = 0.15
mixup_prob = 0.15
# =======================Unmodified in most cases==================
model = dict(
backbone=dict(
last_stage_out_channels=last_stage_out_channels,
@ -22,15 +25,15 @@ model = dict(
pre_transform = _base_.pre_transform
albu_train_transform = _base_.albu_train_transform
mosaic_affine_transform = _base_.mosaic_affine_transform
mosaic_affine_pipeline = _base_.mosaic_affine_pipeline
last_transform = _base_.last_transform
train_pipeline = [
*pre_transform, *mosaic_affine_transform,
*pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=mixup_ratio,
pre_transform=[*pre_transform, *mosaic_affine_transform]),
prob=mixup_prob,
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
*last_transform
]

View File

@ -1,12 +1,14 @@
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75
last_stage_out_channels = 768
affine_scale = 0.9
mixup_ratio = 0.1
mixup_prob = 0.1
# =======================Unmodified in most cases==================
num_classes = _base_.num_classes
num_det_layers = _base_.num_det_layers
img_scale = _base_.img_scale
@ -30,7 +32,7 @@ pre_transform = _base_.pre_transform
albu_train_transform = _base_.albu_train_transform
last_transform = _base_.last_transform
mosaic_affine_transform = [
mosaic_affine_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
@ -47,12 +49,13 @@ mosaic_affine_transform = [
border_val=(114, 114, 114))
]
# enable mixup
train_pipeline = [
*pre_transform, *mosaic_affine_transform,
*pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=mixup_ratio,
pre_transform=[*pre_transform, *mosaic_affine_transform]),
prob=mixup_prob,
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
*last_transform
]
@ -85,6 +88,6 @@ custom_hooks = [
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=_base_.max_epochs - 10,
switch_epoch=_base_.max_epochs - _base_.close_mosaic_epochs,
switch_pipeline=train_pipeline_stage2)
]

View File

@ -1,37 +1,100 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path
# parameters that often need to be modified
num_classes = 80
img_scale = (640, 640) # height, width
deepen_factor = 0.33
widen_factor = 0.5
max_epochs = 500
save_epoch_intervals = 10
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
val_batch_size_per_gpu = 1
val_num_workers = 2
# persistent_workers must be False if num_workers is 0.
# persistent_workers must be False if num_workers is 0
persistent_workers = True
strides = [8, 16, 32]
num_det_layers = 3
last_stage_out_channels = 1024
# Base learning rate for optim_wrapper
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.01
lr_factor = 0.01
max_epochs = 500 # Maximum training epochs
# Disable mosaic augmentation for final 10 epochs (stage 2)
close_mosaic_epochs = 10
# single-scale training is recommended to
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.7), # NMS type and threshold
max_per_img=300) # Max number of detections of each image
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# We tested YOLOv8-m will get 0.02 higher than not using it.
batch_shapes_cfg = None
# You can turn on `batch_shapes_cfg` by uncommenting the following lines.
# batch_shapes_cfg = dict(
# type='BatchShapePolicy',
# batch_size=val_batch_size_per_gpu,
# img_size=img_scale[0],
# # The image scale of padding should be divided by pad_size_divisor
# size_divisor=32,
# # Additional paddings for pixel scale
# extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
# The output channel of the last stage
last_stage_out_channels = 1024
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
# YOLOv5RandomAffine aspect ratio of width and height thres to filter bboxes
max_aspect_ratio = 100
tal_topk = 10 # Number of bbox selected in each level
tal_alpha = 0.5 # A Hyper-parameter related to alignment_metrics
tal_beta = 6.0 # A Hyper-parameter related to alignment_metrics
# TODO: Automatically scale loss_weight based on number of detection layers
loss_cls_weight = 0.5
loss_bbox_weight = 7.5
# Since the dfloss is implemented differently in the official
# and mmdet, we're going to divide loss_weight by 4.
loss_dfl_weight = 1.5 / 4
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals in stage 1
save_epoch_intervals = 10
# validation intervals in stage 2
val_interval_stage2 = 1
# The maximum checkpoints to keep.
max_keep_ckpts = 2
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
@ -45,7 +108,7 @@ model = dict(
last_stage_out_channels=last_stage_out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='YOLOv8PAFPN',
@ -54,7 +117,7 @@ model = dict(
in_channels=[256, 512, last_stage_out_channels],
out_channels=[256, 512, last_stage_out_channels],
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv8Head',
@ -64,45 +127,39 @@ model = dict(
in_channels=[256, 512, last_stage_out_channels],
widen_factor=widen_factor,
reg_max=16,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32]),
featmap_strides=strides),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0.5, strides=[8, 16, 32]),
type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
bbox_coder=dict(type='DistancePointBBoxCoder'),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='none',
loss_weight=0.5),
loss_weight=loss_cls_weight),
loss_bbox=dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xyxy',
reduction='sum',
loss_weight=7.5,
loss_weight=loss_bbox_weight,
return_iou=False),
# Since the dfloss is implemented differently in the official
# and mmdet, we're going to divide loss_weight by 4.
loss_dfl=dict(
type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=1.5 / 4)),
loss_weight=loss_dfl_weight)),
train_cfg=dict(
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
use_ciou=True,
topk=10,
alpha=0.5,
beta=6.0,
topk=tal_topk,
alpha=tal_alpha,
beta=tal_beta,
eps=1e-9)),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.7),
max_per_img=300))
test_cfg=model_test_cfg)
albu_train_transform = [
dict(type='Blur', p=0.01),
@ -135,6 +192,7 @@ last_transform = [
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_pipeline = [
*pre_transform,
dict(
@ -146,8 +204,8 @@ train_pipeline = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(0.5, 1.5),
max_aspect_ratio=100,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_aspect_ratio=max_aspect_ratio,
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
@ -166,8 +224,8 @@ train_pipeline_stage2 = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(0.5, 1.5),
max_aspect_ratio=100,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_aspect_ratio=max_aspect_ratio,
border_val=(114, 114, 114)), *last_transform
]
@ -181,8 +239,8 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
@ -201,17 +259,6 @@ test_pipeline = [
'scale_factor', 'pad_param'))
]
# only on Val
# you can turn on `batch_shapes_cfg`,
# we tested YOLOv8-m will get 0.02 higher than not using it.
batch_shapes_cfg = None
# batch_shapes_cfg = dict(
# type='BatchShapePolicy',
# batch_size=val_batch_size_per_gpu,
# img_size=img_scale[0],
# size_divisor=32,
# extra_pad_ratio=0.5)
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
@ -223,8 +270,8 @@ val_dataloader = dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
@ -238,7 +285,7 @@ optim_wrapper = dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=0.0005,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
@ -253,7 +300,7 @@ default_hooks = dict(
type='CheckpointHook',
interval=save_epoch_intervals,
save_best='auto',
max_keep_ckpts=2))
max_keep_ckpts=max_keep_ckpts))
custom_hooks = [
dict(
@ -265,14 +312,14 @@ custom_hooks = [
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - 10,
switch_epoch=max_epochs - close_mosaic_epochs,
switch_pipeline=train_pipeline_stage2)
]
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
@ -280,7 +327,8 @@ train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - 10, 1)])
dynamic_intervals=[((max_epochs - close_mosaic_epochs),
val_interval_stage2)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -450,6 +450,15 @@ class YOLOv5RandomAffine(BaseTransform):
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
min_bbox_size (float): Width and height threshold to filter bboxes.
If the height or width of a box is smaller than this value, it
will be removed. Defaults to 2.
min_area_ratio (float): Threshold of area ratio between
original bboxes and wrapped bboxes. If smaller than this value,
the box will be removed. Defaults to 0.1.
max_aspect_ratio (float): Aspect ratio of width and height
threshold to filter bboxes. If max(h/w, w/h) larger than this
value, the box will be removed. Defaults to 20.
"""
def __init__(self,