mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support YOLOv6 ML model (#265)
* align larger infer * align large test * update docstr * add ut * Add yolov6m pipeline * Add yolov6m pipeline * Improve coding * update l config * align m test * rename config * support training * update * block_cfg -> stage_block_cfg * update * Add docstr * Fix lint * update convert script * update * rename param * update readme&metafile * rm duplicate docde * fix lint * update * update * update Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>pull/278/head
parent
178b0bf3ff
commit
ab4e7c5158
|
@ -16,17 +16,20 @@ For years, YOLO series have been de facto industry-level standard for efficient
|
|||
|
||||
### COCO
|
||||
|
||||
| Backbone | Arch | size | SyncBN | AMP | Mem (GB) | box AP | Config | Download |
|
||||
| :------: | :--: | :--: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOv6-n | P5 | 640 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
|
||||
| YOLOv6-t | P5 | 640 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
|
||||
| YOLOv6-s | P5 | 640 | Yes | Yes | 8.88 | 44.0 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035.log.json) |
|
||||
| Backbone | Arch | Size | Epoch | SyncBN | AMP | Mem (GB) | Box AP | Config | Download |
|
||||
| :------: | :--: | :--: | :---: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOv6-n | P5 | 640 | 400 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
|
||||
| YOLOv6-t | P5 | 640 | 400 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
|
||||
| YOLOv6-s | P5 | 640 | 400 | Yes | Yes | 8.88 | 44.0 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035.log.json) |
|
||||
| YOLOv6-m | P5 | 640 | 300 | Yes | Yes | 16.69 | 48.4 | [config](../yolov6/yolov6_m_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658-85bda3f4.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658.log.json) |
|
||||
| YOLOv6-l | P5 | 640 | 300 | Yes | Yes | 20.86 | 51.0 | [config](../yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156-91e3c447.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. The performance is unstable and may fluctuate by about 0.3 mAP.
|
||||
2. YOLOv6-m,l,x will be supported in later version.
|
||||
3. If users need the weight of 300 epoch, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
|
||||
1. The official m and l models use knowledge distillation, but our version does not support it, which will be implemented in [MMRazor](https://github.com/open-mmlab/mmrazor) in the future.
|
||||
2. The performance is unstable and may fluctuate by about 0.3 mAP.
|
||||
3. If users need the weight of 300 epoch for nano, tiny and small model, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
|
||||
4. We have observed that the [base model](https://github.com/meituan/YOLOv6/tree/main/configs/base) has been officially released in v6 recently. Although the accuracy has decreased, it is more efficient. We will also provide the base model configuration in the future.
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -57,3 +57,27 @@ Models:
|
|||
Metrics:
|
||||
box AP: 41.0
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth
|
||||
- Name: yolov6_m_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: YOLOv6
|
||||
Config: configs/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 16.69
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 48.4
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_m_syncbn_fast_8xb32-300e_coco/yolov6_m_syncbn_fast_8xb32-300e_coco_20221109_182658-85bda3f4.pth
|
||||
- Name: yolov6_l_syncbn_fast_8xb32-300e_coco
|
||||
In Collection: YOLOv6
|
||||
Config: configs/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 20.86
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 51.0
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco/yolov6_l_syncbn_fast_8xb32-300e_coco_20221109_183156-91e3c447.pth
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
_base_ = './yolov6_m_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 1
|
||||
widen_factor = 1
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
hidden_ratio=1. / 2,
|
||||
block_cfg=dict(
|
||||
type='ConvWrapper',
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
neck=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
hidden_ratio=1. / 2,
|
||||
block_cfg=dict(
|
||||
type='ConvWrapper',
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
|
||||
block_act_cfg=dict(type='SiLU', inplace=True)),
|
||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
|
@ -0,0 +1,54 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.6
|
||||
widen_factor = 0.75
|
||||
affine_scale = 0.9
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
type='YOLOv6CSPBep',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
hidden_ratio=2. / 3,
|
||||
block_cfg=dict(type='RepVGGBlock'),
|
||||
act_cfg=dict(type='ReLU', inplace=True)),
|
||||
neck=dict(
|
||||
type='YOLOv6CSPRepPAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
block_cfg=dict(type='RepVGGBlock'),
|
||||
hidden_ratio=2. / 3,
|
||||
block_act_cfg=dict(type='ReLU', inplace=True)),
|
||||
bbox_head=dict(
|
||||
type='YOLOv6Head', head_module=dict(widen_factor=widen_factor)))
|
||||
|
||||
mosaic_affine_pipeline = [
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=_base_.img_scale,
|
||||
pad_val=114.0,
|
||||
pre_transform=_base_.pre_transform),
|
||||
dict(
|
||||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_shear_degree=0.0,
|
||||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
|
||||
border=(-_base_.img_scale[0] // 2, -_base_.img_scale[1] // 2),
|
||||
border_val=(114, 114, 114))
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
*_base_.pre_transform, *mosaic_affine_pipeline,
|
||||
dict(
|
||||
type='YOLOv5MixUp',
|
||||
prob=0.1,
|
||||
pre_transform=[*_base_.pre_transform, *mosaic_affine_pipeline]),
|
||||
dict(type='YOLOv5HSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
|
@ -12,6 +12,7 @@ num_classes = 80
|
|||
img_scale = (640, 640) # height, width
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.5
|
||||
affine_scale = 0.5
|
||||
save_epoch_intervals = 10
|
||||
train_batch_size_per_gpu = 32
|
||||
train_num_workers = 8
|
||||
|
@ -112,7 +113,7 @@ train_pipeline = [
|
|||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_translate_ratio=0.1,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2),
|
||||
border_val=(114, 114, 114),
|
||||
max_shear_degree=0.0),
|
||||
|
@ -136,7 +137,7 @@ train_pipeline_stage2 = [
|
|||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_translate_ratio=0.1,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
|
||||
max_shear_degree=0.0,
|
||||
),
|
||||
dict(type='YOLOv5HSVRandomAug'),
|
||||
|
|
|
@ -3,10 +3,10 @@ from .base_backbone import BaseBackbone
|
|||
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
|
||||
from .csp_resnet import PPYOLOECSPResNet
|
||||
from .cspnext import CSPNeXt
|
||||
from .efficient_rep import YOLOv6EfficientRep
|
||||
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
|
||||
from .yolov7_backbone import YOLOv7Backbone
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep',
|
||||
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
|
||||
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet'
|
||||
]
|
||||
|
|
|
@ -87,7 +87,6 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.num_stages = len(arch_setting)
|
||||
self.arch_setting = arch_setting
|
||||
|
||||
|
|
|
@ -8,20 +8,18 @@ from mmdet.utils import ConfigType, OptMultiConfig
|
|||
|
||||
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import RepStageBlock, RepVGGBlock
|
||||
from ..utils import make_divisible, make_round
|
||||
from ..layers import BepC3StageBlock, RepStageBlock
|
||||
from ..utils import make_round
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv6EfficientRep(BaseBackbone):
|
||||
"""EfficientRep backbone used in YOLOv6.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
|
@ -41,10 +39,10 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
block (nn.Module): block used to build each stage.
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
init_cfg (Union[dict, list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv6EfficientRep
|
||||
>>> import torch
|
||||
|
@ -78,9 +76,9 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
block: nn.Module = RepVGGBlock,
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.block = block
|
||||
self.block_cfg = block_cfg
|
||||
super().__init__(
|
||||
self.arch_settings[arch],
|
||||
deepen_factor,
|
||||
|
@ -96,12 +94,16 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
return self.block(
|
||||
in_channels=self.input_channels,
|
||||
out_channels=make_divisible(self.arch_setting[0][0],
|
||||
self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
|
||||
block_cfg = self.block_cfg.copy()
|
||||
block_cfg.update(
|
||||
dict(
|
||||
in_channels=self.input_channels,
|
||||
out_channels=int(self.arch_setting[0][0] * self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
))
|
||||
return MODELS.build(block_cfg)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
@ -112,24 +114,28 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
"""
|
||||
in_channels, out_channels, num_blocks, use_spp = setting
|
||||
|
||||
in_channels = make_divisible(in_channels, self.widen_factor)
|
||||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
in_channels = int(in_channels * self.widen_factor)
|
||||
out_channels = int(out_channels * self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
|
||||
stage = []
|
||||
rep_stage_block = RepStageBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
block_cfg=self.block_cfg,
|
||||
)
|
||||
|
||||
ef_block = nn.Sequential(
|
||||
self.block(
|
||||
block_cfg = self.block_cfg.copy()
|
||||
block_cfg.update(
|
||||
dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2),
|
||||
RepStageBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
n=num_blocks,
|
||||
block=self.block,
|
||||
))
|
||||
stride=2))
|
||||
stage = []
|
||||
|
||||
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
|
||||
|
||||
stage.append(ef_block)
|
||||
|
||||
if use_spp:
|
||||
|
@ -152,3 +158,130 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
m.reset_parameters()
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv6CSPBep(YOLOv6EfficientRep):
|
||||
"""CSPBep backbone used in YOLOv6.
|
||||
Args:
|
||||
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels (int): Number of input image channels. Defaults to 3.
|
||||
out_indices (Tuple[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
block_act_cfg (dict): Config dict for activation layer used in each
|
||||
stage. Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (Union[dict, list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv6CSPBep
|
||||
>>> import torch
|
||||
>>> model = YOLOv6CSPBep()
|
||||
>>> model.eval()
|
||||
>>> inputs = torch.rand(1, 3, 416, 416)
|
||||
>>> level_outputs = model(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
...
|
||||
(1, 256, 52, 52)
|
||||
(1, 512, 26, 26)
|
||||
(1, 1024, 13, 13)
|
||||
"""
|
||||
# From left to right:
|
||||
# in_channels, out_channels, num_blocks, use_spp
|
||||
arch_settings = {
|
||||
'P5': [[64, 128, 6, False], [128, 256, 12, False],
|
||||
[256, 512, 18, False], [512, 1024, 6, True]]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
hidden_ratio: float = 0.5,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
block_cfg: ConfigType = dict(type='ConvWrapper'),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.hidden_ratio = hidden_ratio
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
plugins=plugins,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
norm_eval=norm_eval,
|
||||
block_cfg=block_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
in_channels, out_channels, num_blocks, use_spp = setting
|
||||
in_channels = int(in_channels * self.widen_factor)
|
||||
out_channels = int(out_channels * self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
|
||||
rep_stage_block = BepC3StageBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
hidden_ratio=self.hidden_ratio,
|
||||
block_cfg=self.block_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
block_cfg = self.block_cfg.copy()
|
||||
block_cfg.update(
|
||||
dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2))
|
||||
stage = []
|
||||
|
||||
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
|
||||
|
||||
stage.append(ef_block)
|
||||
|
||||
if use_spp:
|
||||
spp = SPPFBottleneck(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=5,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(spp)
|
||||
return stage
|
||||
|
|
|
@ -14,7 +14,6 @@ from mmengine.structures import InstanceData
|
|||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import MODELS, TASK_UTILS
|
||||
from ..utils import make_divisible
|
||||
from .yolov5_head import YOLOv5Head
|
||||
|
||||
|
||||
|
@ -65,12 +64,10 @@ class YOLOv6HeadModule(BaseModule):
|
|||
self.act_cfg = act_cfg
|
||||
|
||||
if isinstance(in_channels, int):
|
||||
self.in_channels = [make_divisible(in_channels, widen_factor)
|
||||
self.in_channels = [int(in_channels * widen_factor)
|
||||
] * self.num_levels
|
||||
else:
|
||||
self.in_channels = [
|
||||
make_divisible(i, widen_factor) for i in in_channels
|
||||
]
|
||||
self.in_channels = [int(i * widen_factor) for i in in_channels]
|
||||
|
||||
self._init_layers()
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ema import ExpMomentumEMA
|
||||
from .yolo_bricks import (EffectiveSELayer, ELANBlock,
|
||||
from .yolo_bricks import (BepC3StageBlock, EffectiveSELayer, ELANBlock,
|
||||
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
|
||||
RepStageBlock, RepVGGBlock, SPPFBottleneck,
|
||||
SPPFCSPBlock)
|
||||
|
@ -8,5 +8,5 @@ from .yolo_bricks import (EffectiveSELayer, ELANBlock,
|
|||
__all__ = [
|
||||
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
|
||||
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
|
||||
'PPYOLOEBasicBlock', 'EffectiveSELayer'
|
||||
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'BepC3StageBlock'
|
||||
]
|
||||
|
|
|
@ -9,6 +9,7 @@ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
|||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import digit_version
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
@ -22,7 +23,7 @@ else:
|
|||
def __init__(self, inplace=True):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, inputs) -> torch.Tensor:
|
||||
def forward(self, inputs) -> Tensor:
|
||||
return inputs * torch.sigmoid(inputs)
|
||||
|
||||
MODELS.register_module(module=SiLU, name='SiLU')
|
||||
|
@ -31,7 +32,6 @@ else:
|
|||
class SPPFBottleneck(BaseModule):
|
||||
"""Spatial pyramid pooling - Fast (SPPF) layer for
|
||||
YOLOv5, YOLOX and PPYOLOE by Glenn Jocher
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
|
@ -100,7 +100,7 @@ class SPPFBottleneck(BaseModule):
|
|||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward process
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
|
@ -118,6 +118,7 @@ class SPPFBottleneck(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RepVGGBlock(nn.Module):
|
||||
"""RepVGGBlock is a basic rep-style block, including training and deploy
|
||||
status This code is based on
|
||||
|
@ -227,11 +228,11 @@ class RepVGGBlock(nn.Module):
|
|||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
"""Forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
@ -270,9 +271,9 @@ class RepVGGBlock(nn.Module):
|
|||
|
||||
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
||||
"""Pad 1x1 tensor to 3x3.
|
||||
|
||||
Args:
|
||||
kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
|
||||
|
||||
Returns:
|
||||
Tensor: 3x3 kernel after padded.
|
||||
"""
|
||||
|
@ -281,14 +282,12 @@ class RepVGGBlock(nn.Module):
|
|||
else:
|
||||
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
||||
|
||||
def _fuse_bn_tensor(self,
|
||||
branch: nn.Module) -> Tuple[np.ndarray, torch.Tensor]:
|
||||
def _fuse_bn_tensor(self, branch: nn.Module) -> Tuple[np.ndarray, Tensor]:
|
||||
"""Derives the equivalent kernel and bias of a specific branch layer.
|
||||
|
||||
Args:
|
||||
branch (nn.Module): The layer that needs to be equivalently
|
||||
transformed, which can be nn.Sequential or nn.Batchnorm2d
|
||||
|
||||
Returns:
|
||||
tuple: Equivalent kernel and bias
|
||||
"""
|
||||
|
@ -348,38 +347,177 @@ class RepVGGBlock(nn.Module):
|
|||
self.deploy = True
|
||||
|
||||
|
||||
class RepStageBlock(nn.Module):
|
||||
"""RepStageBlock is a stage block with rep-style basic block.
|
||||
@MODELS.register_module()
|
||||
class BepC3StageBlock(nn.Module):
|
||||
"""Beer-mug RepC3 Block.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
n (int, tuple[int]): Number of blocks. Defaults to 1.
|
||||
block (nn.Module): Basic unit of RepStage. Defaults to RepVGGBlock.
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
num_blocks (int): Number of blocks. Defaults to 1
|
||||
hidden_ratio (float): Hidden channel expansion.
|
||||
Default: 0.5
|
||||
concat_all_layer (bool): Concat all layer when forward calculate.
|
||||
Default: True
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
norm_cfg (ConfigType): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (ConfigType): Config dict for activation layer.
|
||||
Defaults to dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
n: int = 1,
|
||||
block: nn.Module = RepVGGBlock):
|
||||
num_blocks: int = 1,
|
||||
hidden_ratio: float = 0.5,
|
||||
concat_all_layer: bool = True,
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True)):
|
||||
super().__init__()
|
||||
self.conv1 = block(in_channels, out_channels)
|
||||
self.block = nn.Sequential(*(block(out_channels, out_channels)
|
||||
for _ in range(n - 1))) if n > 1 else None
|
||||
hidden_channels = int(out_channels * hidden_ratio)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward process.
|
||||
Args:
|
||||
inputs (Tensor): The input tensor.
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv3 = ConvModule(
|
||||
2 * hidden_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.block = RepStageBlock(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=hidden_channels,
|
||||
num_blocks=num_blocks,
|
||||
block_cfg=block_cfg,
|
||||
bottle_block=BottleRep)
|
||||
self.concat_all_layer = concat_all_layer
|
||||
if not concat_all_layer:
|
||||
self.conv3 = ConvModule(
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
if self.block is not None:
|
||||
x = self.block(x)
|
||||
return x
|
||||
def forward(self, x):
|
||||
if self.concat_all_layer is True:
|
||||
return self.conv3(
|
||||
torch.cat((self.block(self.conv1(x)), self.conv2(x)), dim=1))
|
||||
else:
|
||||
return self.conv3(self.block(self.conv1(x)))
|
||||
|
||||
|
||||
class BottleRep(nn.Module):
|
||||
"""Bottle Rep Block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
adaptive_weight (bool): Add adaptive_weight when forward calculate.
|
||||
Defaults False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
||||
adaptive_weight: bool = False):
|
||||
super().__init__()
|
||||
conv1_cfg = block_cfg.copy()
|
||||
conv2_cfg = block_cfg.copy()
|
||||
|
||||
conv1_cfg.update(
|
||||
dict(in_channels=in_channels, out_channels=out_channels))
|
||||
conv2_cfg.update(
|
||||
dict(in_channels=out_channels, out_channels=out_channels))
|
||||
|
||||
self.conv1 = MODELS.build(conv1_cfg)
|
||||
self.conv2 = MODELS.build(conv2_cfg)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = False
|
||||
else:
|
||||
self.shortcut = True
|
||||
if adaptive_weight:
|
||||
self.alpha = Parameter(torch.ones(1))
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
outputs = self.conv1(x)
|
||||
outputs = self.conv2(outputs)
|
||||
return outputs + self.alpha * x if self.shortcut else outputs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ConvWrapper(nn.Module):
|
||||
"""Wrapper for normal Conv with SiLU activation.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple): Stride of the convolution. Default: 1
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
bias (bool, optional): Conv bias. Default: True.
|
||||
norm_cfg (ConfigType): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (ConfigType): Config dict for activation layer.
|
||||
Defaults to dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
norm_cfg: ConfigType = None,
|
||||
act_cfg: ConfigType = dict(type='SiLU')):
|
||||
super().__init__()
|
||||
self.block = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=kernel_size // 2,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -390,7 +528,6 @@ class EffectiveSELayer(nn.Module):
|
|||
arxiv (https://arxiv.org/abs/1911.06667)
|
||||
This code referenced to
|
||||
https://github.com/youngwanLEE/CenterMask/blob/72147e8aae673fcaf4103ee90a6a6b73863e7fa1/maskrcnn_benchmark/modeling/backbone/vovnet.py#L108-L121 # noqa
|
||||
|
||||
Args:
|
||||
channels (int): The input and output channels of this Module.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
|
@ -419,13 +556,13 @@ class EffectiveSELayer(nn.Module):
|
|||
|
||||
class PPYOLOESELayer(nn.Module):
|
||||
"""Squeeze-and-Excitation Attention Module for PPYOLOE.
|
||||
There are some differences between the current implementation and
|
||||
|
||||
There are some differences between the current implementation and
|
||||
SELayer in mmdet:
|
||||
1. For fast speed and avoiding double inference in ppyoloe,
|
||||
use `F.adaptive_avg_pool2d` before PPYOLOESELayer.
|
||||
2. Special ways to init weights.
|
||||
3. Different convolution order.
|
||||
|
||||
Args:
|
||||
feat_channels (int): The input (and output) channels of the SE layer.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
|
@ -473,7 +610,6 @@ class ELANBlock(BaseModule):
|
|||
- if mode is `no_change_channel`, the output channel does not change.
|
||||
- if mode is `expand_channel_2x`, the output channel will be
|
||||
expanded by a factor of 2
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
mode (str): Output channel mode. Defaults to `expand_channel_2x`.
|
||||
|
@ -597,7 +733,6 @@ class MaxPoolAndStrideConvBlock(BaseModule):
|
|||
- if mode is `reduce_channel_2x`, the output channel will
|
||||
be reduced by a factor of 2
|
||||
- if mode is `no_change_channel`, the output channel does not change.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
mode (str): Output channel mode. `reduce_channel_2x` or
|
||||
|
@ -669,7 +804,6 @@ class MaxPoolAndStrideConvBlock(BaseModule):
|
|||
class SPPFCSPBlock(BaseModule):
|
||||
"""Spatial pyramid pooling - Fast (SPPF) layer with CSP for
|
||||
YOLOv7
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
|
@ -832,9 +966,9 @@ class PPYOLOEBasicBlock(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
@ -986,3 +1120,66 @@ class CSPResLayer(nn.Module):
|
|||
y = self.attn(y)
|
||||
y = self.conv3(y)
|
||||
return y
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RepStageBlock(nn.Module):
|
||||
"""RepStageBlock is a stage block with rep-style basic block.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
num_blocks (int, tuple[int]): Number of blocks. Defaults to 1.
|
||||
bottle_block (nn.Module): Basic unit of RepStage.
|
||||
Defaults to RepVGGBlock.
|
||||
block_cfg (ConfigType): Config of RepStage.
|
||||
Defaults to 'RepVGGBlock'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int = 1,
|
||||
bottle_block: nn.Module = RepVGGBlock,
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock')):
|
||||
super().__init__()
|
||||
block_cfg = block_cfg.copy()
|
||||
|
||||
block_cfg.update(
|
||||
dict(in_channels=in_channels, out_channels=out_channels))
|
||||
|
||||
self.conv1 = MODELS.build(block_cfg)
|
||||
|
||||
block_cfg.update(
|
||||
dict(in_channels=out_channels, out_channels=out_channels))
|
||||
self.block = nn.Sequential(*(
|
||||
MODELS.build(block_cfg)
|
||||
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
|
||||
|
||||
if bottle_block == BottleRep:
|
||||
self.conv1 = BottleRep(
|
||||
in_channels,
|
||||
out_channels,
|
||||
block_cfg=block_cfg,
|
||||
adaptive_weight=True)
|
||||
num_blocks = num_blocks // 2
|
||||
self.block = nn.Sequential(*(
|
||||
BottleRep(
|
||||
out_channels,
|
||||
out_channels,
|
||||
block_cfg=block_cfg,
|
||||
adaptive_weight=True)
|
||||
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor.
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
if self.block is not None:
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
|
|
@ -3,11 +3,11 @@ from .base_yolo_neck import BaseYOLONeck
|
|||
from .cspnext_pafpn import CSPNeXtPAFPN
|
||||
from .ppyoloe_csppan import PPYOLOECSPPAFPN
|
||||
from .yolov5_pafpn import YOLOv5PAFPN
|
||||
from .yolov6_pafpn import YOLOv6RepPAFPN
|
||||
from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
|
||||
from .yolov7_pafpn import YOLOv7PAFPN
|
||||
from .yolox_pafpn import YOLOXPAFPN
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN',
|
||||
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN'
|
||||
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN'
|
||||
]
|
||||
|
|
|
@ -56,12 +56,15 @@ class YOLOv5PAFPN(BaseYOLONeck):
|
|||
init_cfg=init_cfg)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
if self.init_cfg is None:
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def build_reduce_layer(self, idx: int) -> nn.Module:
|
||||
"""build reduce layer.
|
||||
|
|
|
@ -7,8 +7,8 @@ from mmcv.cnn import ConvModule
|
|||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import RepStageBlock, RepVGGBlock
|
||||
from ..utils import make_divisible, make_round
|
||||
from ..layers import BepC3StageBlock, RepStageBlock
|
||||
from ..utils import make_round
|
||||
from .base_yolo_neck import BaseYOLONeck
|
||||
|
||||
|
||||
|
@ -29,8 +29,8 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='ReLU', inplace=True).
|
||||
block (nn.Module): block used to build each layer.
|
||||
Defaults to RepVGGBlock.
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
@ -45,10 +45,10 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
block: nn.Module = RepVGGBlock,
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_csp_blocks = num_csp_blocks
|
||||
self.block = block
|
||||
self.block_cfg = block_cfg
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -64,16 +64,14 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The reduce layer.
|
||||
"""
|
||||
if idx == 2:
|
||||
layer = ConvModule(
|
||||
in_channels=make_divisible(self.in_channels[idx],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
in_channels=int(self.in_channels[idx] * self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx - 1] *
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
|
@ -88,15 +86,12 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The upsample layer.
|
||||
"""
|
||||
return nn.ConvTranspose2d(
|
||||
in_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
in_channels=int(self.out_channels[idx - 1] * self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
bias=True)
|
||||
|
@ -106,26 +101,27 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The top down layer.
|
||||
"""
|
||||
block_cfg = self.block_cfg.copy()
|
||||
|
||||
layer0 = RepStageBlock(
|
||||
in_channels=make_divisible(
|
||||
self.out_channels[idx - 1] + self.in_channels[idx - 1],
|
||||
in_channels=int(
|
||||
(self.out_channels[idx - 1] + self.in_channels[idx - 1]) *
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
n=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block=self.block)
|
||||
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block_cfg=block_cfg)
|
||||
|
||||
if idx == 1:
|
||||
return layer0
|
||||
elif idx == 2:
|
||||
layer1 = ConvModule(
|
||||
in_channels=make_divisible(self.out_channels[idx - 1],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx - 2],
|
||||
self.widen_factor),
|
||||
in_channels=int(self.out_channels[idx - 1] *
|
||||
self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx - 2] *
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
|
@ -137,15 +133,12 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The downsample layer.
|
||||
"""
|
||||
return ConvModule(
|
||||
in_channels=make_divisible(self.out_channels[idx],
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx],
|
||||
self.widen_factor),
|
||||
in_channels=int(self.out_channels[idx] * self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx] * self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=3 // 2,
|
||||
|
@ -157,26 +150,136 @@ class YOLOv6RepPAFPN(BaseYOLONeck):
|
|||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The bottom up layer.
|
||||
"""
|
||||
block_cfg = self.block_cfg.copy()
|
||||
|
||||
return RepStageBlock(
|
||||
in_channels=make_divisible(self.out_channels[idx] * 2,
|
||||
self.widen_factor),
|
||||
out_channels=make_divisible(self.out_channels[idx + 1],
|
||||
self.widen_factor),
|
||||
n=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block=self.block)
|
||||
in_channels=int(self.out_channels[idx] * 2 * self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx + 1] * self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block_cfg=block_cfg)
|
||||
|
||||
def build_out_layer(self, *args, **kwargs) -> nn.Module:
|
||||
"""build out layer."""
|
||||
return nn.Identity()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
if self.init_cfg is None:
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
# reset the Conv2d initialization parameters
|
||||
m.reset_parameters()
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv6CSPRepPAFPN(YOLOv6RepPAFPN):
|
||||
"""Path Aggregation Network used in YOLOv6.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
freeze_all(bool): Whether to freeze the model.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='ReLU', inplace=True).
|
||||
block_cfg (dict): Config dict for the block used to build each
|
||||
layer. Defaults to dict(type='RepVGGBlock').
|
||||
block_act_cfg (dict): Config dict for activation layer used in each
|
||||
stage. Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
hidden_ratio: float = 0.5,
|
||||
num_csp_blocks: int = 12,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
||||
block_act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.hidden_ratio = hidden_ratio
|
||||
self.block_act_cfg = block_act_cfg
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
num_csp_blocks=num_csp_blocks,
|
||||
freeze_all=freeze_all,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
block_cfg=block_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_top_down_layer(self, idx: int) -> nn.Module:
|
||||
"""build top down layer.
|
||||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
Returns:
|
||||
nn.Module: The top down layer.
|
||||
"""
|
||||
block_cfg = self.block_cfg.copy()
|
||||
|
||||
layer0 = BepC3StageBlock(
|
||||
in_channels=int(
|
||||
(self.out_channels[idx - 1] + self.in_channels[idx - 1]) *
|
||||
self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx - 1] * self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block_cfg=block_cfg,
|
||||
hidden_ratio=self.hidden_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.block_act_cfg)
|
||||
|
||||
if idx == 1:
|
||||
return layer0
|
||||
elif idx == 2:
|
||||
layer1 = ConvModule(
|
||||
in_channels=int(self.out_channels[idx - 1] *
|
||||
self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx - 2] *
|
||||
self.widen_factor),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
return nn.Sequential(layer0, layer1)
|
||||
|
||||
def build_bottom_up_layer(self, idx: int) -> nn.Module:
|
||||
"""build bottom up layer.
|
||||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
Returns:
|
||||
nn.Module: The bottom up layer.
|
||||
"""
|
||||
block_cfg = self.block_cfg.copy()
|
||||
|
||||
return BepC3StageBlock(
|
||||
in_channels=int(self.out_channels[idx] * 2 * self.widen_factor),
|
||||
out_channels=int(self.out_channels[idx + 1] * self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
block_cfg=block_cfg,
|
||||
hidden_ratio=self.hidden_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.block_act_cfg)
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmyolo.models.backbones import YOLOv6EfficientRep
|
||||
from mmyolo.models.backbones import YOLOv6CSPBep, YOLOv6EfficientRep
|
||||
from mmyolo.utils import register_all_modules
|
||||
from .utils import check_norm_state, is_norm
|
||||
|
||||
|
@ -23,7 +23,7 @@ class TestYOLOv6EfficientRep(TestCase):
|
|||
# frozen_stages must in range(-1, len(arch_setting) + 1)
|
||||
YOLOv6EfficientRep(frozen_stages=6)
|
||||
|
||||
def test_forward(self):
|
||||
def test_YOLOv6EfficientRep_forward(self):
|
||||
# Test YOLOv6EfficientRep with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = YOLOv6EfficientRep(frozen_stages=frozen_stages)
|
||||
|
@ -111,3 +111,92 @@ class TestYOLOv6EfficientRep(TestCase):
|
|||
assert feat[0].shape == torch.Size((1, 256, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 512, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 1024, 8, 8))
|
||||
|
||||
def test_YOLOv6CSPBep_forward(self):
|
||||
# Test YOLOv6CSPBep with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = YOLOv6CSPBep(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for mod in model.stem.modules():
|
||||
for param in mod.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'stage{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test YOLOv6CSPBep with norm_eval=True
|
||||
model = YOLOv6CSPBep(norm_eval=True)
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test YOLOv6CSPBep forward with widen_factor=0.25
|
||||
model = YOLOv6CSPBep(
|
||||
arch='P5', widen_factor=0.25, out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 16, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 32, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 64, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 128, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 256, 2, 2))
|
||||
|
||||
# Test YOLOv6CSPBep forward with dict(type='ReLU')
|
||||
model = YOLOv6CSPBep(
|
||||
widen_factor=0.125,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
out_indices=range(0, 5))
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test YOLOv6CSPBep with BatchNorm forward
|
||||
model = YOLOv6CSPBep(widen_factor=0.125, out_indices=range(0, 5))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size((1, 8, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 16, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test YOLOv6CSPBep with BatchNorm forward
|
||||
model = YOLOv6CSPBep(plugins=[
|
||||
dict(
|
||||
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
|
||||
stages=(False, False, True, True)),
|
||||
])
|
||||
|
||||
assert len(model.stage1) == 1
|
||||
assert len(model.stage2) == 1
|
||||
assert len(model.stage3) == 2 # +DropBlock
|
||||
assert len(model.stage4) == 3 # +SPPF+DropBlock
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 256, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 512, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 1024, 8, 8))
|
||||
|
|
|
@ -3,15 +3,15 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmyolo.models.necks import YOLOv6RepPAFPN
|
||||
from mmyolo.models.necks import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestYOLOv6RepPAFPN(TestCase):
|
||||
class TestYOLOv6PAFPN(TestCase):
|
||||
|
||||
def test_forward(self):
|
||||
def test_YOLOv6RepPAFP_forward(self):
|
||||
s = 64
|
||||
in_channels = [8, 16, 32]
|
||||
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
|
||||
|
@ -27,3 +27,20 @@ class TestYOLOv6RepPAFPN(TestCase):
|
|||
for i in range(len(feats)):
|
||||
assert outs[i].shape[1] == out_channels[i]
|
||||
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
|
||||
|
||||
def test_YOLOv6CSPRepPAFPN_forward(self):
|
||||
s = 64
|
||||
in_channels = [8, 16, 32]
|
||||
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
|
||||
out_channels = [8, 16, 32]
|
||||
feats = [
|
||||
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
|
||||
for i in range(len(in_channels))
|
||||
]
|
||||
neck = YOLOv6CSPRepPAFPN(
|
||||
in_channels=in_channels, out_channels=out_channels)
|
||||
outs = neck(feats)
|
||||
assert len(outs) == len(feats)
|
||||
for i in range(len(feats)):
|
||||
assert outs[i].shape[1] == out_channels[i]
|
||||
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
|
||||
|
|
|
@ -28,12 +28,28 @@ def convert(src, dst):
|
|||
|
||||
if 'ERBlock_2' in k:
|
||||
name = k.replace('ERBlock_2', 'stage1.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'ERBlock_3' in k:
|
||||
name = k.replace('ERBlock_3', 'stage2.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'ERBlock_4' in k:
|
||||
name = k.replace('ERBlock_4', 'stage3.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'ERBlock_5' in k:
|
||||
name = k.replace('ERBlock_5', 'stage4.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
if 'stage4.0.2' in name:
|
||||
name = name.replace('stage4.0.2', 'stage4.1')
|
||||
name = name.replace('cv', 'conv')
|
||||
|
@ -41,10 +57,22 @@ def convert(src, dst):
|
|||
name = k.replace('reduce_layer0', 'reduce_layers.2')
|
||||
elif 'Rep_p4' in k:
|
||||
name = k.replace('Rep_p4', 'top_down_layers.0.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'reduce_layer1' in k:
|
||||
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'Rep_p3' in k:
|
||||
name = k.replace('Rep_p3', 'top_down_layers.1')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'upsample0' in k:
|
||||
name = k.replace('upsample0.upsample_transpose',
|
||||
'upsample_layers.0')
|
||||
|
@ -53,8 +81,16 @@ def convert(src, dst):
|
|||
'upsample_layers.1')
|
||||
elif 'Rep_n3' in k:
|
||||
name = k.replace('Rep_n3', 'bottom_up_layers.0')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'Rep_n4' in k:
|
||||
name = k.replace('Rep_n4', 'bottom_up_layers.1')
|
||||
if '.cv' in k:
|
||||
name = name.replace('.cv', '.conv')
|
||||
if '.m.' in k:
|
||||
name = name.replace('.m.', '.block.')
|
||||
elif 'downsample2' in k:
|
||||
name = k.replace('downsample2', 'downsample_layers.0')
|
||||
elif 'downsample1' in k:
|
||||
|
|
Loading…
Reference in New Issue