[Feature] Add Backbone;Neck (#2)

* add backbone neck

* update
pull/7/head
wanghonglie 2022-09-18 10:47:44 +08:00 committed by GitHub
parent 9059d1b446
commit 85e504fe67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1314 additions and 0 deletions

View File

@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_backbone import BaseBackbone
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
from .efficient_rep import YOLOv6EfficientRep
__all__ = [
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOXCSPDarknet'
]

View File

@ -0,0 +1,125 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.registry import MODELS
@MODELS.register_module()
class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""BaseBackbone backbone used in YOLO series.
Args:
arch_setting (dict): Architecture of BaseBackbone.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels: Number of input image channels. Defaults to 3.
out_indices (Sequence[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to None.
act_cfg (dict): Config dict for activation layer.
Defaults to None.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch_setting: dict,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: tuple = (2, 3, 4),
frozen_stages: int = -1,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = None,
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
self.num_stages = len(arch_setting)
self.arch_setting = arch_setting
assert set(out_indices).issubset(
i for i in range(len(arch_setting) + 1))
if frozen_stages not in range(-1, len(arch_setting) + 1):
raise ValueError('frozen_stages must be in range(-1, '
'len(arch_setting) + 1). But received '
f'{frozen_stages}')
self.input_channels = input_channels
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.widen_factor = widen_factor
self.deepen_factor = deepen_factor
self.norm_eval = norm_eval
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.stem = self.build_stem_layer()
self.layers = ['stem']
for idx, setting in enumerate(arch_setting):
stage = []
stage += self.build_stage_layer(idx, setting)
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
self.layers.append(f'stage{idx + 1}')
@abstractmethod
def build_stem_layer(self):
"""Build a stem layer."""
pass
@abstractmethod
def build_stage_layer(self, stage_idx, setting):
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
pass
def _freeze_stages(self):
"""Freeze the parameters of the specified stage so that they are no
longer updated."""
if self.frozen_stages >= 0:
for i in range(self.frozen_stages + 1):
m = getattr(self, self.layers[i])
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def forward(self, x: torch.Tensor) -> tuple:
"""Forward batch_inputs from the data_preprocessor."""
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)

View File

@ -0,0 +1,260 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import SPPFBottleneck
from ..utils import make_divisible, make_round
from .base_backbone import BaseBackbone
@MODELS.register_module()
class YOLOv5CSPDarknet(BaseBackbone):
"""CSP-Darknet backbone used in YOLOv5.
Args:
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
Defaults to P5.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels (int): Number of input image channels. Defaults to: 3.
out_indices (Tuple[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (Union[dict,list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv5CSPDarknet
>>> import torch
>>> model = YOLOv5CSPDarknet()
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# From left to right:
# in_channels, out_channels, num_blocks, add_identity, use_spp
arch_settings = {
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 9, True, False], [512, 1024, 3, True, True]]
}
def __init__(self,
arch: str = 'P5',
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
norm_eval: bool = False,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
self.arch_settings[arch],
deepen_factor,
widen_factor,
input_channels=input_channels,
out_indices=out_indices,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
init_cfg=init_cfg)
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
return ConvModule(
self.input_channels,
make_divisible(self.arch_setting[0][0], self.widen_factor),
kernel_size=6,
stride=2,
padding=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
in_channels = make_divisible(in_channels, self.widen_factor)
out_channels = make_divisible(out_channels, self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
stage = []
conv_layer = ConvModule(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(conv_layer)
csp_layer = CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
add_identity=add_identity,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(csp_layer)
if use_spp:
spp = SPPFBottleneck(
out_channels,
out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
@MODELS.register_module()
class YOLOXCSPDarknet(BaseBackbone):
"""CSP-Darknet backbone used in YOLOX.
Args:
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
Defaults to P5.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels (int): Number of input image channels. Defaults to 3.
out_indices (Tuple[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
layers. Defaults to (5, 9, 13).
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
init_cfg (Union[dict,list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOXCSPDarknet
>>> import torch
>>> model = YOLOXCSPDarknet()
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# From left to right:
# in_channels, out_channels, num_blocks, add_identity, use_spp
arch_settings = {
'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
}
def __init__(self,
arch: str = 'P5',
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
spp_kernal_sizes: Tuple[int] = (5, 9, 13),
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
self.spp_kernal_sizes = spp_kernal_sizes
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
input_channels, out_indices, frozen_stages, norm_cfg,
act_cfg, norm_eval, init_cfg)
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
return Focus(
3,
make_divisible(64, self.widen_factor),
kernel_size=3,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
in_channels = make_divisible(in_channels, self.widen_factor)
out_channels = make_divisible(out_channels, self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
stage = []
conv_layer = ConvModule(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(conv_layer)
if use_spp:
spp = SPPFBottleneck(
out_channels,
out_channels,
kernel_sizes=self.spp_kernal_sizes,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
csp_layer = CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
add_identity=add_identity,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(csp_layer)
return stage

View File

@ -0,0 +1,143 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
from mmyolo.registry import MODELS
from ..layers import RepStageBlock, RepVGGBlock
from ..utils import make_divisible, make_round
from .base_backbone import BaseBackbone
@MODELS.register_module()
class YOLOv6EfficientRep(BaseBackbone):
"""EfficientRep backbone used in YOLOv6.
Args:
arch (str): Architecture of BaseDarknet, from {P5, P6}.
Defaults to P5.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels (int): Number of input image channels. Defaults to 3.
out_indices (Tuple[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
block (nn.Module): block used to build each stage.
init_cfg (Union[dict, list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv6EfficientRep
>>> import torch
>>> model = YOLOv6EfficientRep()
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# From left to right:
# in_channels, out_channels, num_blocks, use_spp
arch_settings = {
'P5': [[64, 128, 6, False], [128, 256, 12, False],
[256, 512, 18, False], [512, 1024, 6, True]]
}
def __init__(self,
arch: str = 'P5',
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
norm_eval: bool = False,
block: nn.Module = RepVGGBlock,
init_cfg: OptMultiConfig = None):
self.block = block
super().__init__(
self.arch_settings[arch],
deepen_factor,
widen_factor,
input_channels=input_channels,
out_indices=out_indices,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
init_cfg=init_cfg)
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
return self.block(
in_channels=self.input_channels,
out_channels=make_divisible(self.arch_setting[0][0],
self.widen_factor),
kernel_size=3,
stride=2)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, num_blocks, use_spp = setting
in_channels = make_divisible(in_channels, self.widen_factor)
out_channels = make_divisible(out_channels, self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
stage = []
ef_block = nn.Sequential(
self.block(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2),
RepStageBlock(
in_channels=out_channels,
out_channels=out_channels,
n=num_blocks,
block=self.block,
))
stage.append(ef_block)
if use_spp:
spp = SPPFBottleneck(
in_channels=out_channels,
out_channels=out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_yolo_neck import BaseYOLONeck
from .yolov5_pafpn import YOLOv5PAFPN
from .yolov6_pafpn import YOLOv6RepPAFPN
from .yolox_pafpn import YOLOXPAFPN
__all__ = ['YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN']

View File

@ -0,0 +1,150 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List
import torch
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.registry import MODELS
@MODELS.register_module()
class BaseYOLONeck(BaseModule, metaclass=ABCMeta):
"""Base neck used in YOLO series.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
freeze_all(bool): Whether to freeze the model. Defaults to False
norm_cfg (dict): Config dict for normalization layer.
Defaults to None.
act_cfg (dict): Config dict for activation layer.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
freeze_all: bool = False,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.deepen_factor = deepen_factor
self.widen_factor = widen_factor
self.freeze_all = freeze_all
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.reduce_layers = nn.ModuleList()
for idx in range(len(in_channels)):
self.reduce_layers.append(self.build_reduce_layer(idx))
# build top-down blocks
self.upsample_layers = nn.ModuleList()
self.top_down_layers = nn.ModuleList()
for idx in range(len(in_channels) - 1, 0, -1):
self.upsample_layers.append(self.build_upsample_layer(idx))
self.top_down_layers.append(self.build_top_down_layer(idx))
# build bottom-up blocks
self.downsample_layers = nn.ModuleList()
self.bottom_up_layers = nn.ModuleList()
for idx in range(len(in_channels) - 1):
self.downsample_layers.append(self.build_downsample_layer(idx))
self.bottom_up_layers.append(self.build_bottom_up_layer(idx))
self.out_layers = nn.ModuleList()
for idx in range(len(in_channels)):
self.out_layers.append(self.build_out_layer(idx))
@abstractmethod
def build_reduce_layer(self, idx):
pass
@abstractmethod
def build_upsample_layer(self, idx):
pass
@abstractmethod
def build_top_down_layer(self, idx):
pass
@abstractmethod
def build_downsample_layer(self, idx):
pass
@abstractmethod
def build_bottom_up_layer(self, idx):
pass
@abstractmethod
def build_out_layer(self, idx):
pass
def _freeze_all(self):
"""Freeze the model."""
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep the normalization
layer freezed."""
super().train(mode)
if self.freeze_all:
self._freeze_all()
def forward(self, inputs: torch.Tensor) -> tuple:
"""Forward function."""
assert len(inputs) == len(self.in_channels)
# reduce layers
reduce_outs = []
for idx in range(len(self.in_channels)):
reduce_outs.append(self.reduce_layers[idx](inputs[idx]))
# top-down path
inner_outs = [reduce_outs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = reduce_outs[idx - 1]
upsample_feat = self.upsample_layers[len(self.in_channels) - 1 -
idx](
feat_heigh)
inner_out = self.top_down_layers[len(self.in_channels) - 1 - idx](
torch.cat([upsample_feat, feat_low], 1))
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsample_layers[idx](feat_low)
out = self.bottom_up_layers[idx](
torch.cat([downsample_feat, feat_height], 1))
outs.append(out)
# out_layers
results = []
for idx in range(len(self.in_channels)):
results.append(self.out_layers[idx](outs[idx]))
return tuple(results)

View File

@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.models.backbones.csp_darknet import CSPLayer
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..utils import make_divisible, make_round
from .base_yolo_neck import BaseYOLONeck
@MODELS.register_module()
class YOLOv5PAFPN(BaseYOLONeck):
"""Path Aggregation Network used in YOLOv5.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
num_csp_blocks: int = 1,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
self.num_csp_blocks = num_csp_blocks
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()
def build_reduce_layer(self, idx: int) -> nn.Module:
if idx == 2:
layer = ConvModule(
make_divisible(self.in_channels[idx], self.widen_factor),
make_divisible(self.in_channels[idx - 1], self.widen_factor),
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
layer = nn.Identity()
return layer
def build_upsample_layer(self, *args, **kwargs) -> nn.Module:
return nn.Upsample(scale_factor=2, mode='nearest')
def build_top_down_layer(self, idx: int):
if idx == 1:
return CSPLayer(
make_divisible(self.in_channels[idx - 1] * 2,
self.widen_factor),
make_divisible(self.in_channels[idx - 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
elif idx == 2:
return nn.Sequential(
CSPLayer(
make_divisible(self.in_channels[idx - 1] * 2,
self.widen_factor),
make_divisible(self.in_channels[idx - 1],
self.widen_factor),
num_blocks=make_round(self.num_csp_blocks,
self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
make_divisible(self.in_channels[idx - 1],
self.widen_factor),
make_divisible(self.in_channels[idx - 2],
self.widen_factor),
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def build_downsample_layer(self, idx: int) -> nn.Module:
return ConvModule(
make_divisible(self.in_channels[idx], self.widen_factor),
make_divisible(self.in_channels[idx], self.widen_factor),
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
return CSPLayer(
make_divisible(self.in_channels[idx] * 2, self.widen_factor),
make_divisible(self.in_channels[idx + 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_out_layer(self, *args, **kwargs) -> nn.Module:
return nn.Identity()

View File

@ -0,0 +1,141 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import RepStageBlock, RepVGGBlock
from ..utils import make_divisible, make_round
from .base_yolo_neck import BaseYOLONeck
@MODELS.register_module()
class YOLOv6RepPAFPN(BaseYOLONeck):
"""Path Aggregation Network used in YOLOv6.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='ReLU', inplace=True).
block (nn.Module): block used to build each layer.
Defaults to RepVGGBlock.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
num_csp_blocks: int = 12,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
block: nn.Module = RepVGGBlock,
init_cfg: OptMultiConfig = None):
self.num_csp_blocks = num_csp_blocks
self.block = block
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def build_reduce_layer(self, idx: int) -> nn.Module:
if idx == 2:
layer = ConvModule(
in_channels=make_divisible(self.in_channels[idx],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
layer = nn.Identity()
return layer
def build_upsample_layer(self, idx: int) -> nn.Module:
return nn.ConvTranspose2d(
in_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
kernel_size=2,
stride=2,
bias=True)
def build_top_down_layer(self, idx: int) -> nn.Module:
layer0 = RepStageBlock(
in_channels=make_divisible(
self.out_channels[idx - 1] + self.in_channels[idx - 1],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
n=make_round(self.num_csp_blocks, self.deepen_factor),
block=self.block)
if idx == 1:
return layer0
elif idx == 2:
layer1 = ConvModule(
in_channels=make_divisible(self.out_channels[idx - 1],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx - 2],
self.widen_factor),
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return nn.Sequential(layer0, layer1)
def build_downsample_layer(self, idx: int) -> nn.Module:
return ConvModule(
in_channels=make_divisible(self.out_channels[idx],
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx],
self.widen_factor),
kernel_size=3,
stride=2,
padding=3 // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
return RepStageBlock(
in_channels=make_divisible(self.out_channels[idx] * 2,
self.widen_factor),
out_channels=make_divisible(self.out_channels[idx + 1],
self.widen_factor),
n=make_round(self.num_csp_blocks, self.deepen_factor),
block=self.block)
def build_out_layer(self, *args, **kwargs) -> nn.Module:
return nn.Identity()
def init_weights(self):
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,
# reset the Conv2d initialization parameters
m.reset_parameters()

View File

@ -0,0 +1,125 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.models.backbones.csp_darknet import CSPLayer
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from .base_yolo_neck import BaseYOLONeck
@MODELS.register_module()
class YOLOXPAFPN(BaseYOLONeck):
"""Path Aggregation Network used in YOLOX.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model. Defaults to False.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
num_csp_blocks: int = 3,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
self.num_csp_blocks = round(num_csp_blocks * deepen_factor)
super().__init__(
in_channels=[
int(channel * widen_factor) for channel in in_channels
],
out_channels=int(out_channels * widen_factor),
deepen_factor=deepen_factor,
widen_factor=widen_factor,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def build_reduce_layer(self, idx: int) -> nn.Module:
if idx == 2:
layer = ConvModule(
self.in_channels[idx],
self.in_channels[idx - 1],
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
layer = nn.Identity()
return layer
def build_upsample_layer(self, *args, **kwargs) -> nn.Module:
return nn.Upsample(scale_factor=2, mode='nearest')
def build_top_down_layer(self, idx: int) -> nn.Module:
if idx == 1:
return CSPLayer(
self.in_channels[idx - 1] * 2,
self.in_channels[idx - 1],
num_blocks=self.num_csp_blocks,
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
elif idx == 2:
return nn.Sequential(
CSPLayer(
self.in_channels[idx - 1] * 2,
self.in_channels[idx - 1],
num_blocks=self.num_csp_blocks,
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
self.in_channels[idx - 1],
self.in_channels[idx - 2],
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def build_downsample_layer(self, idx: int) -> nn.Module:
return ConvModule(
self.in_channels[idx],
self.in_channels[idx],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
return CSPLayer(
self.in_channels[idx] * 2,
self.in_channels[idx + 1],
num_blocks=self.num_csp_blocks,
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_out_layer(self, idx: int) -> nn.Module:
return ConvModule(
self.in_channels[idx],
self.out_channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

View File

@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from parameterized import parameterized
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.models.backbones import YOLOv5CSPDarknet, YOLOXCSPDarknet
from mmyolo.utils import register_all_modules
from .utils import check_norm_state, is_norm
register_all_modules()
class TestCSPDarknet(TestCase):
@parameterized.expand([(YOLOv5CSPDarknet, ), (YOLOXCSPDarknet, )])
def test_init(self, module_class):
# out_indices in range(len(arch_setting) + 1)
with pytest.raises(AssertionError):
module_class(out_indices=(6, ))
with pytest.raises(ValueError):
# frozen_stages must in range(-1, len(arch_setting) + 1)
module_class(frozen_stages=6)
@parameterized.expand([(YOLOv5CSPDarknet, ), (YOLOXCSPDarknet, )])
def test_forward(self, module_class):
# Test CSPDarknet with first stage frozen
frozen_stages = 1
model = module_class(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'stage{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test CSPDarknet with norm_eval=True
model = module_class(norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test CSPDarknet-P5 forward with widen_factor=0.25
model = module_class(
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 32, 16, 16))
assert feat[2].shape == torch.Size((1, 64, 8, 8))
assert feat[3].shape == torch.Size((1, 128, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))
# Test CSPDarknet forward with dict(type='ReLU')
model = module_class(
widen_factor=0.125,
act_cfg=dict(type='ReLU'),
out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test CSPDarknet with BatchNorm forward
model = module_class(widen_factor=0.125, out_indices=range(0, 5))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))

View File

@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.models.backbones import YOLOv6EfficientRep
from mmyolo.utils import register_all_modules
from .utils import check_norm_state, is_norm
register_all_modules()
class TestYOLOv6EfficientRep(TestCase):
def test_init(self):
# out_indices in range(len(arch_setting) + 1)
with pytest.raises(AssertionError):
YOLOv6EfficientRep(out_indices=(6, ))
with pytest.raises(ValueError):
# frozen_stages must in range(-1, len(arch_setting) + 1)
YOLOv6EfficientRep(frozen_stages=6)
def test_forward(self):
# Test YOLOv6EfficientRep with first stage frozen
frozen_stages = 1
model = YOLOv6EfficientRep(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'stage{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test YOLOv6EfficientRep with norm_eval=True
model = YOLOv6EfficientRep(norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test YOLOv6EfficientRep-P5 forward with widen_factor=0.25
model = YOLOv6EfficientRep(
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 32, 16, 16))
assert feat[2].shape == torch.Size((1, 64, 8, 8))
assert feat[3].shape == torch.Size((1, 128, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))
# Test YOLOv6EfficientRep forward with dict(type='ReLU')
model = YOLOv6EfficientRep(
widen_factor=0.125,
act_cfg=dict(type='ReLU'),
out_indices=range(0, 5))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test YOLOv6EfficientRep with BatchNorm forward
model = YOLOv6EfficientRep(widen_factor=0.125, out_indices=range(0, 5))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 8, 32, 32))
assert feat[1].shape == torch.Size((1, 16, 16, 16))
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones.res2net import Bottle2neck
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
from mmdet.models.layers import SimplifiedBasicBlock
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
SimplifiedBasicBlock)):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True