Beautify the YOLOv5 configuration (#501)

* refactor_config
pull/504/head
Haian Huang(深度眸) 2023-02-03 14:28:35 +08:00 committed by GitHub
parent 74558aa2f7
commit 79f0aae555
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 90 deletions

View File

@ -1,10 +1,15 @@
_base_ = './yolov5_s-p6-v62_syncbn_fast_8xb16-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75
lr_factor = 0.1 # lrf=0.1
lr_factor = 0.1
affine_scale = 0.9
loss_cls_weight = 0.3
loss_obj_weight = 0.7
mixup_prob = 0.1
# =======================Unmodified in most cases==================
num_classes = _base_.num_classes
num_det_layers = _base_.num_det_layers
img_scale = _base_.img_scale
@ -20,9 +25,9 @@ model = dict(
),
bbox_head=dict(
head_module=dict(widen_factor=widen_factor),
loss_cls=dict(loss_weight=0.3 *
loss_cls=dict(loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
loss_obj=dict(loss_weight=0.7 *
loss_obj=dict(loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers))))
pre_transform = _base_.pre_transform
@ -49,7 +54,7 @@ train_pipeline = [
*pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=0.1,
prob=mixup_prob,
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
dict(
type='mmdet.Albu',
@ -71,5 +76,4 @@ train_pipeline = [
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
default_hooks = dict(param_scheduler=dict(lr_factor=lr_factor))

View File

@ -1,10 +1,15 @@
_base_ = './yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
# ========================modified parameters======================
deepen_factor = 0.67
widen_factor = 0.75
lr_factor = 0.1 # lrf=0.1
lr_factor = 0.1
affine_scale = 0.9
loss_cls_weight = 0.3
loss_obj_weight = 0.7
mixup_prob = 0.1
# =======================Unmodified in most cases==================
num_classes = _base_.num_classes
num_det_layers = _base_.num_det_layers
img_scale = _base_.img_scale
@ -20,9 +25,9 @@ model = dict(
),
bbox_head=dict(
head_module=dict(widen_factor=widen_factor),
loss_cls=dict(loss_weight=0.3 *
loss_cls=dict(loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
loss_obj=dict(loss_weight=0.7 *
loss_obj=dict(loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers))))
pre_transform = _base_.pre_transform
@ -49,7 +54,7 @@ train_pipeline = [
*pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=0.1,
prob=mixup_prob,
pre_transform=[*pre_transform, *mosaic_affine_pipeline]),
dict(
type='mmdet.Albu',
@ -71,5 +76,4 @@ train_pipeline = [
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
default_hooks = dict(param_scheduler=dict(lr_factor=lr_factor))

View File

@ -1,19 +1,32 @@
_base_ = 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
# ========================modified parameters======================
img_scale = (1280, 1280) # width, height
num_classes = 80
# only on Val
batch_shapes_cfg = dict(img_size=img_scale[0], size_divisor=64)
num_classes = 80 # Number of classes for classification
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
img_size=img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=64)
# Basic size of multi-scale prior box
anchors = [
[(19, 27), (44, 40), (38, 94)], # P3/8
[(96, 68), (86, 152), (180, 137)], # P4/16
[(140, 301), (303, 264), (238, 542)], # P5/32
[(436, 615), (739, 380), (925, 792)] # P6/64
]
# Strides of multi-scale prior box
strides = [8, 16, 32, 64]
num_det_layers = 4
num_det_layers = 4 # The number of model output scales
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
# The obj loss weights of the three output layers
obj_level_weights = [4.0, 1.0, 0.25, 0.06]
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
# =======================Unmodified in most cases==================
model = dict(
backbone=dict(arch='P6', out_indices=(2, 3, 4, 5)),
neck=dict(
@ -23,12 +36,12 @@ model = dict(
in_channels=[256, 512, 768, 1024], featmap_strides=strides),
prior_generator=dict(base_sizes=anchors, strides=strides),
# scaled based on number of detection layers
loss_cls=dict(loss_weight=0.5 *
loss_cls=dict(loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(loss_weight=0.05 * (3 / num_det_layers)),
loss_obj=dict(loss_weight=1.0 *
loss_bbox=dict(loss_weight=loss_bbox_weight * (3 / num_det_layers)),
loss_obj=dict(loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
obj_level_weights=[4.0, 1.0, 0.25, 0.06]))
obj_level_weights=obj_level_weights))
pre_transform = _base_.pre_transform
albu_train_transforms = _base_.albu_train_transforms
@ -44,7 +57,7 @@ train_pipeline = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(0.5, 1.5),
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),

View File

@ -1,9 +1,7 @@
_base_ = 'yolov5_s-v61_syncbn_8xb16-300e_coco.py'
test_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(
type='LetterResize',
scale=_base_.img_scale,

View File

@ -1,47 +1,95 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path
# parameters that often need to be modified
num_classes = 80
img_scale = (640, 640) # width, height
deepen_factor = 0.33
widen_factor = 0.5
max_epochs = 300
save_epoch_intervals = 10
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
val_batch_size_per_gpu = 1
val_num_workers = 2
# persistent_workers must be False if num_workers is 0.
# persistent_workers must be False if num_workers is 0
persistent_workers = True
# Base learning rate for optim_wrapper
base_lr = 0.01
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# -----model related-----
# Basic size of multi-scale prior box
anchors = [
[(10, 13), (16, 30), (33, 23)], # P3/8
[(30, 61), (62, 45), (59, 119)], # P4/16
[(116, 90), (156, 198), (373, 326)] # P5/32
]
strides = [8, 16, 32]
num_det_layers = 3
# single-scale training is recommended to
# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=64 bs
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
model_test_cfg = dict(
# The config of multi-label for multi-class prediction.
multi_label=True,
# The number of boxes before NMS
nms_pre=30000,
score_thr=0.001, # Threshold to filter out boxes.
nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
max_per_img=300) # Max number of detections of each image
# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3 # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4. # Priori box matching threshold
obj_level_weights = [4., 1.,
0.4] # The obj loss weights of the three output layers
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ===============================Unmodified in most cases====================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
@ -53,7 +101,7 @@ model = dict(
type='YOLOv5CSPDarknet',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='YOLOv5PAFPN',
@ -62,7 +110,7 @@ model = dict(
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
norm_cfg=norm_cfg,
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv5Head',
@ -82,28 +130,25 @@ model = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.5 * (num_classes / 80 * 3 / num_det_layers)),
loss_weight=loss_cls_weight *
(num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xywh',
eps=1e-7,
reduction='mean',
loss_weight=0.05 * (3 / num_det_layers),
loss_weight=loss_bbox_weight * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0 * ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
prior_match_thr=4.,
obj_level_weights=[4., 1., 0.4]),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
loss_weight=loss_obj_weight *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
prior_match_thr=prior_match_thr,
obj_level_weights=obj_level_weights),
test_cfg=model_test_cfg)
albu_train_transforms = [
dict(type='Blur', p=0.01),
@ -128,7 +173,7 @@ train_pipeline = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(0.5, 1.5),
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
@ -160,8 +205,8 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
@ -191,8 +236,8 @@ val_dataloader = dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
@ -205,7 +250,7 @@ optim_wrapper = dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=0.0005,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
@ -214,13 +259,13 @@ default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='linear',
lr_factor=0.01,
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
save_best='auto',
max_keep_ckpts=3))
max_keep_ckpts=max_keep_ckpts))
custom_hooks = [
dict(
@ -235,7 +280,7 @@ custom_hooks = [
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator

View File

@ -1,39 +1,42 @@
_base_ = './yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
# ========================modified parameters======================
data_root = 'data/balloon/'
train_batch_size_per_gpu = 4
train_num_workers = 2
# Path of train annotation file
train_ann_file = 'train.json'
train_data_prefix = 'train/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'val.json'
val_data_prefix = 'val/' # Prefix of val image path
metainfo = {
'classes': ('balloon', ),
'palette': [
(220, 20, 60),
]
}
num_classes = 1
train_batch_size_per_gpu = 4
train_num_workers = 2
log_interval = 1
# =======================Unmodified in most cases==================
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
data_prefix=dict(img='train/'),
ann_file='train.json'))
data_prefix=dict(img=train_data_prefix),
ann_file=train_ann_file))
val_dataloader = dict(
dataset=dict(
data_root=data_root,
metainfo=metainfo,
data_prefix=dict(img='val/'),
ann_file='val.json'))
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file))
test_dataloader = val_dataloader
val_evaluator = dict(ann_file=data_root + 'val.json')
val_evaluator = dict(ann_file=data_root + val_ann_file)
test_evaluator = val_evaluator
model = dict(bbox_head=dict(head_module=dict(num_classes=1)))
default_hooks = dict(logger=dict(interval=1))
model = dict(bbox_head=dict(head_module=dict(num_classes=num_classes)))
default_hooks = dict(logger=dict(interval=log_interval))