mirror of https://github.com/open-mmlab/mmyolo.git
Beautify the YOLOv8 configuration (#516)
* Update yolov5_s-v61_syncbn_8xb16-300e_coco.py * Update yolov8_s_syncbn_fast_8xb16-500e_coco.py * Update yolov8_m_syncbn_fast_8xb16-500e_coco.py * Update yolov8_l_syncbn_fast_8xb16-500e_coco.py * Update yolov8_s_syncbn_fast_8xb16-500e_coco.py * Add todo * Update yolov8_s_syncbn_fast_8xb16-500e_coco.py * Update transforms.pypull/517/head
parent
f54e5603fd
commit
c3acf42db4
|
@ -69,7 +69,7 @@ widen_factor = 0.5
|
|||
# Strides of multi-scale prior box
|
||||
strides = [8, 16, 32]
|
||||
num_det_layers = 3 # The number of model output scales
|
||||
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
|
||||
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
|
||||
|
||||
# -----train val related-----
|
||||
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
_base_ = './yolov8_m_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
deepen_factor = 1.00
|
||||
widen_factor = 1.00
|
||||
last_stage_out_channels = 512
|
||||
mixup_ratio = 0.15
|
||||
|
||||
mixup_prob = 0.15
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
last_stage_out_channels=last_stage_out_channels,
|
||||
|
@ -22,15 +25,15 @@ model = dict(
|
|||
|
||||
pre_transform = _base_.pre_transform
|
||||
albu_train_transform = _base_.albu_train_transform
|
||||
mosaic_affine_transform = _base_.mosaic_affine_transform
|
||||
mosaic_affine_pipeline = _base_.mosaic_affine_pipeline
|
||||
last_transform = _base_.last_transform
|
||||
|
||||
train_pipeline = [
|
||||
*pre_transform, *mosaic_affine_transform,
|
||||
*pre_transform, *mosaic_affine_pipeline,
|
||||
dict(
|
||||
type='YOLOv5MixUp',
|
||||
prob=mixup_ratio,
|
||||
pre_transform=[*pre_transform, *mosaic_affine_transform]),
|
||||
prob=mixup_prob,
|
||||
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
|
||||
*last_transform
|
||||
]
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
# ========================modified parameters======================
|
||||
deepen_factor = 0.67
|
||||
widen_factor = 0.75
|
||||
last_stage_out_channels = 768
|
||||
|
||||
affine_scale = 0.9
|
||||
mixup_ratio = 0.1
|
||||
mixup_prob = 0.1
|
||||
|
||||
# =======================Unmodified in most cases==================
|
||||
num_classes = _base_.num_classes
|
||||
num_det_layers = _base_.num_det_layers
|
||||
img_scale = _base_.img_scale
|
||||
|
@ -30,7 +32,7 @@ pre_transform = _base_.pre_transform
|
|||
albu_train_transform = _base_.albu_train_transform
|
||||
last_transform = _base_.last_transform
|
||||
|
||||
mosaic_affine_transform = [
|
||||
mosaic_affine_pipeline = [
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=img_scale,
|
||||
|
@ -47,12 +49,13 @@ mosaic_affine_transform = [
|
|||
border_val=(114, 114, 114))
|
||||
]
|
||||
|
||||
# enable mixup
|
||||
train_pipeline = [
|
||||
*pre_transform, *mosaic_affine_transform,
|
||||
*pre_transform, *mosaic_affine_pipeline,
|
||||
dict(
|
||||
type='YOLOv5MixUp',
|
||||
prob=mixup_ratio,
|
||||
pre_transform=[*pre_transform, *mosaic_affine_transform]),
|
||||
prob=mixup_prob,
|
||||
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
|
||||
*last_transform
|
||||
]
|
||||
|
||||
|
@ -85,6 +88,6 @@ custom_hooks = [
|
|||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - 10,
|
||||
switch_epoch=_base_.max_epochs - _base_.close_mosaic_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
|
|
@ -1,37 +1,100 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
dataset_type = 'YOLOv5CocoDataset'
|
||||
# ========================Frequently modified parameters======================
|
||||
# -----data related-----
|
||||
data_root = 'data/coco/' # Root path of data
|
||||
# Path of train annotation file
|
||||
train_ann_file = 'annotations/instances_train2017.json'
|
||||
train_data_prefix = 'train2017/' # Prefix of train image path
|
||||
# Path of val annotation file
|
||||
val_ann_file = 'annotations/instances_val2017.json'
|
||||
val_data_prefix = 'val2017/' # Prefix of val image path
|
||||
|
||||
# parameters that often need to be modified
|
||||
num_classes = 80
|
||||
img_scale = (640, 640) # height, width
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.5
|
||||
max_epochs = 500
|
||||
save_epoch_intervals = 10
|
||||
num_classes = 80 # Number of classes for classification
|
||||
# Batch size of a single GPU during training
|
||||
train_batch_size_per_gpu = 16
|
||||
# Worker to pre-fetch data for each single GPU during training
|
||||
train_num_workers = 8
|
||||
val_batch_size_per_gpu = 1
|
||||
val_num_workers = 2
|
||||
|
||||
# persistent_workers must be False if num_workers is 0.
|
||||
# persistent_workers must be False if num_workers is 0
|
||||
persistent_workers = True
|
||||
|
||||
strides = [8, 16, 32]
|
||||
num_det_layers = 3
|
||||
|
||||
last_stage_out_channels = 1024
|
||||
|
||||
# Base learning rate for optim_wrapper
|
||||
# -----train val related-----
|
||||
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
|
||||
base_lr = 0.01
|
||||
lr_factor = 0.01
|
||||
max_epochs = 500 # Maximum training epochs
|
||||
# Disable mosaic augmentation for final 10 epochs (stage 2)
|
||||
close_mosaic_epochs = 10
|
||||
|
||||
# single-scale training is recommended to
|
||||
model_test_cfg = dict(
|
||||
# The config of multi-label for multi-class prediction.
|
||||
multi_label=True,
|
||||
# The number of boxes before NMS
|
||||
nms_pre=30000,
|
||||
score_thr=0.001, # Threshold to filter out boxes.
|
||||
nms=dict(type='nms', iou_threshold=0.7), # NMS type and threshold
|
||||
max_per_img=300) # Max number of detections of each image
|
||||
|
||||
# ========================Possible modified parameters========================
|
||||
# -----data related-----
|
||||
img_scale = (640, 640) # width, height
|
||||
# Dataset type, this will be used to define the dataset
|
||||
dataset_type = 'YOLOv5CocoDataset'
|
||||
# Batch size of a single GPU during validation
|
||||
val_batch_size_per_gpu = 1
|
||||
# Worker to pre-fetch data for each single GPU during validation
|
||||
val_num_workers = 2
|
||||
|
||||
# Config of batch shapes. Only on val.
|
||||
# We tested YOLOv8-m will get 0.02 higher than not using it.
|
||||
batch_shapes_cfg = None
|
||||
# You can turn on `batch_shapes_cfg` by uncommenting the following lines.
|
||||
# batch_shapes_cfg = dict(
|
||||
# type='BatchShapePolicy',
|
||||
# batch_size=val_batch_size_per_gpu,
|
||||
# img_size=img_scale[0],
|
||||
# # The image scale of padding should be divided by pad_size_divisor
|
||||
# size_divisor=32,
|
||||
# # Additional paddings for pixel scale
|
||||
# extra_pad_ratio=0.5)
|
||||
|
||||
# -----model related-----
|
||||
# The scaling factor that controls the depth of the network structure
|
||||
deepen_factor = 0.33
|
||||
# The scaling factor that controls the width of the network structure
|
||||
widen_factor = 0.5
|
||||
# Strides of multi-scale prior box
|
||||
strides = [8, 16, 32]
|
||||
# The output channel of the last stage
|
||||
last_stage_out_channels = 1024
|
||||
num_det_layers = 3 # The number of model output scales
|
||||
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
|
||||
|
||||
# -----train val related-----
|
||||
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
|
||||
# YOLOv5RandomAffine aspect ratio of width and height thres to filter bboxes
|
||||
max_aspect_ratio = 100
|
||||
tal_topk = 10 # Number of bbox selected in each level
|
||||
tal_alpha = 0.5 # A Hyper-parameter related to alignment_metrics
|
||||
tal_beta = 6.0 # A Hyper-parameter related to alignment_metrics
|
||||
# TODO: Automatically scale loss_weight based on number of detection layers
|
||||
loss_cls_weight = 0.5
|
||||
loss_bbox_weight = 7.5
|
||||
# Since the dfloss is implemented differently in the official
|
||||
# and mmdet, we're going to divide loss_weight by 4.
|
||||
loss_dfl_weight = 1.5 / 4
|
||||
lr_factor = 0.01 # Learning rate scaling factor
|
||||
weight_decay = 0.0005
|
||||
# Save model checkpoint and validation intervals in stage 1
|
||||
save_epoch_intervals = 10
|
||||
# validation intervals in stage 2
|
||||
val_interval_stage2 = 1
|
||||
# The maximum checkpoints to keep.
|
||||
max_keep_ckpts = 2
|
||||
# Single-scale training is recommended to
|
||||
# be turned on, which can speed up training.
|
||||
env_cfg = dict(cudnn_benchmark=True)
|
||||
|
||||
# ===============================Unmodified in most cases====================
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
data_preprocessor=dict(
|
||||
|
@ -45,7 +108,7 @@ model = dict(
|
|||
last_stage_out_channels=last_stage_out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
neck=dict(
|
||||
type='YOLOv8PAFPN',
|
||||
|
@ -54,7 +117,7 @@ model = dict(
|
|||
in_channels=[256, 512, last_stage_out_channels],
|
||||
out_channels=[256, 512, last_stage_out_channels],
|
||||
num_csp_blocks=3,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
bbox_head=dict(
|
||||
type='YOLOv8Head',
|
||||
|
@ -64,45 +127,39 @@ model = dict(
|
|||
in_channels=[256, 512, last_stage_out_channels],
|
||||
widen_factor=widen_factor,
|
||||
reg_max=16,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='SiLU', inplace=True),
|
||||
featmap_strides=[8, 16, 32]),
|
||||
featmap_strides=strides),
|
||||
prior_generator=dict(
|
||||
type='mmdet.MlvlPointGenerator', offset=0.5, strides=[8, 16, 32]),
|
||||
type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
|
||||
bbox_coder=dict(type='DistancePointBBoxCoder'),
|
||||
# scaled based on number of detection layers
|
||||
loss_cls=dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='none',
|
||||
loss_weight=0.5),
|
||||
loss_weight=loss_cls_weight),
|
||||
loss_bbox=dict(
|
||||
type='IoULoss',
|
||||
iou_mode='ciou',
|
||||
bbox_format='xyxy',
|
||||
reduction='sum',
|
||||
loss_weight=7.5,
|
||||
loss_weight=loss_bbox_weight,
|
||||
return_iou=False),
|
||||
# Since the dfloss is implemented differently in the official
|
||||
# and mmdet, we're going to divide loss_weight by 4.
|
||||
loss_dfl=dict(
|
||||
type='mmdet.DistributionFocalLoss',
|
||||
reduction='mean',
|
||||
loss_weight=1.5 / 4)),
|
||||
loss_weight=loss_dfl_weight)),
|
||||
train_cfg=dict(
|
||||
assigner=dict(
|
||||
type='BatchTaskAlignedAssigner',
|
||||
num_classes=num_classes,
|
||||
use_ciou=True,
|
||||
topk=10,
|
||||
alpha=0.5,
|
||||
beta=6.0,
|
||||
topk=tal_topk,
|
||||
alpha=tal_alpha,
|
||||
beta=tal_beta,
|
||||
eps=1e-9)),
|
||||
test_cfg=dict(
|
||||
multi_label=True,
|
||||
nms_pre=30000,
|
||||
score_thr=0.001,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
max_per_img=300))
|
||||
test_cfg=model_test_cfg)
|
||||
|
||||
albu_train_transform = [
|
||||
dict(type='Blur', p=0.01),
|
||||
|
@ -135,6 +192,7 @@ last_transform = [
|
|||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
*pre_transform,
|
||||
dict(
|
||||
|
@ -146,8 +204,8 @@ train_pipeline = [
|
|||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_shear_degree=0.0,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
max_aspect_ratio=100,
|
||||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
|
||||
max_aspect_ratio=max_aspect_ratio,
|
||||
# img_scale is (width, height)
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2),
|
||||
border_val=(114, 114, 114)),
|
||||
|
@ -166,8 +224,8 @@ train_pipeline_stage2 = [
|
|||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_shear_degree=0.0,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
max_aspect_ratio=100,
|
||||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
|
||||
max_aspect_ratio=max_aspect_ratio,
|
||||
border_val=(114, 114, 114)), *last_transform
|
||||
]
|
||||
|
||||
|
@ -181,8 +239,8 @@ train_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/instances_train2017.json',
|
||||
data_prefix=dict(img='train2017/'),
|
||||
ann_file=train_ann_file,
|
||||
data_prefix=dict(img=train_data_prefix),
|
||||
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
|
@ -201,17 +259,6 @@ test_pipeline = [
|
|||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
# only on Val
|
||||
# you can turn on `batch_shapes_cfg`,
|
||||
# we tested YOLOv8-m will get 0.02 higher than not using it.
|
||||
batch_shapes_cfg = None
|
||||
# batch_shapes_cfg = dict(
|
||||
# type='BatchShapePolicy',
|
||||
# batch_size=val_batch_size_per_gpu,
|
||||
# img_size=img_scale[0],
|
||||
# size_divisor=32,
|
||||
# extra_pad_ratio=0.5)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
num_workers=val_num_workers,
|
||||
|
@ -223,8 +270,8 @@ val_dataloader = dict(
|
|||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
test_mode=True,
|
||||
data_prefix=dict(img='val2017/'),
|
||||
ann_file='annotations/instances_val2017.json',
|
||||
data_prefix=dict(img=val_data_prefix),
|
||||
ann_file=val_ann_file,
|
||||
pipeline=test_pipeline,
|
||||
batch_shapes_cfg=batch_shapes_cfg))
|
||||
|
||||
|
@ -238,7 +285,7 @@ optim_wrapper = dict(
|
|||
type='SGD',
|
||||
lr=base_lr,
|
||||
momentum=0.937,
|
||||
weight_decay=0.0005,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=True,
|
||||
batch_size_per_gpu=train_batch_size_per_gpu),
|
||||
constructor='YOLOv5OptimizerConstructor')
|
||||
|
@ -253,7 +300,7 @@ default_hooks = dict(
|
|||
type='CheckpointHook',
|
||||
interval=save_epoch_intervals,
|
||||
save_best='auto',
|
||||
max_keep_ckpts=2))
|
||||
max_keep_ckpts=max_keep_ckpts))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
|
@ -265,14 +312,14 @@ custom_hooks = [
|
|||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=max_epochs - 10,
|
||||
switch_epoch=max_epochs - close_mosaic_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
||||
val_evaluator = dict(
|
||||
type='mmdet.CocoMetric',
|
||||
proposal_nums=(100, 1, 10),
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
ann_file=data_root + val_ann_file,
|
||||
metric='bbox')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
|
@ -280,7 +327,8 @@ train_cfg = dict(
|
|||
type='EpochBasedTrainLoop',
|
||||
max_epochs=max_epochs,
|
||||
val_interval=save_epoch_intervals,
|
||||
dynamic_intervals=[(max_epochs - 10, 1)])
|
||||
dynamic_intervals=[((max_epochs - close_mosaic_epochs),
|
||||
val_interval_stage2)])
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
|
|
@ -450,6 +450,15 @@ class YOLOv5RandomAffine(BaseTransform):
|
|||
the border of the image. In some dataset like MOT17, the gt bboxes
|
||||
are allowed to cross the border of images. Therefore, we don't
|
||||
need to clip the gt bboxes in these cases. Defaults to True.
|
||||
min_bbox_size (float): Width and height threshold to filter bboxes.
|
||||
If the height or width of a box is smaller than this value, it
|
||||
will be removed. Defaults to 2.
|
||||
min_area_ratio (float): Threshold of area ratio between
|
||||
original bboxes and wrapped bboxes. If smaller than this value,
|
||||
the box will be removed. Defaults to 0.1.
|
||||
max_aspect_ratio (float): Aspect ratio of width and height
|
||||
threshold to filter bboxes. If max(h/w, w/h) larger than this
|
||||
value, the box will be removed. Defaults to 20.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
Loading…
Reference in New Issue