[FEATURE] Support YOLOv6 3.0 inference ()

* [FEATURE] Support YOLOv6 3.0 inference

* add CSPSPPFBottleneck module for YOLOv6 3.0 backbone

* add BiFusion module, YOLOv6RepBiPAFPN module for YOLOv6 3.0 neck

* modify YOLOv6HeadModule to support YOLOv6 3.0 head

* add a yolov6v3 l/m/s/t/n conifgs

* [Fix] Modify YOLOv6 3.0 neck

* Modify YOLOv6RepBiPAFPN

* Add unit tests

* [Fix] Modify configs

* Rename yolov6_v3 configs

* Fix a bug in building BiFusion Module

* Add a checkpoint convert script
pull/754/merge
Qingren 2023-04-25 10:22:48 +08:00 committed by GitHub
parent 1aa1ecd27b
commit 9f3adc426f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1109 additions and 11 deletions

View File

@ -0,0 +1,28 @@
_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_coco.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1
# The scaling factor that controls the width of the network structure
widen_factor = 1
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=1. / 2,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=1. / 2,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
block_act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

View File

@ -0,0 +1,63 @@
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_coco.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.6
# The scaling factor that controls the width of the network structure
widen_factor = 0.75
# -----train val related-----
affine_scale = 0.9 # YOLOv5RandomAffine scaling ratio
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
type='YOLOv6CSPBep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=2. / 3,
block_cfg=dict(type='RepVGGBlock'),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6CSPRepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
hidden_ratio=2. / 3,
block_act_cfg=dict(type='ReLU', inplace=True)),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(reg_max=16, widen_factor=widen_factor)))
mosaic_affine_pipeline = [
dict(
type='Mosaic',
img_scale=_base_.img_scale,
pad_val=114.0,
pre_transform=_base_.pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-_base_.img_scale[0] // 2, -_base_.img_scale[1] // 2),
border_val=(114, 114, 114))
]
train_pipeline = [
*_base_.pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=0.1,
pre_transform=[*_base_.pre_transform, *mosaic_affine_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

View File

@ -0,0 +1,21 @@
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_coco.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.25
# -----train val related-----
lr_factor = 0.02 # Learning rate scaling factor
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))
default_hooks = dict(param_scheduler=dict(lr_factor=lr_factor))

View File

