mmyolo/configs/rtmdet/rtmdet_l_syncbn_8xb32-300e_...

237 lines
6.9 KiB
Python

_base_ = '../_base_/default_runtime.py'
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
img_scale = (640, 640) # height, width
deepen_factor = 1.0
widen_factor = 1.0
max_epochs = 300
stage2_num_epochs = 20
interval = 10
train_batch_size_per_gpu = 32
train_num_workers = 10
val_batch_size_per_gpu = 5
val_num_workers = 10
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
strides = [8, 16, 32]
base_lr = 0.004
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
bgr_to_rgb=False),
backbone=dict(
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
channel_attention=True,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='CSPNeXtPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=256,
num_csp_blocks=3,
expand_ratio=0.5,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='RTMDetHead',
head_module=dict(
type='RTMDetSepBNHeadModule',
num_classes=80,
in_channels=256,
stacked_convs=2,
feat_channels=256,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='SiLU', inplace=True),
share_conv=True,
pred_kernel_size=1,
featmap_strides=strides),
prior_generator=dict(
type='mmdet.MlvlPointGenerator', offset=0, strides=strides),
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
loss_cls=dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox=dict(type='mmdet.GIoULoss', loss_weight=2.0)),
train_cfg=dict(
assigner=dict(
type='mmdet.DynamicSoftLabelAssigner',
topk=13,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100),
)
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Mosaic',
img_scale=img_scale,
use_cached=True,
max_cached_images=40,
pad_val=114.0),
dict(
type='mmdet.RandomResize',
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=(0.1, 2.0),
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(
type='YOLOXMixUp',
img_scale=img_scale,
use_cached=True,
ratio_range=(1.0, 1.0),
max_cached_images=20,
pad_val=(114, 114, 114)),
dict(type='mmdet.PackDetInputs')
]
train_pipeline_stage2 = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='mmdet.RandomResize',
scale=img_scale,
ratio_range=(0.1, 2.0),
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='mmdet.PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='mmdet.Resize', scale=img_scale, keep_ratio=True),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader
# Reduce evaluation time
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox')
test_evaluator = val_evaluator
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 150 to 300 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]
# hooks
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=interval,
max_keep_ckpts=3 # only keep latest 3 checkpoints
))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')