mirror of https://github.com/open-mmlab/mmyolo.git
parent
9059d1b446
commit
85e504fe67
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_backbone import BaseBackbone
|
||||
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
|
||||
from .efficient_rep import YOLOv6EfficientRep
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOXCSPDarknet'
|
||||
]
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
||||
"""BaseBackbone backbone used in YOLO series.
|
||||
|
||||
Args:
|
||||
arch_setting (dict): Architecture of BaseBackbone.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels: Number of input image channels. Defaults to 3.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to None.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to None.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch_setting: dict,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
out_indices: tuple = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
norm_cfg: ConfigType = None,
|
||||
act_cfg: ConfigType = None,
|
||||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.num_stages = len(arch_setting)
|
||||
self.arch_setting = arch_setting
|
||||
|
||||
assert set(out_indices).issubset(
|
||||
i for i in range(len(arch_setting) + 1))
|
||||
|
||||
if frozen_stages not in range(-1, len(arch_setting) + 1):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
'len(arch_setting) + 1). But received '
|
||||
f'{frozen_stages}')
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.widen_factor = widen_factor
|
||||
self.deepen_factor = deepen_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.stem = self.build_stem_layer()
|
||||
self.layers = ['stem']
|
||||
|
||||
for idx, setting in enumerate(arch_setting):
|
||||
stage = []
|
||||
stage += self.build_stage_layer(idx, setting)
|
||||
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
|
||||
self.layers.append(f'stage{idx + 1}')
|
||||
|
||||
@abstractmethod
|
||||
def build_stem_layer(self):
|
||||
"""Build a stem layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_stage_layer(self, stage_idx, setting):
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze the parameters of the specified stage so that they are no
|
||||
longer updated."""
|
||||
if self.frozen_stages >= 0:
|
||||
for i in range(self.frozen_stages + 1):
|
||||
m = getattr(self, self.layers[i])
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> tuple:
|
||||
"""Forward batch_inputs from the data_preprocessor."""
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import SPPFBottleneck
|
||||
from ..utils import make_divisible, make_round
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv5CSPDarknet(BaseBackbone):
|
||||
"""CSP-Darknet backbone used in YOLOv5.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels (int): Number of input image channels. Defaults to: 3.
|
||||
out_indices (Tuple[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv5CSPDarknet
|
||||
>>> import torch
|
||||
>>> model = YOLOv5CSPDarknet()
|
||||
>>> model.eval()
|
||||
>>> inputs = torch.rand(1, 3, 416, 416)
|
||||
>>> level_outputs = model(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
...
|
||||
(1, 256, 52, 52)
|
||||
(1, 512, 26, 26)
|
||||
(1, 1024, 13, 13)
|
||||
"""
|
||||
# From left to right:
|
||||
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
||||
arch_settings = {
|
||||
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
||||
[256, 512, 9, True, False], [512, 1024, 3, True, True]]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None) -> None:
|
||||
super().__init__(
|
||||
self.arch_settings[arch],
|
||||
deepen_factor,
|
||||
widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
norm_eval=norm_eval,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
return ConvModule(
|
||||
self.input_channels,
|
||||
make_divisible(self.arch_setting[0][0], self.widen_factor),
|
||||
kernel_size=6,
|
||||
stride=2,
|
||||
padding=2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
||||
|
||||
in_channels = make_divisible(in_channels, self.widen_factor)
|
||||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
stage = []
|
||||
conv_layer = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(conv_layer)
|
||||
csp_layer = CSPLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
num_blocks=num_blocks,
|
||||
add_identity=add_identity,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(csp_layer)
|
||||
if use_spp:
|
||||
spp = SPPFBottleneck(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_sizes=5,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(spp)
|
||||
return stage
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOXCSPDarknet(BaseBackbone):
|
||||
"""CSP-Darknet backbone used in YOLOX.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels (int): Number of input image channels. Defaults to 3.
|
||||
out_indices (Tuple[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
|
||||
layers. Defaults to (5, 9, 13).
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOXCSPDarknet
|
||||
>>> import torch
|
||||
>>> model = YOLOXCSPDarknet()
|
||||
>>> model.eval()
|
||||
>>> inputs = torch.rand(1, 3, 416, 416)
|
||||
>>> level_outputs = model(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
...
|
||||
(1, 256, 52, 52)
|
||||
(1, 512, 26, 26)
|
||||
(1, 1024, 13, 13)
|
||||
"""
|
||||
# From left to right:
|
||||
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
||||
arch_settings = {
|
||||
'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
|
||||
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
spp_kernal_sizes: Tuple[int] = (5, 9, 13),
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.spp_kernal_sizes = spp_kernal_sizes
|
||||
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
|
||||
input_channels, out_indices, frozen_stages, norm_cfg,
|
||||
act_cfg, norm_eval, init_cfg)
|
||||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
return Focus(
|
||||
3,
|
||||
make_divisible(64, self.widen_factor),
|
||||
kernel_size=3,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
||||
|
||||
in_channels = make_divisible(in_channels, self.widen_factor)
|
||||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
stage = []
|
||||
conv_layer = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(conv_layer)
|
||||
if use_spp:
|
||||
spp = SPPFBottleneck(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_sizes=self.spp_kernal_sizes,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(spp)
|
||||
csp_layer = CSPLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
num_blocks=num_blocks,
|
||||
add_identity=add_identity,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(csp_layer)
|
||||
return stage
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import RepStageBlock, RepVGGBlock
|
||||
from ..utils import make_divisible, make_round
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv6EfficientRep(BaseBackbone):
|
||||
"""EfficientRep backbone used in YOLOv6.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels (int): Number of input image channels. Defaults to 3.
|
||||
out_indices (Tuple[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
block (nn.Module): block used to build each stage.
|
||||
init_cfg (Union[dict, list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv6EfficientRep
|
||||
>>> import torch
|
||||
>>> model = YOLOv6EfficientRep()
|
||||
>>> model.eval()
|
||||
>>> inputs = torch.rand(1, 3, 416, 416)
|
||||
>>> level_outputs = model(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
...
|
||||
(1, 256, 52, 52)
|
||||
(1, 512, 26, 26)
|
||||
(1, 1024, 13, 13)
|
||||
"""
|
||||
# From left to right:
|
||||
# in_channels, out_channels, num_blocks, use_spp
|
||||
arch_settings = {
|
||||
'P5': [[64, 128, 6, False], [128, 256, 12, False],
|
||||
[256, 512, 18, False], [512, 1024, 6, True]]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
block: nn.Module = RepVGGBlock,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.block = block
|
||||
super().__init__(
|
||||
self.arch_settings[arch],
|
||||
deepen_factor,
|
||||
widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
norm_eval=norm_eval,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
return self.block(
|
||||
in_channels=self.input_channels,
|
||||
out_channels=make_divisible(self.arch_setting[0][0],
|
||||
self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
in_channels, out_channels, num_blocks, use_spp = setting
|
||||
|
||||
in_channels = make_divisible(in_channels, self.widen_factor)
|
||||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
|
||||
stage = []
|
||||
|
||||
ef_block = nn.Sequential(
|
||||
self.block(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2),
|
||||
RepStageBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
n=num_blocks,
|
||||
block=self.block,
|
||||
))
|
||||
stage.append(ef_block)
|
||||
|
||||
if use_spp:
|
||||
spp = SPPFBottleneck(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=5,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(spp)
|
||||
return stage
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_yolo_neck import BaseYOLONeck
|
||||
from .yolov5_pafpn import YOLOv5PAFPN
|
||||
from .yolov6_pafpn import YOLOv6RepPAFPN
|
||||
from .yolox_pafpn import YOLOXPAFPN
|
||||
|
||||
__all__ = ['YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN']
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseYOLONeck(BaseModule, metaclass=ABCMeta):
|
||||
"""Base neck used in YOLO series.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
freeze_all(bool): Whether to freeze the model. Defaults to False
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to None.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = None,
|
||||
act_cfg: ConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deepen_factor = deepen_factor
|
||||
self.widen_factor = widen_factor
|
||||
self.freeze_all = freeze_all
|
||||
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.reduce_layers = nn.ModuleList()
|
||||
for idx in range(len(in_channels)):
|
||||
self.reduce_layers.append(self.build_reduce_layer(idx))
|
||||
|
||||
# build top-down blocks
|
||||
self.upsample_layers = nn.ModuleList()
|
||||
self.top_down_layers = nn.ModuleList()
|
||||
for idx in range(len(in_channels) - 1, 0, -1):
|
||||
self.upsample_layers.append(self.build_upsample_layer(idx))
|
||||
self.top_down_layers.append(self.build_top_down_layer(idx))
|
||||
|
||||
# build bottom-up blocks
|
||||
self.downsample_layers = nn.ModuleList()
|
||||
self.bottom_up_layers = nn.ModuleList()
|
||||
for idx in range(len(in_channels) - 1):
|
||||
self.downsample_layers.append(self.build_downsample_layer(idx))
|
||||
self.bottom_up_layers.append(self.build_bottom_up_layer(idx))
|
||||
|
||||
self.out_layers = nn.ModuleList()
|
||||
for idx in range(len(in_channels)):
|
||||
self.out_layers.append(self.build_out_layer(idx))
|
||||
|
||||
@abstractmethod
|
||||
def build_reduce_layer(self, idx):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_upsample_layer(self, idx):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_top_down_layer(self, idx):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_downsample_layer(self, idx):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_bottom_up_layer(self, idx):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_out_layer(self, idx):
|
||||
pass
|
||||
|
||||
def _freeze_all(self):
|
||||
"""Freeze the model."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
if self.freeze_all:
|
||||
self._freeze_all()
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> tuple:
|
||||
"""Forward function."""
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
# reduce layers
|
||||
reduce_outs = []
|
||||
for idx in range(len(self.in_channels)):
|
||||
reduce_outs.append(self.reduce_layers[idx](inputs[idx]))
|
||||
|
||||
# top-down path
|
||||
inner_outs = [reduce_outs[-1]]
|
||||
for idx in range(len(self.in_channels) - 1, 0, -1):
|
||||
feat_heigh = inner_outs[0]
|
||||
feat_low = reduce_outs[idx - 1]
|
||||
upsample_feat = self.upsample_layers[len(self.in_channels) - 1 -
|
||||
idx](
|
||||
feat_heigh)
|
||||
|
||||
inner_out = self.top_down_layers[len(self.in_channels) - 1 - idx](
|
||||
torch.cat([upsample_feat, feat_low], 1))
|
||||
inner_outs.insert(0, inner_out)
|
||||
|
||||
# bottom-up path
|
||||
outs = [inner_outs[0]]
|
||||
for idx in range(len(self.in_channels) - 1):
|
||||
feat_low = outs[-1]
|
||||
feat_height = inner_outs[idx + 1]
|
||||
downsample_feat = self.downsample_layers[idx](feat_low)
|
||||
out = self.bottom_up_layers[idx](
|
||||
torch.cat([downsample_feat, feat_height], 1))
|
||||
outs.append(out)
|
||||
|
||||
# out_layers
|
||||
results = []
|
||||
for idx in range(len(self.in_channels)):
|
||||
results.append(self.out_layers[idx](outs[idx]))
|
||||
|
||||
return tuple(results)
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.models.backbones.csp_darknet import CSPLayer
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..utils import make_divisible, make_round
|
||||
from .base_yolo_neck import BaseYOLONeck
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv5PAFPN(BaseYOLONeck):
|
||||
"""Path Aggregation Network used in YOLOv5.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
freeze_all(bool): Whether to freeze the model
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
num_csp_blocks: int = 1,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_csp_blocks = num_csp_blocks
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
freeze_all=freeze_all,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
|
||||
def build_reduce_layer(self, idx: int) -> nn.Module:
|
||||
if idx == 2:
|
||||
layer = ConvModule(
|
||||
make_divisible(self.in_channels[idx], self.widen_factor),
|
||||
make_divisible(self.in_channels[idx - 1], self.widen_factor),
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
layer = nn.Identity()
|
||||
|
||||
return layer
|
||||
|
||||
def build_upsample_layer(self, *args, **kwargs) -> nn.Module:
|
||||
return nn.Upsample(scale_factor=2, mode='nearest')
|
||||
|
||||
def build_top_down_layer(self, idx: int):
|
||||
if idx == 1:
|
||||
return CSPLayer(
|
||||
make_divisible(self.in_channels[idx - 1] * 2,
|
||||
self.widen_factor),
|
||||
make_divisible(self.in_channels[idx - 1], self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
elif idx == 2:
|
||||
return nn.Sequential(
|
||||
CSPLayer(
|
||||
make_divisible(self.in_channels[idx - 1] * 2,
|
||||
self.widen_factor),
|
||||
make_divisible(self.in_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks,
|
||||
self.deepen_factor),
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
make_divisible(self.in_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
make_divisible(self.in_channels[idx - 2],
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def build_downsample_layer(self, idx: int) -> nn.Module:
|
||||
return ConvModule(
|
||||
make_divisible(self.in_channels[idx], self.widen_factor),
|
||||
make_divisible(self.in_channels[idx], self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_bottom_up_layer(self, idx: int) -> nn.Module:
|
||||
return CSPLayer(
|
||||
make_divisible(self.in_channels[idx] * 2, self.widen_factor),
|
||||
make_divisible(self.in_channels[idx + 1], self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_out_layer(self, *args, **kwargs) -> nn.Module:
|
||||
return nn.Identity()
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import RepStageBlock, RepVGGBlock
|
||||
from ..utils import make_divisible, make_round
|
||||
from .base_yolo_neck import BaseYOLONeck
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv6RepPAFPN(BaseYOLONeck):
|
||||
"""Path Aggregation Network used in YOLOv6.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
freeze_all(bool): Whether to freeze the model.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='ReLU', inplace=True).
|
||||
block (nn.Module): block used to build each layer.
|
||||
Defaults to RepVGGBlock.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
num_csp_blocks: int = 12,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
block: nn.Module = RepVGGBlock,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_csp_blocks = num_csp_blocks
|
||||
self.block = block
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
freeze_all=freeze_all,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_reduce_layer(self, idx: int) -> nn.Module:
|
||||
if idx == 2:
|
||||
layer = ConvModule(
|
||||
in_channels=make_divisible(self.in_channels[idx],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
layer = nn.Identity()
|
||||
|
||||
return layer
|
||||
|
||||
def build_upsample_layer(self, idx: int) -> nn.Module:
|
||||
return nn.ConvTranspose2d(
|
||||
in_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
bias=True)
|
||||
|
||||
def build_top_down_layer(self, idx: int) -> nn.Module:
|
||||
layer0 = RepStageBlock(
|
||||
in_channels=make_divisible(
|
||||
self.out_channels[idx - 1] + self.in_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
n=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block=self.block)
|
||||
if idx == 1:
|
||||
return layer0
|
||||
elif idx == 2:
|
||||
layer1 = ConvModule(
|
||||
in_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 2],
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
return nn.Sequential(layer0, layer1)
|
||||
|
||||
def build_downsample_layer(self, idx: int) -> nn.Module:
|
||||
return ConvModule(
|
||||
in_channels=make_divisible(self.out_channels[idx],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx],
|
||||
self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=3 // 2,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_bottom_up_layer(self, idx: int) -> nn.Module:
|
||||
return RepStageBlock(
|
||||
in_channels=make_divisible(self.out_channels[idx] * 2,
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx + 1],
|
||||
self.widen_factor),
|
||||
n=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block=self.block)
|
||||
|
||||
def build_out_layer(self, *args, **kwargs) -> nn.Module:
|
||||
return nn.Identity()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.models.backbones.csp_darknet import CSPLayer
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from .base_yolo_neck import BaseYOLONeck
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOXPAFPN(BaseYOLONeck):
|
||||
"""Path Aggregation Network used in YOLOX.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale).
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
freeze_all(bool): Whether to freeze the model. Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
num_csp_blocks: int = 3,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_csp_blocks = round(num_csp_blocks * deepen_factor)
|
||||
|
||||
super().__init__(
|
||||
in_channels=[
|
||||
int(channel * widen_factor) for channel in in_channels
|
||||
],
|
||||
out_channels=int(out_channels * widen_factor),
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
freeze_all=freeze_all,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_reduce_layer(self, idx: int) -> nn.Module:
|
||||
if idx == 2:
|
||||
layer = ConvModule(
|
||||
self.in_channels[idx],
|
||||
self.in_channels[idx - 1],
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
else:
|
||||
layer = nn.Identity()
|
||||
|
||||
return layer
|
||||
|
||||
def build_upsample_layer(self, *args, **kwargs) -> nn.Module:
|
||||
return nn.Upsample(scale_factor=2, mode='nearest')
|
||||
|
||||
def build_top_down_layer(self, idx: int) -> nn.Module:
|
||||
if idx == 1:
|
||||
return CSPLayer(
|
||||
self.in_channels[idx - 1] * 2,
|
||||
self.in_channels[idx - 1],
|
||||
num_blocks=self.num_csp_blocks,
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
elif idx == 2:
|
||||
return nn.Sequential(
|
||||
CSPLayer(
|
||||
self.in_channels[idx - 1] * 2,
|
||||
self.in_channels[idx - 1],
|
||||
num_blocks=self.num_csp_blocks,
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
self.in_channels[idx - 1],
|
||||
self.in_channels[idx - 2],
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def build_downsample_layer(self, idx: int) -> nn.Module:
|
||||
return ConvModule(
|
||||
self.in_channels[idx],
|
||||
self.in_channels[idx],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_bottom_up_layer(self, idx: int) -> nn.Module:
|
||||
return CSPLayer(
|
||||
self.in_channels[idx] * 2,
|
||||
self.in_channels[idx + 1],
|
||||
num_blocks=self.num_csp_blocks,
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_out_layer(self, idx: int) -> nn.Module:
|
||||
return ConvModule(
|
||||
self.in_channels[idx],
|
||||
self.out_channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmyolo.models.backbones import YOLOv5CSPDarknet, YOLOXCSPDarknet
|
||||
from mmyolo.utils import register_all_modules
|
||||
from .utils import check_norm_state, is_norm
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestCSPDarknet(TestCase):
|
||||
|
||||
@parameterized.expand([(YOLOv5CSPDarknet, ), (YOLOXCSPDarknet, )])
|
||||
def test_init(self, module_class):
|
||||
# out_indices in range(len(arch_setting) + 1)
|
||||
with pytest.raises(AssertionError):
|
||||
module_class(out_indices=(6, ))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must in range(-1, len(arch_setting) + 1)
|
||||
module_class(frozen_stages=6)
|
||||
|
||||
@parameterized.expand([(YOLOv5CSPDarknet, ), (YOLOXCSPDarknet, )])
|
||||
def test_forward(self, module_class):
|
||||
# Test CSPDarknet with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = module_class(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for mod in model.stem.modules():
|
||||
for param in mod.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'stage{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test CSPDarknet with norm_eval=True
|
||||
model = module_class(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test CSPDarknet-P5 forward with widen_factor=0.25
|
||||
model = module_class(
|
||||
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 16, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 32, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 64, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 128, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 256, 2, 2))
|
||||
|
||||
# Test CSPDarknet forward with dict(type='ReLU')
|
||||
model = module_class(
|
||||
widen_factor=0.125,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test CSPDarknet with BatchNorm forward
|
||||
model = module_class(widen_factor=0.125, out_indices=range(0, 5))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmyolo.models.backbones import YOLOv6EfficientRep
|
||||
from mmyolo.utils import register_all_modules
|
||||
from .utils import check_norm_state, is_norm
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestYOLOv6EfficientRep(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
# out_indices in range(len(arch_setting) + 1)
|
||||
with pytest.raises(AssertionError):
|
||||
YOLOv6EfficientRep(out_indices=(6, ))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must in range(-1, len(arch_setting) + 1)
|
||||
YOLOv6EfficientRep(frozen_stages=6)
|
||||
|
||||
def test_forward(self):
|
||||
# Test YOLOv6EfficientRep with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = YOLOv6EfficientRep(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for mod in model.stem.modules():
|
||||
for param in mod.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'stage{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test YOLOv6EfficientRep with norm_eval=True
|
||||
model = YOLOv6EfficientRep(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test YOLOv6EfficientRep-P5 forward with widen_factor=0.25
|
||||
model = YOLOv6EfficientRep(
|
||||
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 16, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 32, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 64, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 128, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 256, 2, 2))
|
||||
|
||||
# Test YOLOv6EfficientRep forward with dict(type='ReLU')
|
||||
model = YOLOv6EfficientRep(
|
||||
widen_factor=0.125,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test YOLOv6EfficientRep with BatchNorm forward
|
||||
model = YOLOv6EfficientRep(widen_factor=0.125, out_indices=range(0, 5))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdet.models.backbones.res2net import Bottle2neck
|
||||
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
from mmdet.models.layers import SimplifiedBasicBlock
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
|
||||
SimplifiedBasicBlock)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
Loading…
Reference in New Issue