@ -0,0 +1,282 @@
_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
# ======================= Frequently modified parameters =====================
# -----data related-----
data_root = 'data/coco/' # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/' # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/' # Prefix of val image path
num_classes = 80 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True
# -----train val related-----
# Base learning rate for optim_wrapper
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
num_last_epochs = 15 # Last epoch number to switch training pipeline
# ======================= Possible modified parameters =======================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ============================== Unmodified in most cases ===================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv6EfficientRep',
out_indices=[1, 2, 3, 4],
use_cspsppf=True,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6RepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[128, 256, 512, 1024],
out_channels=[128, 256, 512],
num_csp_blocks=12,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True),
),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(
type='YOLOv6HeadModule',
num_classes=num_classes,
in_channels=[128, 256, 512],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32]),
loss_bbox=dict(
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False)),
train_cfg=dict(
initial_epoch=4,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=num_classes,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
topk=13,
alpha=1,
beta=6),
),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
# The training pipeline of YOLOv6 is basically the same as YOLOv5.
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa
pre_transform = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114),
max_shear_degree=0.0),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_pipeline_stage2 = [
*pre_transform,
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=True,
pad_val=dict(img=114)),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_shear_degree=0.0,
),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
collate_fn=dict(type='yolov5_collate'),
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann_file,
data_prefix=dict(img=train_data_prefix),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img=val_data_prefix),
ann_file=val_ann_file,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa
# The difference is that the scheduler_type of YOLOv6 is cosine.
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=max_keep_ckpts,
save_best='auto'))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_last_epochs,
switch_pipeline=train_pipeline_stage2)
]
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + val_ann_file,
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -0,0 +1,17 @@
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_coco.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.375
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
from mmyolo.models.layers.yolo_bricks import CSPSPPFBottleneck, SPPFBottleneck
from mmyolo.registry import MODELS
from ..layers import BepC3StageBlock, RepStageBlock
from ..utils import make_round
@ -72,6 +72,7 @@ class YOLOv6EfficientRep(BaseBackbone):
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
use_cspsppf: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
@ -79,6 +80,7 @@ class YOLOv6EfficientRep(BaseBackbone):
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.block_cfg = block_cfg
self.use_cspsppf = use_cspsppf
super().__init__(
self.arch_settings[arch],
deepen_factor,
@ -145,6 +147,13 @@ class YOLOv6EfficientRep(BaseBackbone):
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_cspsppf:
spp = CSPSPPFBottleneck(
in_channels=out_channels,
out_channels=out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage
@ -222,6 +231,7 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
hidden_ratio: float = 0.5,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
use_cspsppf: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
@ -229,6 +239,7 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
block_cfg: ConfigType = dict(type='ConvWrapper'),
init_cfg: OptMultiConfig = None):
self.hidden_ratio = hidden_ratio
self.use_cspsppf = use_cspsppf
super().__init__(
arch=arch,
deepen_factor=deepen_factor,
@ -283,5 +294,12 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_cspsppf:
spp = CSPSPPFBottleneck(
in_channels=out_channels,
out_channels=out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage

View File

@ -50,6 +50,7 @@ class YOLOv6HeadModule(BaseModule):
in_channels: Union[int, Sequence],
widen_factor: float = 1.0,
num_base_priors: int = 1,
reg_max=0,
featmap_strides: Sequence[int] = (8, 16, 32),
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
@ -61,6 +62,7 @@ class YOLOv6HeadModule(BaseModule):
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.num_base_priors = num_base_priors
self.reg_max = reg_max
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
@ -80,6 +82,12 @@ class YOLOv6HeadModule(BaseModule):
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.stems = nn.ModuleList()
if self.reg_max > 1:
proj = torch.arange(
self.reg_max + self.num_base_priors, dtype=torch.float)
self.register_buffer('proj', proj, persistent=False)
for i in range(self.num_levels):
self.stems.append(
ConvModule(
@ -116,7 +124,7 @@ class YOLOv6HeadModule(BaseModule):
self.reg_preds.append(
nn.Conv2d(
in_channels=self.in_channels[i],
out_channels=self.num_base_priors * 4,
out_channels=(self.num_base_priors + self.reg_max) * 4,
kernel_size=1))
def init_weights(self):
@ -148,6 +156,7 @@ class YOLOv6HeadModule(BaseModule):
cls_pred: nn.Module, reg_conv: nn.Module,
reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
"""Forward feature of a single scale level."""
b, _, h, w = x.shape
y = stem(x)
cls_x = y
reg_x = y
@ -155,9 +164,26 @@ class YOLOv6HeadModule(BaseModule):
reg_feat = reg_conv(reg_x)
cls_score = cls_pred(cls_feat)
bbox_pred = reg_pred(reg_feat)
bbox_dist_preds = reg_pred(reg_feat)
return cls_score, bbox_pred
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max + self.num_base_priors,
h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_score, bbox_preds, bbox_dist_preds
else:
return cls_score, bbox_preds
@MODELS.register_module()
@ -238,6 +264,7 @@ class YOLOv6Head(YOLOv5Head):
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ema import ExpMomentumEMA
from .yolo_bricks import (BepC3StageBlock, CSPLayerWithTwoConv,
from .yolo_bricks import (BepC3StageBlock, BiFusion, CSPLayerWithTwoConv,
DarknetBottleneck, EELANBlock, EffectiveSELayer,
ELANBlock, ImplicitA, ImplicitM,
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
@ -12,5 +12,5 @@ __all__ = [
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock',
'CSPLayerWithTwoConv', 'DarknetBottleneck'
'CSPLayerWithTwoConv', 'DarknetBottleneck', 'BiFusion'
]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -1508,3 +1508,221 @@ class CSPLayerWithTwoConv(BaseModule):
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
return self.final_conv(torch.cat(x_main, 1))
class BiFusion(nn.Module):
"""BiFusion Block in YOLOv6.
BiFusion fuses current-, high- and low-level features.
Compared with concatenation in PAN, it fuses an extra low-level feature.
Args:
in_channels0 (int): The channels of current-level feature.
in_channels1 (int): The input channels of lower-level feature.
out_channels (int): The out channels of the BiFusion module.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels0: int,
in_channels1: int,
out_channels: int,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True)):
super().__init__()
self.conv1 = ConvModule(
in_channels0,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
in_channels1,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv3 = ConvModule(
out_channels * 3,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.upsample = nn.ConvTranspose2d(
out_channels, out_channels, kernel_size=2, stride=2, bias=True)
self.downsample = ConvModule(
out_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x: List[torch.Tensor]) -> Tensor:
"""Forward process
Args:
x (List[torch.Tensor]): The tensor list of length 3.
x[0]: The high-level feature.
x[1]: The current-level feature.
x[2]: The low-level feature.
"""
x0 = self.upsample(x[0])
x1 = self.conv1(x[1])
x2 = self.downsample(self.conv2(x[2]))
return self.conv3(torch.cat((x0, x1, x2), dim=1))
class CSPSPPFBottleneck(BaseModule):
"""The SPPF block having a CSP-like version in YOLOv6 3.0.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
kernel_sizes (int, tuple[int]): Sequential or number of kernel
sizes of pooling layers. Defaults to 5.
use_conv_first (bool): Whether to use conv before pooling layer.
In YOLOv5 and YOLOX, the para set to True.
In PPYOLOE, the para set to False.
Defaults to True.
mid_channels_scale (float): Channel multiplier, multiply in_channels
by this amount to get mid_channels. This parameter is valid only
when use_conv_fist=True.Defaults to 0.5.
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
which means using conv2d. Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_sizes: Union[int, Sequence[int]] = 5,
use_conv_first: bool = True,
mid_channels_scale: float = 0.5,
conv_cfg: ConfigType = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
if use_conv_first:
mid_channels = int(in_channels * mid_channels_scale)
self.conv1 = ConvModule(
in_channels,
mid_channels,
1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv3 = ConvModule(
mid_channels,
mid_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv4 = ConvModule(
mid_channels,
mid_channels,
1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
mid_channels = in_channels
self.conv1 = None
self.conv3 = None
self.conv4 = None
self.conv2 = ConvModule(
in_channels,
mid_channels,
1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.kernel_sizes = kernel_sizes
if isinstance(kernel_sizes, int):
self.poolings = nn.MaxPool2d(
kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2)
conv2_in_channels = mid_channels * 4
else:
self.poolings = nn.ModuleList([
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_in_channels = mid_channels * (len(kernel_sizes) + 1)
self.conv5 = ConvModule(
conv2_in_channels,
mid_channels,
1,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv6 = ConvModule(
mid_channels,
mid_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv7 = ConvModule(
mid_channels * 2,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x: Tensor) -> Tensor:
"""Forward process
Args:
x (Tensor): The input tensor.
"""
x0 = self.conv4(self.conv3(self.conv1(x))) if self.conv1 else x
y = self.conv2(x)
if isinstance(self.kernel_sizes, int):
x1 = self.poolings(x0)
x2 = self.poolings(x1)
x3 = torch.cat([x0, x1, x2, self.poolings(x2)], dim=1)
else:
x3 = torch.cat(
[x0] + [pooling(x0) for pooling in self.poolings], dim=1)
x3 = self.conv6(self.conv5(x3))
x = self.conv7(torch.cat([y, x3], dim=1))
return x

View File

@ -3,7 +3,8 @@ from .base_yolo_neck import BaseYOLONeck
from .cspnext_pafpn import CSPNeXtPAFPN
from .ppyoloe_csppan import PPYOLOECSPPAFPN
from .yolov5_pafpn import YOLOv5PAFPN
from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
from .yolov6_pafpn import (YOLOv6CSPRepBiPAFPN, YOLOv6CSPRepPAFPN,
YOLOv6RepBiPAFPN, YOLOv6RepPAFPN)
from .yolov7_pafpn import YOLOv7PAFPN
from .yolov8_pafpn import YOLOv8PAFPN
from .yolox_pafpn import YOLOXPAFPN
@ -11,5 +12,5 @@ from .yolox_pafpn import YOLOXPAFPN
__all__ = [
'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN',
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN',
'YOLOv8PAFPN'
'YOLOv8PAFPN', 'YOLOv6RepBiPAFPN', 'YOLOv6CSPRepBiPAFPN'
]

View File

@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import BepC3StageBlock, RepStageBlock
from ..layers import BepC3StageBlock, BiFusion, RepStageBlock
from ..utils import make_round
from .base_yolo_neck import BaseYOLONeck
@ -283,3 +283,245 @@ class YOLOv6CSPRepPAFPN(YOLOv6RepPAFPN):
hidden_ratio=self.hidden_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.block_act_cfg)
@MODELS.register_module()
class YOLOv6RepBiPAFPN(YOLOv6RepPAFPN):
"""Path Aggregation Network used in YOLOv6 3.0.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
num_csp_blocks: int = 12,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.extra_in_channel = in_channels[0]
super().__init__(
in_channels=in_channels[1:],
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
num_csp_blocks=num_csp_blocks,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
block_cfg=block_cfg,
init_cfg=init_cfg)
def build_top_down_layer(self, idx: int) -> nn.Module:
"""build top down layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The top down layer.
"""
block_cfg = self.block_cfg.copy()
layer0 = RepStageBlock(
in_channels=int(self.out_channels[idx - 1] * self.widen_factor),
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg)
if idx == 1:
return layer0
elif idx == 2:
layer1 = ConvModule(
in_channels=int(self.out_channels[idx - 1] *
self.widen_factor),
out_channels=int(self.out_channels[idx - 2] *
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return nn.Sequential(layer0, layer1)
def build_upsample_layer(self, idx: int) -> nn.Module:
"""build upsample layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The upsample layer.
"""
in_channels1 = self.in_channels[
idx - 2] if idx > 1 else self.extra_in_channel
return BiFusion(
in_channels0=int(self.in_channels[idx - 1] * self.widen_factor),
in_channels1=int(in_channels1 * self.widen_factor),
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs: List[torch.Tensor]) -> tuple:
"""Forward function."""
assert len(inputs) == len(self.in_channels) + 1
# reduce layers
reduce_outs = [inputs[0]]
for idx in range(len(self.in_channels)):
reduce_outs.append(self.reduce_layers[idx](inputs[idx + 1]))
# top-down path
inner_outs = [reduce_outs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_high = inner_outs[0]
feat_cur = reduce_outs[idx]
feat_low = reduce_outs[idx - 1]
top_down_layer_inputs = self.upsample_layers[len(self.in_channels)
- 1 - idx]([
feat_high,
feat_cur, feat_low
])
inner_out = self.top_down_layers[len(self.in_channels) - 1 - idx](
top_down_layer_inputs)
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_high = inner_outs[idx + 1]
downsample_feat = self.downsample_layers[idx](feat_low)
out = self.bottom_up_layers[idx](
torch.cat([downsample_feat, feat_high], 1))
outs.append(out)
# out_layers
results = []
for idx in range(len(self.in_channels)):
results.append(self.out_layers[idx](outs[idx]))
return tuple(results)
@MODELS.register_module()
class YOLOv6CSPRepBiPAFPN(YOLOv6RepBiPAFPN):
"""Path Aggregation Network used in YOLOv6 3.0.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
block_act_cfg (dict): Config dict for activation layer used in each
stage. Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
hidden_ratio: float = 0.5,
num_csp_blocks: int = 12,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
block_act_cfg: ConfigType = dict(type='SiLU', inplace=True),
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.hidden_ratio = hidden_ratio
self.block_act_cfg = block_act_cfg
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
num_csp_blocks=num_csp_blocks,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
block_cfg=block_cfg,
init_cfg=init_cfg)
def build_top_down_layer(self, idx: int) -> nn.Module:
"""build top down layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The top down layer.
"""
block_cfg = self.block_cfg.copy()
layer0 = BepC3StageBlock(
in_channels=int(self.out_channels[idx - 1] * self.widen_factor),
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg,
hidden_ratio=self.hidden_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.block_act_cfg)
if idx == 1:
return layer0
elif idx == 2:
layer1 = ConvModule(
in_channels=int(self.out_channels[idx - 1] *
self.widen_factor),
out_channels=int(self.out_channels[idx - 2] *
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return nn.Sequential(layer0, layer1)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
"""build bottom up layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The bottom up layer.
"""
block_cfg = self.block_cfg.copy()
return BepC3StageBlock(
in_channels=int(self.out_channels[idx] * 2 * self.widen_factor),
out_channels=int(self.out_channels[idx + 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg,
hidden_ratio=self.hidden_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.block_act_cfg)

View File

@ -34,6 +34,7 @@ class TestYOLOv6Head(TestCase):
nms=dict(type='nms', iou_threshold=0.65)))
head = YOLOv6Head(head_module=self.head_module, test_cfg=test_cfg)
head.eval()
feat = []
for i in range(len(self.head_module['in_channels'])):

View File

@ -3,7 +3,8 @@ from unittest import TestCase
import torch
from mmyolo.models.necks import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
from mmyolo.models.necks import (YOLOv6CSPRepBiPAFPN, YOLOv6CSPRepPAFPN,
YOLOv6RepBiPAFPN, YOLOv6RepPAFPN)
from mmyolo.utils import register_all_modules
register_all_modules()
@ -44,3 +45,37 @@ class TestYOLOv6PAFPN(TestCase):
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
def test_YOLOv6CSPRepBiPAFPN_forward(self):
s = 64
in_channels = [4, 8, 16, 32] # includes an extra input for BiFusion
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv6CSPRepBiPAFPN(
in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats) - 1
for i in range(len(feats) - 1):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == feat_sizes[i + 1]
def test_YOLOv6RepBiPAFPN_forward(self):
s = 64
in_channels = [4, 8, 16, 32] # includes an extra input for BiFusion
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv6RepBiPAFPN(
in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats) - 1
for i in range(len(feats) - 1):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == feat_sizes[i + 1]

View File

@ -0,0 +1,145 @@
import argparse
from collections import OrderedDict
import torch
def convert(src, dst):
import sys
sys.path.append('yolov6')
try:
ckpt = torch.load(src, map_location=torch.device('cpu'))
except ModuleNotFoundError:
raise RuntimeError(
'This script must be placed under the meituan/YOLOv6 repo,'
' because loading the official pretrained model need'
' some python files to build model.')
# The saved model is the model before reparameterization
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
new_state_dict = OrderedDict()
is_ns = False
for k, v in model.state_dict().items():
name = k
if 'detect' in k:
if 'proj' in k:
continue
if 'reg_preds_lrtb' in k:
is_ns = True
name = k.replace('detect', 'bbox_head.head_module')
if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
continue
if 'ERBlock_2' in k:
name = k.replace('ERBlock_2', 'stage1.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_3' in k:
name = k.replace('ERBlock_3', 'stage2.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_4' in k:
name = k.replace('ERBlock_4', 'stage3.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_5' in k:
name = k.replace('ERBlock_5', 'stage4.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if 'stage4.0.2' in name:
name = name.replace('stage4.0.2', 'stage4.1')
name = name.replace('cv', 'conv')
elif 'reduce_layer0' in k:
name = k.replace('reduce_layer0', 'reduce_layers.2')
elif 'Rep_p4' in k:
name = k.replace('Rep_p4', 'top_down_layers.0.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'reduce_layer1' in k:
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_p3' in k:
name = k.replace('Rep_p3', 'top_down_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Bifusion0' in k:
name = k.replace('Bifusion0', 'upsample_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if '.upsample_transpose.' in k:
name = name.replace('.upsample_transpose.', '.')
elif 'Bifusion1' in k:
name = k.replace('Bifusion1', 'upsample_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if '.upsample_transpose.' in k:
name = name.replace('.upsample_transpose.', '.')
elif 'Rep_n3' in k:
name = k.replace('Rep_n3', 'bottom_up_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_n4' in k:
name = k.replace('Rep_n4', 'bottom_up_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'downsample2' in k:
name = k.replace('downsample2', 'downsample_layers.0')
elif 'downsample1' in k:
name = k.replace('downsample1', 'downsample_layers.1')
new_state_dict[name] = v
# The yolov6_v3_n/s has two regression heads.
# One called 'reg_preds_lrtb' is a regular anchor-free head,
# which is used for inference.
# One called 'reg_preds' is a DFL style head, which
# is only used in training.
if is_ns:
tmp_state_dict = OrderedDict()
for k, v in new_state_dict.items():
name = k
if 'reg_preds_lrtb' in k:
name = k.replace('reg_preds_lrtb', 'reg_preds')
elif 'reg_preds' in k:
name = k.replace('reg_preds', 'distill_ns_head')
tmp_state_dict[name] = v
new_state_dict = tmp_state_dict
data = {'state_dict': new_state_dict}
torch.save(data, dst)
# Note: This script must be placed under the yolov6 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov6s.pt', help='src yolov6 model path')
parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()