[Feature] Support YOLOv6 ML model (#265)

* align larger infer

* align large test

* update docstr

* add ut

* Add yolov6m pipeline

* Add yolov6m pipeline

* Improve coding

* update l config

* align m test

* rename config

* support training

* update

* block_cfg -> stage_block_cfg

* update

* Add docstr

* Fix lint

* update convert script

* update

* rename param

* update readme&metafile

* rm duplicate docde

* fix lint

* update

* update

* update

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
pull/278/head
wanghonglie 2022-11-11 19:13:51 +08:00 committed by GitHub
parent 178b0bf3ff
commit ab4e7c5158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 821 additions and 142 deletions

View File

@ -16,17 +16,20 @@ For years, YOLO series have been de facto industry-level standard for efficient
### COCO
| Backbone | Arch | size | SyncBN | AMP | Mem (GB) | box AP | Config | Download |
| :------: | :--: | :--: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOv6-n | P5 | 640 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
| YOLOv6-t | P5 | 640 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
| YOLOv6-s | P5 | 640 | Yes | Yes | 8.88 | 44.0 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035.log.json) |
| Backbone | Arch | Size | Epoch | SyncBN | AMP | Mem (GB) | Box AP | Config | Download |
| :------: | :--: | :--: | :---: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOv6-n | P5 | 640 | 400 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
| YOLOv6-t | P5 | 640 | 400 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
| YOLOv6-s | P5 | 640 | 400 | Yes | Yes | 8.88 | 44.0 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035.log.json) |
| YOLOv6-m | P5 | 640 | 300 | Yes | Yes | 16.69 | 48.4 | [config](../yolov6/yolov6_m_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658-85bda3f4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658.log.json) |
| YOLOv6-l | P5 | 640 | 300 | Yes | Yes | 20.86 | 51.0 | [config](../yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156-91e3c447.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156.log.json) |
**Note**:
1. The performance is unstable and may fluctuate by about 0.3 mAP.
2. YOLOv6-m,l,x will be supported in later version.
3. If users need the weight of 300 epoch, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
1. The official m and l models use knowledge distillation, but our version does not support it, which will be implemented in [MMRazor](https://github.com/open-mmlab/mmrazor) in the future.
2. The performance is unstable and may fluctuate by about 0.3 mAP.
3. If users need the weight of 300 epoch for nano, tiny and small model, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
4. We have observed that the [base model](https://github.com/meituan/YOLOv6/tree/main/configs/base) has been officially released in v6 recently. Although the accuracy has decreased, it is more efficient. We will also provide the base model configuration in the future.
## Citation

View File

@ -57,3 +57,27 @@ Models:
Metrics:
box AP: 41.0
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth
- Name: yolov6_m_syncbn_fast_8xb32-300e_coco
In Collection: YOLOv6
Config: configs/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 16.69
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 48.4
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658-85bda3f4.pth
- Name: yolov6_l_syncbn_fast_8xb32-300e_coco
In Collection: YOLOv6
Config: configs/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py
Metadata:
Training Memory (GB): 20.86
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 51.0
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156-91e3c447.pth

View File

@ -0,0 +1,23 @@
_base_ = './yolov6_m_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 1
widen_factor = 1
model = dict(
backbone=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=1. / 2,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=1. / 2,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
block_act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

View File

@ -0,0 +1,54 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 0.6
widen_factor = 0.75
affine_scale = 0.9
model = dict(
backbone=dict(
type='YOLOv6CSPBep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
hidden_ratio=2. / 3,
block_cfg=dict(type='RepVGGBlock'),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6CSPRepPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
hidden_ratio=2. / 3,
block_act_cfg=dict(type='ReLU', inplace=True)),
bbox_head=dict(
type='YOLOv6Head', head_module=dict(widen_factor=widen_factor)))
mosaic_affine_pipeline = [
dict(
type='Mosaic',
img_scale=_base_.img_scale,
pad_val=114.0,
pre_transform=_base_.pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
border=(-_base_.img_scale[0] // 2, -_base_.img_scale[1] // 2),
border_val=(114, 114, 114))
]
train_pipeline = [
*_base_.pre_transform, *mosaic_affine_pipeline,
dict(
type='YOLOv5MixUp',
prob=0.1,
pre_transform=[*_base_.pre_transform, *mosaic_affine_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

View File

@ -12,6 +12,7 @@ num_classes = 80
img_scale = (640, 640) # height, width
deepen_factor = 0.33
widen_factor = 0.5
affine_scale = 0.5
save_epoch_intervals = 10
train_batch_size_per_gpu = 32
train_num_workers = 8
@ -112,7 +113,7 @@ train_pipeline = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(0.5, 1.5),
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114),
max_shear_degree=0.0),
@ -136,7 +137,7 @@ train_pipeline_stage2 = [
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(0.5, 1.5),
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_shear_degree=0.0,
),
dict(type='YOLOv5HSVRandomAug'),

View File

@ -3,10 +3,10 @@ from .base_backbone import BaseBackbone
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
from .csp_resnet import PPYOLOECSPResNet
from .cspnext import CSPNeXt
from .efficient_rep import YOLOv6EfficientRep
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
from .yolov7_backbone import YOLOv7Backbone
__all__ = [
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep',
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet'
]

View File

@ -87,7 +87,6 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
self.num_stages = len(arch_setting)
self.arch_setting = arch_setting

View File

@ -8,20 +8,18 @@ from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
from mmyolo.registry import MODELS
from ..layers import RepStageBlock, RepVGGBlock
from ..utils import make_divisible, make_round
from ..layers import BepC3StageBlock, RepStageBlock
from ..utils import make_round
from .base_backbone import BaseBackbone
@MODELS.register_module()
class YOLOv6EfficientRep(BaseBackbone):
"""EfficientRep backbone used in YOLOv6.
Args:
arch (str): Architecture of BaseDarknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
@ -41,10 +39,10 @@ class YOLOv6EfficientRep(BaseBackbone):
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
block (nn.Module): block used to build each stage.
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
init_cfg (Union[dict, list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv6EfficientRep
>>> import torch
@ -78,9 +76,9 @@ class YOLOv6EfficientRep(BaseBackbone):
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
norm_eval: bool = False,
block: nn.Module = RepVGGBlock,
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.block = block
self.block_cfg = block_cfg
super().__init__(
self.arch_settings[arch],
deepen_factor,
@ -96,12 +94,16 @@ class YOLOv6EfficientRep(BaseBackbone):
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
return self.block(
in_channels=self.input_channels,
out_channels=make_divisible(self.arch_setting[0][0],
self.widen_factor),
kernel_size=3,
stride=2)
block_cfg = self.block_cfg.copy()
block_cfg.update(
dict(
in_channels=self.input_channels,
out_channels=int(self.arch_setting[0][0] * self.widen_factor),
kernel_size=3,
stride=2,
))
return MODELS.build(block_cfg)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
@ -112,24 +114,28 @@ class YOLOv6EfficientRep(BaseBackbone):
"""
in_channels, out_channels, num_blocks, use_spp = setting
in_channels = make_divisible(in_channels, self.widen_factor)
out_channels = make_divisible(out_channels, self.widen_factor)
in_channels = int(in_channels * self.widen_factor)
out_channels = int(out_channels * self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
stage = []
rep_stage_block = RepStageBlock(
in_channels=out_channels,
out_channels=out_channels,
num_blocks=num_blocks,
block_cfg=self.block_cfg,
)
ef_block = nn.Sequential(
self.block(
block_cfg = self.block_cfg.copy()
block_cfg.update(
dict(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2),
RepStageBlock(
in_channels=out_channels,
out_channels=out_channels,
n=num_blocks,
block=self.block,
))
stride=2))
stage = []
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
stage.append(ef_block)
if use_spp:
@ -152,3 +158,130 @@ class YOLOv6EfficientRep(BaseBackbone):
m.reset_parameters()
else:
super().init_weights()
@MODELS.register_module()
class YOLOv6CSPBep(YOLOv6EfficientRep):
"""CSPBep backbone used in YOLOv6.
Args:
arch (str): Architecture of BaseDarknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels (int): Number of input image channels. Defaults to 3.
out_indices (Tuple[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
block_act_cfg (dict): Config dict for activation layer used in each
stage. Defaults to dict(type='SiLU', inplace=True).
init_cfg (Union[dict, list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv6CSPBep
>>> import torch
>>> model = YOLOv6CSPBep()
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# From left to right:
# in_channels, out_channels, num_blocks, use_spp
arch_settings = {
'P5': [[64, 128, 6, False], [128, 256, 12, False],
[256, 512, 18, False], [512, 1024, 6, True]]
}
def __init__(self,
arch: str = 'P5',
plugins: Union[dict, List[dict]] = None,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
hidden_ratio: float = 0.5,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
norm_eval: bool = False,
block_cfg: ConfigType = dict(type='ConvWrapper'),
init_cfg: OptMultiConfig = None):
self.hidden_ratio = hidden_ratio
super().__init__(
arch=arch,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
input_channels=input_channels,
out_indices=out_indices,
plugins=plugins,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
block_cfg=block_cfg,
init_cfg=init_cfg)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, num_blocks, use_spp = setting
in_channels = int(in_channels * self.widen_factor)
out_channels = int(out_channels * self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
rep_stage_block = BepC3StageBlock(
in_channels=out_channels,
out_channels=out_channels,
num_blocks=num_blocks,
hidden_ratio=self.hidden_ratio,
block_cfg=self.block_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
block_cfg = self.block_cfg.copy()
block_cfg.update(
dict(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2))
stage = []
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
stage.append(ef_block)
if use_spp:
spp = SPPFBottleneck(
in_channels=out_channels,
out_channels=out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage

View File

@ -14,7 +14,6 @@ from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS, TASK_UTILS
from ..utils import make_divisible
from .yolov5_head import YOLOv5Head
@ -65,12 +64,10 @@ class YOLOv6HeadModule(BaseModule):
self.act_cfg = act_cfg
if isinstance(in_channels, int):
self.in_channels = [make_divisible(in_channels, widen_factor)
self.in_channels = [int(in_channels * widen_factor)
] * self.num_levels
else:
self.in_channels = [
make_divisible(i, widen_factor) for i in in_channels
]
self.in_channels = [int(i * widen_factor) for i in in_channels]
self._init_layers()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ema import ExpMomentumEMA
from .yolo_bricks import (EffectiveSELayer, ELANBlock,
from .yolo_bricks import (BepC3StageBlock, EffectiveSELayer, ELANBlock,
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
RepStageBlock, RepVGGBlock, SPPFBottleneck,
SPPFCSPBlock)
@ -8,5 +8,5 @@ from .yolo_bricks import (EffectiveSELayer, ELANBlock,
__all__ = [
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
'PPYOLOEBasicBlock', 'EffectiveSELayer'
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'BepC3StageBlock'
]

View File

@ -9,6 +9,7 @@ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from torch import Tensor
from torch.nn.parameter import Parameter
from mmyolo.registry import MODELS
@ -22,7 +23,7 @@ else:
def __init__(self, inplace=True):
super().__init__()
def forward(self, inputs) -> torch.Tensor:
def forward(self, inputs) -> Tensor:
return inputs * torch.sigmoid(inputs)
MODELS.register_module(module=SiLU, name='SiLU')
@ -31,7 +32,6 @@ else:
class SPPFBottleneck(BaseModule):
"""Spatial pyramid pooling - Fast (SPPF) layer for
YOLOv5, YOLOX and PPYOLOE by Glenn Jocher
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
@ -100,7 +100,7 @@ class SPPFBottleneck(BaseModule):
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""Forward process
Args:
x (Tensor): The input tensor.
@ -118,6 +118,7 @@ class SPPFBottleneck(BaseModule):
return x
@MODELS.register_module()
class RepVGGBlock(nn.Module):
"""RepVGGBlock is a basic rep-style block, including training and deploy
status This code is based on
@ -227,11 +228,11 @@ class RepVGGBlock(nn.Module):
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
@ -270,9 +271,9 @@ class RepVGGBlock(nn.Module):
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
"""Pad 1x1 tensor to 3x3.
Args:
kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
Returns:
Tensor: 3x3 kernel after padded.
"""
@ -281,14 +282,12 @@ class RepVGGBlock(nn.Module):
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self,
branch: nn.Module) -> Tuple[np.ndarray, torch.Tensor]:
def _fuse_bn_tensor(self, branch: nn.Module) -> Tuple[np.ndarray, Tensor]:
"""Derives the equivalent kernel and bias of a specific branch layer.
Args:
branch (nn.Module): The layer that needs to be equivalently
transformed, which can be nn.Sequential or nn.Batchnorm2d
Returns:
tuple: Equivalent kernel and bias
"""
@ -348,38 +347,177 @@ class RepVGGBlock(nn.Module):
self.deploy = True
class RepStageBlock(nn.Module):
"""RepStageBlock is a stage block with rep-style basic block.
@MODELS.register_module()
class BepC3StageBlock(nn.Module):
"""Beer-mug RepC3 Block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
n (int, tuple[int]): Number of blocks. Defaults to 1.
block (nn.Module): Basic unit of RepStage. Defaults to RepVGGBlock.
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
num_blocks (int): Number of blocks. Defaults to 1
hidden_ratio (float): Hidden channel expansion.
Default: 0.5
concat_all_layer (bool): Concat all layer when forward calculate.
Default: True
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
norm_cfg (ConfigType): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (ConfigType): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
out_channels: int,
n: int = 1,
block: nn.Module = RepVGGBlock):
num_blocks: int = 1,
hidden_ratio: float = 0.5,
concat_all_layer: bool = True,
block_cfg: ConfigType = dict(type='RepVGGBlock'),
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True)):
super().__init__()
self.conv1 = block(in_channels, out_channels)
self.block = nn.Sequential(*(block(out_channels, out_channels)
for _ in range(n - 1))) if n > 1 else None
hidden_channels = int(out_channels * hidden_ratio)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
self.conv1 = ConvModule(
in_channels,
hidden_channels,
kernel_size=1,
stride=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
in_channels,
hidden_channels,
kernel_size=1,
stride=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv3 = ConvModule(
2 * hidden_channels,
out_channels,
kernel_size=1,
stride=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.block = RepStageBlock(
in_channels=hidden_channels,
out_channels=hidden_channels,
num_blocks=num_blocks,
block_cfg=block_cfg,
bottle_block=BottleRep)
self.concat_all_layer = concat_all_layer
if not concat_all_layer:
self.conv3 = ConvModule(
hidden_channels,
out_channels,
kernel_size=1,
stride=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
Returns:
Tensor: The output tensor.
"""
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x
def forward(self, x):
if self.concat_all_layer is True:
return self.conv3(
torch.cat((self.block(self.conv1(x)), self.conv2(x)), dim=1))
else:
return self.conv3(self.block(self.conv1(x)))
class BottleRep(nn.Module):
"""Bottle Rep Block.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
adaptive_weight (bool): Add adaptive_weight when forward calculate.
Defaults False.
"""
def __init__(self,
in_channels: int,
out_channels: int,
block_cfg: ConfigType = dict(type='RepVGGBlock'),
adaptive_weight: bool = False):
super().__init__()
conv1_cfg = block_cfg.copy()
conv2_cfg = block_cfg.copy()
conv1_cfg.update(
dict(in_channels=in_channels, out_channels=out_channels))
conv2_cfg.update(
dict(in_channels=out_channels, out_channels=out_channels))
self.conv1 = MODELS.build(conv1_cfg)
self.conv2 = MODELS.build(conv2_cfg)
if in_channels != out_channels:
self.shortcut = False
else:
self.shortcut = True
if adaptive_weight:
self.alpha = Parameter(torch.ones(1))
else:
self.alpha = 1.0
def forward(self, x: Tensor) -> Tensor:
outputs = self.conv1(x)
outputs = self.conv2(outputs)
return outputs + self.alpha * x if self.shortcut else outputs
@MODELS.register_module()
class ConvWrapper(nn.Module):
"""Wrapper for normal Conv with SiLU activation.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple): Stride of the convolution. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): Conv bias. Default: True.
norm_cfg (ConfigType): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (ConfigType): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
bias: bool = True,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = dict(type='SiLU')):
super().__init__()
self.block = ConvModule(
in_channels,
out_channels,
kernel_size,
stride,
padding=kernel_size // 2,
groups=groups,
bias=bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
@MODELS.register_module()
@ -390,7 +528,6 @@ class EffectiveSELayer(nn.Module):
arxiv (https://arxiv.org/abs/1911.06667)
This code referenced to
https://github.com/youngwanLEE/CenterMask/blob/72147e8aae673fcaf4103ee90a6a6b73863e7fa1/maskrcnn_benchmark/modeling/backbone/vovnet.py#L108-L121 # noqa
Args:
channels (int): The input and output channels of this Module.
act_cfg (dict): Config dict for activation layer.
@ -419,13 +556,13 @@ class EffectiveSELayer(nn.Module):
class PPYOLOESELayer(nn.Module):
"""Squeeze-and-Excitation Attention Module for PPYOLOE.
There are some differences between the current implementation and
There are some differences between the current implementation and
SELayer in mmdet:
1. For fast speed and avoiding double inference in ppyoloe,
use `F.adaptive_avg_pool2d` before PPYOLOESELayer.
2. Special ways to init weights.
3. Different convolution order.
Args:
feat_channels (int): The input (and output) channels of the SE layer.
norm_cfg (dict): Config dict for normalization layer.
@ -473,7 +610,6 @@ class ELANBlock(BaseModule):
- if mode is `no_change_channel`, the output channel does not change.
- if mode is `expand_channel_2x`, the output channel will be
expanded by a factor of 2
Args:
in_channels (int): The input channels of this Module.
mode (str): Output channel mode. Defaults to `expand_channel_2x`.
@ -597,7 +733,6 @@ class MaxPoolAndStrideConvBlock(BaseModule):
- if mode is `reduce_channel_2x`, the output channel will
be reduced by a factor of 2
- if mode is `no_change_channel`, the output channel does not change.
Args:
in_channels (int): The input channels of this Module.
mode (str): Output channel mode. `reduce_channel_2x` or
@ -669,7 +804,6 @@ class MaxPoolAndStrideConvBlock(BaseModule):
class SPPFCSPBlock(BaseModule):
"""Spatial pyramid pooling - Fast (SPPF) layer with CSP for
YOLOv7
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
@ -832,9 +966,9 @@ class PPYOLOEBasicBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
@ -986,3 +1120,66 @@ class CSPResLayer(nn.Module):
y = self.attn(y)
y = self.conv3(y)
return y
@MODELS.register_module()
class RepStageBlock(nn.Module):
"""RepStageBlock is a stage block with rep-style basic block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
num_blocks (int, tuple[int]): Number of blocks. Defaults to 1.
bottle_block (nn.Module): Basic unit of RepStage.
Defaults to RepVGGBlock.
block_cfg (ConfigType): Config of RepStage.
Defaults to 'RepVGGBlock'.
"""
def __init__(self,
in_channels: int,
out_channels: int,
num_blocks: int = 1,
bottle_block: nn.Module = RepVGGBlock,
block_cfg: ConfigType = dict(type='RepVGGBlock')):
super().__init__()
block_cfg = block_cfg.copy()
block_cfg.update(
dict(in_channels=in_channels, out_channels=out_channels))
self.conv1 = MODELS.build(block_cfg)
block_cfg.update(
dict(in_channels=out_channels, out_channels=out_channels))
self.block = nn.Sequential(*(
MODELS.build(block_cfg)
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
if bottle_block == BottleRep:
self.conv1 = BottleRep(
in_channels,
out_channels,
block_cfg=block_cfg,
adaptive_weight=True)
num_blocks = num_blocks // 2
self.block = nn.Sequential(*(
BottleRep(
out_channels,
out_channels,
block_cfg=block_cfg,
adaptive_weight=True)
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
def forward(self, x: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x

View File

@ -3,11 +3,11 @@ from .base_yolo_neck import BaseYOLONeck
from .cspnext_pafpn import CSPNeXtPAFPN
from .ppyoloe_csppan import PPYOLOECSPPAFPN
from .yolov5_pafpn import YOLOv5PAFPN
from .yolov6_pafpn import YOLOv6RepPAFPN
from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
from .yolov7_pafpn import YOLOv7PAFPN
from .yolox_pafpn import YOLOXPAFPN
__all__ = [
'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN',
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN'
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN'
]

View File

@ -56,12 +56,15 @@ class YOLOv5PAFPN(BaseYOLONeck):
init_cfg=init_cfg)
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
if self.init_cfg is None:
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
else:
super().init_weights()
def build_reduce_layer(self, idx: int) -> nn.Module:
"""build reduce layer.

View File

@ -7,8 +7,8 @@ from mmcv.cnn import ConvModule
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import RepStageBlock, RepVGGBlock
from ..utils import make_divisible, make_round
from ..layers import BepC3StageBlock, RepStageBlock
from ..utils import make_round
from .base_yolo_neck import BaseYOLONeck
@ -29,8 +29,8 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
block (nn.Module): block used to build each layer.
Defaults to RepVGGBlock.
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
@ -45,10 +45,10 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
block: nn.Module = RepVGGBlock,
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.num_csp_blocks = num_csp_blocks
self.block = block
self.block_cfg = block_cfg
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
@ -64,16 +64,14 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Args:
idx (int): layer idx.
Returns:
nn.Module: The reduce layer.
"""
if idx == 2:
layer = ConvModule(
in_channels=make_divisible(self.in_channels[idx],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
in_channels=int(self.in_channels[idx] * self.widen_factor),
out_channels=int(self.out_channels[idx - 1] *
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
@ -88,15 +86,12 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Args:
idx (int): layer idx.
Returns:
nn.Module: The upsample layer.
"""
return nn.ConvTranspose2d(
in_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
in_channels=int(self.out_channels[idx - 1] * self.widen_factor),
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
kernel_size=2,
stride=2,
bias=True)
@ -106,26 +101,27 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Args:
idx (int): layer idx.
Returns:
nn.Module: The top down layer.
"""
block_cfg = self.block_cfg.copy()
layer0 = RepStageBlock(
in_channels=make_divisible(
self.out_channels[idx - 1] + self.in_channels[idx - 1],
in_channels=int(
(self.out_channels[idx - 1] + self.in_channels[idx - 1]) *
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
n=make_round(self.num_csp_blocks, self.deepen_factor),
block=self.block)
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg)
if idx == 1:
return layer0
elif idx == 2:
layer1 = ConvModule(
in_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 2],
self.widen_factor),
in_channels=int(self.out_channels[idx - 1] *
self.widen_factor),
out_channels=int(self.out_channels[idx - 2] *
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
@ -137,15 +133,12 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Args:
idx (int): layer idx.
Returns:
nn.Module: The downsample layer.
"""
return ConvModule(
in_channels=make_divisible(self.out_channels[idx],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx],
self.widen_factor),
in_channels=int(self.out_channels[idx] * self.widen_factor),
out_channels=int(self.out_channels[idx] * self.widen_factor),
kernel_size=3,
stride=2,
padding=3 // 2,
@ -157,26 +150,136 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
Args:
idx (int): layer idx.
Returns:
nn.Module: The bottom up layer.
"""
block_cfg = self.block_cfg.copy()
return RepStageBlock(
in_channels=make_divisible(self.out_channels[idx] * 2,
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx + 1],
self.widen_factor),
n=make_round(self.num_csp_blocks, self.deepen_factor),
block=self.block)
in_channels=int(self.out_channels[idx] * 2 * self.widen_factor),
out_channels=int(self.out_channels[idx + 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg)
def build_out_layer(self, *args, **kwargs) -> nn.Module:
"""build out layer."""
return nn.Identity()
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
if self.init_cfg is None:
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
else:
super().init_weights()
@MODELS.register_module()
class YOLOv6CSPRepPAFPN(YOLOv6RepPAFPN):
"""Path Aggregation Network used in YOLOv6.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
block_cfg (dict): Config dict for the block used to build each
layer. Defaults to dict(type='RepVGGBlock').
block_act_cfg (dict): Config dict for activation layer used in each
stage. Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
hidden_ratio: float = 0.5,
num_csp_blocks: int = 12,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
block_act_cfg: ConfigType = dict(type='SiLU', inplace=True),
block_cfg: ConfigType = dict(type='RepVGGBlock'),
init_cfg: OptMultiConfig = None):
self.hidden_ratio = hidden_ratio
self.block_act_cfg = block_act_cfg
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
num_csp_blocks=num_csp_blocks,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
block_cfg=block_cfg,
init_cfg=init_cfg)
def build_top_down_layer(self, idx: int) -> nn.Module:
"""build top down layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The top down layer.
"""
block_cfg = self.block_cfg.copy()
layer0 = BepC3StageBlock(
in_channels=int(
(self.out_channels[idx - 1] + self.in_channels[idx - 1]) *
self.widen_factor),
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg,
hidden_ratio=self.hidden_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.block_act_cfg)
if idx == 1:
return layer0
elif idx == 2:
layer1 = ConvModule(
in_channels=int(self.out_channels[idx - 1] *
self.widen_factor),
out_channels=int(self.out_channels[idx - 2] *
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return nn.Sequential(layer0, layer1)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
"""build bottom up layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The bottom up layer.
"""
block_cfg = self.block_cfg.copy()
return BepC3StageBlock(
in_channels=int(self.out_channels[idx] * 2 * self.widen_factor),
out_channels=int(self.out_channels[idx + 1] * self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
block_cfg=block_cfg,
hidden_ratio=self.hidden_ratio,
norm_cfg=self.norm_cfg,
act_cfg=self.block_act_cfg)

View File

@ -5,7 +5,7 @@ import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.models.backbones import YOLOv6EfficientRep
from mmyolo.models.backbones import YOLOv6CSPBep, YOLOv6EfficientRep
from mmyolo.utils import register_all_modules
from .utils import check_norm_state, is_norm
@ -23,7 +23,7 @@ class TestYOLOv6EfficientRep(TestCase):
# frozen_stages must in range(-1, len(arch_setting) + 1)
YOLOv6EfficientRep(frozen_stages=6)
def test_forward(self):
def test_YOLOv6EfficientRep_forward(self):
# Test YOLOv6EfficientRep with first stage frozen
frozen_stages = 1
model = YOLOv6EfficientRep(frozen_stages=frozen_stages)
@ -111,3 +111,92 @@ class TestYOLOv6EfficientRep(TestCase):
assert feat[0].shape == torch.Size((1, 256, 32, 32))
assert feat[1].shape == torch.Size((1, 512, 16, 16))
assert feat[2].shape == torch.Size((1, 1024, 8, 8))
def test_YOLOv6CSPBep_forward(self):
# Test YOLOv6CSPBep with first stage frozen
frozen_stages = 1
model = YOLOv6CSPBep(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'stage{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test YOLOv6CSPBep with norm_eval=True
model = YOLOv6CSPBep(norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test YOLOv6CSPBep forward with widen_factor=0.25
model = YOLOv6CSPBep(
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 32, 16, 16))
assert feat[2].shape == torch.Size((1, 64, 8, 8))
assert feat[3].shape == torch.Size((1, 128, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))
# Test YOLOv6CSPBep forward with dict(type='ReLU')
model = YOLOv6CSPBep(
widen_factor=0.125,
act_cfg=dict(type='ReLU'),
out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test YOLOv6CSPBep with BatchNorm forward
model = YOLOv6CSPBep(widen_factor=0.125, out_indices=range(0, 5))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test YOLOv6CSPBep with BatchNorm forward
model = YOLOv6CSPBep(plugins=[
dict(
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
stages=(False, False, True, True)),
])
assert len(model.stage1) == 1
assert len(model.stage2) == 1
assert len(model.stage3) == 2 # +DropBlock
assert len(model.stage4) == 3 # +SPPF+DropBlock
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 256, 32, 32))
assert feat[1].shape == torch.Size((1, 512, 16, 16))
assert feat[2].shape == torch.Size((1, 1024, 8, 8))

View File

@ -3,15 +3,15 @@ from unittest import TestCase
import torch
from mmyolo.models.necks import YOLOv6RepPAFPN
from mmyolo.models.necks import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv6RepPAFPN(TestCase):
class TestYOLOv6PAFPN(TestCase):
def test_forward(self):
def test_YOLOv6RepPAFP_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
@ -27,3 +27,20 @@ class TestYOLOv6RepPAFPN(TestCase):
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
def test_YOLOv6CSPRepPAFPN_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv6CSPRepPAFPN(
in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

View File

@ -28,12 +28,28 @@ def convert(src, dst):
if 'ERBlock_2' in k:
name = k.replace('ERBlock_2', 'stage1.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_3' in k:
name = k.replace('ERBlock_3', 'stage2.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_4' in k:
name = k.replace('ERBlock_4', 'stage3.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_5' in k:
name = k.replace('ERBlock_5', 'stage4.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if 'stage4.0.2' in name:
name = name.replace('stage4.0.2', 'stage4.1')
name = name.replace('cv', 'conv')
@ -41,10 +57,22 @@ def convert(src, dst):
name = k.replace('reduce_layer0', 'reduce_layers.2')
elif 'Rep_p4' in k:
name = k.replace('Rep_p4', 'top_down_layers.0.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'reduce_layer1' in k:
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_p3' in k:
name = k.replace('Rep_p3', 'top_down_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'upsample0' in k:
name = k.replace('upsample0.upsample_transpose',
'upsample_layers.0')
@ -53,8 +81,16 @@ def convert(src, dst):
'upsample_layers.1')
elif 'Rep_n3' in k:
name = k.replace('Rep_n3', 'bottom_up_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_n4' in k:
name = k.replace('Rep_n4', 'bottom_up_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'downsample2' in k:
name = k.replace('downsample2', 'downsample_layers.0')
elif 'downsample1' in k: