mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support YOLOv8 inference (#445)
* Add backbone * Improve code * fix * Add neck * Add config * Fix layer param * Fix layer param * Add head * Add head * Add model converter * Add model converter * Add model converter * Improve code * update * update * align test * Improve code * Improve code * Improve code * Fix lint * Fix lint * Fix lint * Fix lint * Improve code * Improve code * Improve code * Add configs * Improve doc * Improve doc * update * Fix config * update * Fix config * Fix config * update * Fix config epoch * update * Fix docstr * Add UT * Add UT * Fix doc * update * Fix lint * Fix doc * Fix config name * Improve config * Improve default * Add docstr * Improve code * Improve code * Drop `bbox_head.head_module.dfl.conv.weight` when convert to mmyolo weight * Delete useless code * Add batch_shapes_cfg but not enable it. * update Co-authored-by: huanghaian <huanghaian@sensetime.com>pull/456/head
parent
f1855ca618
commit
935d710f79
|
@ -171,6 +171,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
|||
- [x] [YOLOv6](configs/yolov6)
|
||||
- [x] [YOLOv7](configs/yolov7)
|
||||
- [x] [PPYOLOE](configs/ppyoloe)
|
||||
- [ ] [YOLOv8](configs/yolov8) (Inference only)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -190,6 +190,7 @@ MMYOLO 用法和 MMDetection 几乎一致,所有教程都是通用的,你也
|
|||
- [x] [YOLOv6](configs/yolov6)
|
||||
- [x] [YOLOv7](configs/yolov7)
|
||||
- [x] [PPYOLOE](configs/ppyoloe)
|
||||
- [ ] [YOLOv8](configs/yolov8)(仅推理)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# YOLOv8
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO versions and introduces new features and improvements to further boost performance and flexibility. YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of object detection, image segmentation and image classification tasks.
|
||||
|
||||
## Results and models
|
||||
|
||||
### COCO
|
||||
|
||||
| Backbone | Arch | size | AMP | Mem (GB) | box AP | Config | Download |
|
||||
| :------: | :--: | :--: | :-: | :------: | :----: | :-------------------------------------------------------------------------------------------------------: | :--------------------: |
|
||||
| YOLOv8-n | P5 | 640 | Yes | x | 37.3 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_n_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
|
||||
| YOLOv8-s | P5 | 640 | Yes | x | 44.9 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_s_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
|
||||
| YOLOv8-m | P5 | 640 | Yes | x | 50.3 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_m_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
|
||||
| YOLOv8-l | P5 | 640 | Yes | x | 52.8 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_l_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
|
||||
| YOLOv8-x | P5 | 640 | Yes | x | 53.8 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_x_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
|
||||
|
||||
**Note**: The above AP is the result of the test after using the official weight conversion. We provide the [yolov8_to_mmyolo](https://github.com/open-mmlab/mmyolo/tree/dev/tools/model_converters/yolov8_to_mmyolo.py) script for you to convert YOLOv8 weights to MMYOLO.
|
||||
|
||||
## Citation
|
|
@ -0,0 +1,21 @@
|
|||
_base_ = './yolov8_m_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
deepen_factor = 1.00
|
||||
widen_factor = 1.00
|
||||
|
||||
last_stage_out_channels = 512
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
last_stage_out_channels=last_stage_out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor),
|
||||
neck=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, last_stage_out_channels],
|
||||
out_channels=[256, 512, last_stage_out_channels]),
|
||||
bbox_head=dict(
|
||||
head_module=dict(
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, last_stage_out_channels])))
|
|
@ -0,0 +1,21 @@
|
|||
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
deepen_factor = 0.67
|
||||
widen_factor = 0.75
|
||||
|
||||
last_stage_out_channels = 768
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
last_stage_out_channels=last_stage_out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor),
|
||||
neck=dict(
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, last_stage_out_channels],
|
||||
out_channels=[256, 512, last_stage_out_channels]),
|
||||
bbox_head=dict(
|
||||
head_module=dict(
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, last_stage_out_channels])))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.25
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
|
@ -0,0 +1,120 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
dataset_type = 'YOLOv5CocoDataset'
|
||||
|
||||
# parameters that often need to be modified
|
||||
num_classes = 80
|
||||
img_scale = (640, 640) # height, width
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.5
|
||||
val_batch_size_per_gpu = 1
|
||||
val_num_workers = 2
|
||||
|
||||
# persistent_workers must be False if num_workers is 0.
|
||||
persistent_workers = True
|
||||
|
||||
strides = [8, 16, 32]
|
||||
num_det_layers = 3
|
||||
|
||||
last_stage_out_channels = 1024
|
||||
|
||||
# single-scale training is recommended to
|
||||
# be turned on, which can speed up training.
|
||||
env_cfg = dict(cudnn_benchmark=True)
|
||||
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
data_preprocessor=dict(
|
||||
type='YOLOv5DetDataPreprocessor',
|
||||
mean=[0., 0., 0.],
|
||||
std=[255., 255., 255.],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='YOLOv8CSPDarknet',
|
||||
arch='P5',
|
||||
last_stage_out_channels=last_stage_out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
neck=dict(
|
||||
type='YOLOv8PAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, last_stage_out_channels],
|
||||
out_channels=[256, 512, last_stage_out_channels],
|
||||
num_csp_blocks=3,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
bbox_head=dict(
|
||||
type='YOLOv8Head',
|
||||
head_module=dict(
|
||||
type='YOLOv8HeadModule',
|
||||
num_classes=num_classes,
|
||||
in_channels=[256, 512, last_stage_out_channels],
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True),
|
||||
featmap_strides=[8, 16, 32])),
|
||||
test_cfg=dict(
|
||||
multi_label=True,
|
||||
nms_pre=30000,
|
||||
score_thr=0.001,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
max_per_img=300))
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
# only on Val
|
||||
# you can turn on `batch_shapes_cfg`,
|
||||
# we tested YOLOv8-m will get 0.02 higher than not using it.
|
||||
batch_shapes_cfg = None
|
||||
# batch_shapes_cfg = dict(
|
||||
# type='BatchShapePolicy',
|
||||
# batch_size=val_batch_size_per_gpu,
|
||||
# img_size=img_scale[0],
|
||||
# size_divisor=32,
|
||||
# extra_pad_ratio=0.5)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
num_workers=val_num_workers,
|
||||
persistent_workers=persistent_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
test_mode=True,
|
||||
data_prefix=dict(img='val2017/'),
|
||||
ann_file='annotations/instances_val2017.json',
|
||||
pipeline=test_pipeline,
|
||||
batch_shapes_cfg=batch_shapes_cfg))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(
|
||||
type='mmdet.CocoMetric',
|
||||
proposal_nums=(100, 1, 10),
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
metric='bbox')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = './yolov8_l_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
deepen_factor = 1.00
|
||||
widen_factor = 1.25
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_backbone import BaseBackbone
|
||||
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
|
||||
from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
|
||||
from .csp_resnet import PPYOLOECSPResNet
|
||||
from .cspnext import CSPNeXt
|
||||
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
|
||||
|
@ -8,5 +8,6 @@ from .yolov7_backbone import YOLOv7Backbone
|
|||
|
||||
__all__ = [
|
||||
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
|
||||
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet'
|
||||
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
|
||||
'YOLOv8CSPDarknet'
|
||||
]
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
|
|||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..layers import SPPFBottleneck
|
||||
from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
|
||||
from ..utils import make_divisible, make_round
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
@ -16,12 +16,10 @@ from .base_backbone import BaseBackbone
|
|||
@MODELS.register_module()
|
||||
class YOLOv5CSPDarknet(BaseBackbone):
|
||||
"""CSP-Darknet backbone used in YOLOv5.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
|
@ -43,7 +41,6 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
and its variants only. Defaults to False.
|
||||
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv5CSPDarknet
|
||||
>>> import torch
|
||||
|
@ -157,6 +154,140 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
super().init_weights()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv8CSPDarknet(BaseBackbone):
|
||||
"""CSP-Darknet backbone used in YOLOv8.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5}.
|
||||
Defaults to P5.
|
||||
last_stage_out_channels (int): Final layer output channel.
|
||||
Defaults to 1024.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
input_channels (int): Number of input image channels. Defaults to: 3.
|
||||
out_indices (Tuple[int]): Output from which stages.
|
||||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmyolo.models import YOLOv8CSPDarknet
|
||||
>>> import torch
|
||||
>>> model = YOLOv8CSPDarknet()
|
||||
>>> model.eval()
|
||||
>>> inputs = torch.rand(1, 3, 416, 416)
|
||||
>>> level_outputs = model(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
...
|
||||
(1, 256, 52, 52)
|
||||
(1, 512, 26, 26)
|
||||
(1, 1024, 13, 13)
|
||||
"""
|
||||
# From left to right:
|
||||
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
||||
# the final out_channels will be set according to the param.
|
||||
arch_settings = {
|
||||
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
||||
[256, 512, 6, True, False], [512, None, 3, True, True]],
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
last_stage_out_channels: int = 1024,
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.arch_settings[arch][-1][1] = last_stage_out_channels
|
||||
super().__init__(
|
||||
self.arch_settings[arch],
|
||||
deepen_factor,
|
||||
widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
plugins=plugins,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
norm_eval=norm_eval,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
return ConvModule(
|
||||
self.input_channels,
|
||||
make_divisible(self.arch_setting[0][0], self.widen_factor),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
||||
"""Build a stage layer.
|
||||
|
||||
Args:
|
||||
stage_idx (int): The index of a stage layer.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
"""
|
||||
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
||||
|
||||
in_channels = make_divisible(in_channels, self.widen_factor)
|
||||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
stage = []
|
||||
conv_layer = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(conv_layer)
|
||||
csp_layer = CSPLayerWithTwoConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
num_blocks=num_blocks,
|
||||
add_identity=add_identity,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(csp_layer)
|
||||
if use_spp:
|
||||
spp = SPPFBottleneck(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_sizes=5,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
stage.append(spp)
|
||||
return stage
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOXCSPDarknet(BaseBackbone):
|
||||
"""CSP-Darknet backbone used in YOLOX.
|
||||
|
|
|
@ -4,11 +4,12 @@ from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
|||
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
|
||||
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
|
||||
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
|
||||
from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
|
||||
from .yolox_head import YOLOXHead, YOLOXHeadModule
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
|
||||
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
|
||||
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
|
||||
'YOLOv7HeadModule', 'YOLOv7p6HeadModule'
|
||||
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule'
|
||||
]
|
||||
|
|
|
@ -28,8 +28,8 @@ class PPYOLOEHeadModule(BaseModule):
|
|||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
num_base_priors:int: The number of priors (points) at a point
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors (int): The number of priors (points) at a point
|
||||
on the feature grid.
|
||||
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
||||
Defaults to (8, 16, 32).
|
||||
|
|
|
@ -26,7 +26,7 @@ class RTMDetSepBNHeadModule(BaseModule):
|
|||
in_channels (int): Number of channels in the input feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors:int: The number of priors (points) at a point
|
||||
num_base_priors (int): The number of priors (points) at a point
|
||||
on the feature grid. Defaults to 1.
|
||||
feat_channels (int): Number of hidden channels. Used in child classes.
|
||||
Defaults to 256
|
||||
|
|
|
@ -42,8 +42,8 @@ class YOLOv5HeadModule(BaseModule):
|
|||
in_channels (Union[int, Sequence]): Number of channels in the input
|
||||
feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
num_base_priors:int: The number of priors (points) at a point
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors (int): The number of priors (points) at a point
|
||||
on the feature grid.
|
||||
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
||||
Defaults to (8, 16, 32).
|
||||
|
|
|
@ -29,7 +29,7 @@ class YOLOv6HeadModule(BaseModule):
|
|||
in_channels (Union[int, Sequence]): Number of channels in the input
|
||||
feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors: (int): The number of priors (points) at a point
|
||||
on the feature grid.
|
||||
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.models.utils import multi_apply
|
||||
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
||||
OptMultiConfig)
|
||||
from mmengine.model import BaseModule, bias_init_with_prob
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from ..utils import make_divisible
|
||||
from .yolov5_head import YOLOv5Head
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv8HeadModule(BaseModule):
|
||||
"""YOLOv8HeadModule head module used in `YOLOv8`.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (Union[int, Sequence]): Number of channels in the input
|
||||
feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors (int): The number of priors (points) at a point
|
||||
on the feature grid.
|
||||
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
||||
Defaults to [8, 16, 32].
|
||||
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max-1}``
|
||||
in QFL setting. Defaults to 16.
|
||||
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
||||
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
||||
Defaults to None.
|
||||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
||||
list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: Union[int, Sequence],
|
||||
widen_factor: float = 1.0,
|
||||
num_base_priors: int = 1,
|
||||
featmap_strides: Sequence[int] = (8, 16, 32),
|
||||
reg_max: int = 16,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.num_classes = num_classes
|
||||
self.featmap_strides = featmap_strides
|
||||
self.num_levels = len(self.featmap_strides)
|
||||
self.num_base_priors = num_base_priors
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_channels = in_channels
|
||||
self.reg_max = reg_max
|
||||
|
||||
in_channels = []
|
||||
for channel in self.in_channels:
|
||||
channel = make_divisible(channel, widen_factor)
|
||||
in_channels.append(channel)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of the head."""
|
||||
# Use prior in model initialization to improve stability
|
||||
super().init_weights()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
m.reset_parameters()
|
||||
|
||||
bias_init = bias_init_with_prob(0.01)
|
||||
for conv_cls in self.cls_preds:
|
||||
conv_cls.bias.data.fill_(bias_init)
|
||||
|
||||
def _init_layers(self):
|
||||
"""initialize conv layers in YOLOv8 head."""
|
||||
# Init decouple head
|
||||
self.cls_preds = nn.ModuleList()
|
||||
self.reg_preds = nn.ModuleList()
|
||||
|
||||
reg_out_channels = max(
|
||||
(16, self.in_channels[0] // 4, self.reg_max * 4))
|
||||
cls_out_channels = max(self.in_channels[0], self.num_classes)
|
||||
|
||||
for i in range(self.num_levels):
|
||||
self.reg_preds.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=reg_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
in_channels=reg_out_channels,
|
||||
out_channels=reg_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
in_channels=reg_out_channels,
|
||||
out_channels=4 * self.reg_max,
|
||||
kernel_size=1)))
|
||||
self.cls_preds.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.in_channels[i],
|
||||
out_channels=cls_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
in_channels=cls_out_channels,
|
||||
out_channels=cls_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(
|
||||
in_channels=cls_out_channels,
|
||||
out_channels=self.num_classes,
|
||||
kernel_size=1)))
|
||||
|
||||
proj = torch.linspace(0, self.reg_max - 1,
|
||||
self.reg_max).view([1, self.reg_max, 1, 1])
|
||||
self.register_buffer('proj', proj, persistent=False)
|
||||
|
||||
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
||||
"""Forward features from the upstream network.
|
||||
|
||||
Args:
|
||||
x (Tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
Returns:
|
||||
Tuple[List]: A tuple of multi-level classification scores, bbox
|
||||
predictions
|
||||
"""
|
||||
assert len(x) == self.num_levels
|
||||
return multi_apply(self.forward_single, x, self.cls_preds,
|
||||
self.reg_preds)
|
||||
|
||||
def forward_single(self, x: torch.Tensor, cls_pred: nn.ModuleList,
|
||||
reg_pred: nn.ModuleList) -> Tuple:
|
||||
"""Forward feature of a single scale level."""
|
||||
b, _, h, w = x.shape
|
||||
cls_logit = cls_pred(x)
|
||||
bbox_dist_preds = reg_pred(x)
|
||||
if self.reg_max > 1:
|
||||
bbox_dist_preds = bbox_dist_preds.reshape(
|
||||
[-1, 4, self.reg_max, h * w]).permute(0, 2, 3, 1)
|
||||
# TODO: Test whether use matmul instead of conv can
|
||||
# speed up training.
|
||||
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
|
||||
else:
|
||||
bbox_preds = bbox_dist_preds
|
||||
return cls_logit, bbox_preds
|
||||
|
||||
|
||||
# TODO Training mode is currently not supported
|
||||
@MODELS.register_module()
|
||||
class YOLOv8Head(YOLOv5Head):
|
||||
"""YOLOv8Head head used in `YOLOv8`.
|
||||
|
||||
Args:
|
||||
head_module(:obj:`ConfigDict` or dict): Base module used for YOLOv8Head
|
||||
prior_generator(dict): Points generator feature maps
|
||||
in 2D points-based detectors.
|
||||
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
||||
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
||||
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
||||
loss_dfl (:obj:`ConfigDict` or dict): Config of Distribution Focal
|
||||
Loss.
|
||||
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
||||
anchor head. Defaults to None.
|
||||
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
||||
anchor head. Defaults to None.
|
||||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
||||
list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
head_module: ConfigType,
|
||||
prior_generator: ConfigType = dict(
|
||||
type='mmdet.MlvlPointGenerator',
|
||||
offset=0.5,
|
||||
strides=[8, 16, 32]),
|
||||
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
||||
loss_cls: ConfigType = dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
loss_bbox: ConfigType = dict(
|
||||
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
|
||||
loss_dfl=dict(
|
||||
type='mmdet.DistributionFocalLoss',
|
||||
reduction='mean',
|
||||
loss_weight=0.5 / 4),
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
head_module=head_module,
|
||||
prior_generator=prior_generator,
|
||||
bbox_coder=bbox_coder,
|
||||
loss_cls=loss_cls,
|
||||
loss_bbox=loss_bbox,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def special_init(self):
|
||||
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
||||
different algorithms have special initialization process.
|
||||
|
||||
The special_init function is designed to deal with this situation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
cls_scores: Sequence[Tensor],
|
||||
bbox_preds: Sequence[Tensor],
|
||||
objectnesses: Sequence[Tensor],
|
||||
batch_gt_instances: Sequence[InstanceData],
|
||||
batch_img_metas: Sequence[dict],
|
||||
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
||||
raise NotImplementedError('Not implemented yet !')
|
|
@ -30,7 +30,7 @@ class YOLOXHeadModule(BaseModule):
|
|||
in_channels (Union[int, Sequence]): Number of channels in the input
|
||||
feature map.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_base_priors (int): The number of priors (points) at a point
|
||||
on the feature grid
|
||||
stacked_convs (int): Number of stacking convs of the head.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ema import ExpMomentumEMA
|
||||
from .yolo_bricks import (BepC3StageBlock, EELANBlock, EffectiveSELayer,
|
||||
from .yolo_bricks import (BepC3StageBlock, CSPLayerWithTwoConv,
|
||||
DarknetBottleneck, EELANBlock, EffectiveSELayer,
|
||||
ELANBlock, ImplicitA, ImplicitM,
|
||||
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
|
||||
RepStageBlock, RepVGGBlock, SPPFBottleneck,
|
||||
|
@ -10,5 +11,6 @@ __all__ = [
|
|||
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
|
||||
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
|
||||
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
|
||||
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock'
|
||||
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock',
|
||||
'CSPLayerWithTwoConv', 'DarknetBottleneck'
|
||||
]
|
||||
|
|
|
@ -4,7 +4,10 @@ from typing import Optional, Sequence, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, MaxPool2d, build_norm_layer
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, MaxPool2d,
|
||||
build_norm_layer)
|
||||
from mmdet.models.layers.csp_layer import \
|
||||
DarknetBottleneck as MMDET_DarknetBottleneck
|
||||
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import digit_version
|
||||
|
@ -1352,7 +1355,7 @@ class RepStageBlock(nn.Module):
|
|||
"""Forward process.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor.
|
||||
x (Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
|
@ -1361,3 +1364,147 @@ class RepStageBlock(nn.Module):
|
|||
if self.block is not None:
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
||||
|
||||
class DarknetBottleneck(MMDET_DarknetBottleneck):
|
||||
"""The basic bottleneck block used in Darknet.
|
||||
|
||||
Each ResBlock consists of two ConvModules and the input is added to the
|
||||
final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
|
||||
The first convLayer has filter size of k1Xk1 and the second one has the
|
||||
filter size of k2Xk2.
|
||||
|
||||
Note:
|
||||
This DarknetBottleneck is little different from MMDet's, we can
|
||||
change the kernel size and padding for each conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
expansion (float): The kernel size for hidden channel.
|
||||
Defaults to 0.5.
|
||||
kernel_size (Sequence[int]): The kernel size of the convolution.
|
||||
Defaults to (1, 3).
|
||||
padding (Sequence[int]): The padding size of the convolution.
|
||||
Defaults to (0, 1).
|
||||
add_identity (bool): Whether to add identity to the out.
|
||||
Defaults to True
|
||||
use_depthwise (bool): Whether to use depthwise separable convolution.
|
||||
Defaults to False
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='Swish').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expansion: float = 0.5,
|
||||
kernel_size: Sequence[int] = (1, 3),
|
||||
padding: Sequence[int] = (0, 1),
|
||||
add_identity: bool = True,
|
||||
use_depthwise: bool = False,
|
||||
conv_cfg: OptConfigType = None,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None) -> None:
|
||||
super().__init__(in_channels, out_channels, init_cfg=init_cfg)
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
|
||||
assert isinstance(kernel_size, Sequence) and len(kernel_size) == 2
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size[0],
|
||||
padding=padding[0],
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv2 = conv(
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
kernel_size[1],
|
||||
stride=1,
|
||||
padding=padding[1],
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.add_identity = \
|
||||
add_identity and in_channels == out_channels
|
||||
|
||||
|
||||
class CSPLayerWithTwoConv(BaseModule):
|
||||
"""Cross Stage Partial Layer with 2 convolutions.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of the CSP layer.
|
||||
out_channels (int): The output channels of the CSP layer.
|
||||
expand_ratio (float): Ratio to adjust the number of channels of the
|
||||
hidden layer. Defaults to 0.5.
|
||||
num_blocks (int): Number of blocks. Defaults to 1
|
||||
add_identity (bool): Whether to add identity in blocks.
|
||||
Defaults to True.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Defaults to None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
|
||||
list[:obj:`ConfigDict`], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 0.5,
|
||||
num_blocks: int = 1,
|
||||
add_identity: bool = True, # shortcut
|
||||
conv_cfg: OptConfigType = None,
|
||||
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.mid_channels = int(out_channels * expand_ratio)
|
||||
self.main_conv = ConvModule(
|
||||
in_channels,
|
||||
2 * self.mid_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.final_conv = ConvModule(
|
||||
(2 + num_blocks) * self.mid_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
DarknetBottleneck(
|
||||
self.mid_channels,
|
||||
self.mid_channels,
|
||||
expansion=1,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
add_identity=add_identity,
|
||||
use_depthwise=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg) for _ in range(num_blocks))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward process."""
|
||||
x_main = self.main_conv(x)
|
||||
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
|
||||
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
|
||||
return self.final_conv(torch.cat(x_main, 1))
|
||||
|
|
|
@ -5,9 +5,11 @@ from .ppyoloe_csppan import PPYOLOECSPPAFPN
|
|||
from .yolov5_pafpn import YOLOv5PAFPN
|
||||
from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
|
||||
from .yolov7_pafpn import YOLOv7PAFPN
|
||||
from .yolov8_pafpn import YOLOv8PAFPN
|
||||
from .yolox_pafpn import YOLOXPAFPN
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN',
|
||||
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN'
|
||||
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN',
|
||||
'YOLOv8PAFPN'
|
||||
]
|
||||
|
|
|
@ -147,9 +147,8 @@ class BaseYOLONeck(BaseModule, metaclass=ABCMeta):
|
|||
self.out_channels = out_channels
|
||||
self.deepen_factor = deepen_factor
|
||||
self.widen_factor = widen_factor
|
||||
self.freeze_all = freeze_all
|
||||
self.upsample_feats_cat_first = upsample_feats_cat_first
|
||||
|
||||
self.freeze_all = freeze_all
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from .. import CSPLayerWithTwoConv
|
||||
from ..utils import make_divisible, make_round
|
||||
from .yolov5_pafpn import YOLOv5PAFPN
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv8PAFPN(YOLOv5PAFPN):
|
||||
"""Path Aggregation Network used in YOLOv8.
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
freeze_all(bool): Whether to freeze the model
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Defaults to dict(type='SiLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: Union[List[int], int],
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
num_csp_blocks: int = 3,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
num_csp_blocks=num_csp_blocks,
|
||||
freeze_all=freeze_all,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def build_reduce_layer(self, idx: int) -> nn.Module:
|
||||
"""build reduce layer.
|
||||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The reduce layer.
|
||||
"""
|
||||
return nn.Identity()
|
||||
|
||||
def build_top_down_layer(self, idx: int) -> nn.Module:
|
||||
"""build top down layer.
|
||||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The top down layer.
|
||||
"""
|
||||
return CSPLayerWithTwoConv(
|
||||
make_divisible((self.in_channels[idx - 1] + self.in_channels[idx]),
|
||||
self.widen_factor),
|
||||
make_divisible(self.out_channels[idx - 1], self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def build_bottom_up_layer(self, idx: int) -> nn.Module:
|
||||
"""build bottom up layer.
|
||||
|
||||
Args:
|
||||
idx (int): layer idx.
|
||||
|
||||
Returns:
|
||||
nn.Module: The bottom up layer.
|
||||
"""
|
||||
return CSPLayerWithTwoConv(
|
||||
make_divisible(
|
||||
(self.out_channels[idx] + self.out_channels[idx + 1]),
|
||||
self.widen_factor),
|
||||
make_divisible(self.out_channels[idx + 1], self.widen_factor),
|
||||
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
|
||||
add_identity=False,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
|
||||
from mmyolo.models import YOLOv8Head
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestYOLOv8Head(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.head_module = dict(
|
||||
type='YOLOv8HeadModule',
|
||||
num_classes=2,
|
||||
in_channels=[32, 64, 128],
|
||||
featmap_strides=[8, 16, 32])
|
||||
|
||||
def test_predict_by_feat(self):
|
||||
s = 256
|
||||
img_metas = [{
|
||||
'img_shape': (s, s, 3),
|
||||
'ori_shape': (s, s, 3),
|
||||
'scale_factor': (1.0, 1.0),
|
||||
}]
|
||||
test_cfg = Config(
|
||||
dict(
|
||||
multi_label=True,
|
||||
max_per_img=300,
|
||||
score_thr=0.01,
|
||||
nms=dict(type='nms', iou_threshold=0.65)))
|
||||
|
||||
head = YOLOv8Head(head_module=self.head_module, test_cfg=test_cfg)
|
||||
|
||||
feat = []
|
||||
for i in range(len(self.head_module['in_channels'])):
|
||||
in_channel = self.head_module['in_channels'][i]
|
||||
feat_size = self.head_module['featmap_strides'][i]
|
||||
feat.append(
|
||||
torch.rand(1, in_channel, s // feat_size, s // feat_size))
|
||||
|
||||
cls_scores, bbox_preds = head.forward(feat)
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
None,
|
||||
img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=True,
|
||||
with_nms=True)
|
||||
head.predict_by_feat(
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
None,
|
||||
img_metas,
|
||||
cfg=test_cfg,
|
||||
rescale=False,
|
||||
with_nms=False)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmyolo.models import YOLOv8PAFPN
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestYOLOv8PAFPN(TestCase):
|
||||
|
||||
def test_YOLOv8PAFPN_forward(self):
|
||||
s = 64
|
||||
in_channels = [8, 16, 32]
|
||||
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
|
||||
out_channels = [8, 16, 32]
|
||||
feats = [
|
||||
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
|
||||
for i in range(len(in_channels))
|
||||
]
|
||||
neck = YOLOv8PAFPN(in_channels=in_channels, out_channels=out_channels)
|
||||
outs = neck(feats)
|
||||
assert len(outs) == len(feats)
|
||||
for i in range(len(feats)):
|
||||
assert outs[i].shape[1] == out_channels[i]
|
||||
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
convert_dict_s = {
|
||||
# backbone
|
||||
'model.0': 'backbone.stem',
|
||||
'model.1': 'backbone.stage1.0',
|
||||
'model.2': 'backbone.stage1.1',
|
||||
'model.3': 'backbone.stage2.0',
|
||||
'model.4': 'backbone.stage2.1',
|
||||
'model.5': 'backbone.stage3.0',
|
||||
'model.6': 'backbone.stage3.1',
|
||||
'model.7': 'backbone.stage4.0',
|
||||
'model.8': 'backbone.stage4.1',
|
||||
'model.9': 'backbone.stage4.2',
|
||||
|
||||
# neck
|
||||
'model.12': 'neck.top_down_layers.0',
|
||||
'model.15': 'neck.top_down_layers.1',
|
||||
'model.16': 'neck.downsample_layers.0',
|
||||
'model.18': 'neck.bottom_up_layers.0',
|
||||
'model.19': 'neck.downsample_layers.1',
|
||||
'model.21': 'neck.bottom_up_layers.1',
|
||||
|
||||
# Detector
|
||||
'model.22': 'bbox_head.head_module',
|
||||
}
|
||||
|
||||
|
||||
def convert(src, dst):
|
||||
"""Convert keys in pretrained YOLOv8 models to mmyolo style."""
|
||||
convert_dict = convert_dict_s
|
||||
|
||||
try:
|
||||
yolov8_model = torch.load(src)['model']
|
||||
blobs = yolov8_model.state_dict()
|
||||
except ModuleNotFoundError:
|
||||
raise RuntimeError(
|
||||
'This script must be placed under the ultralytics repo,'
|
||||
' because loading the official pretrained model need'
|
||||
' `model.py` to build model.'
|
||||
'Also need to install hydra-core>=1.2.0 and thop>=0.1.1')
|
||||
state_dict = OrderedDict()
|
||||
|
||||
for key, weight in blobs.items():
|
||||
num, module = key.split('.')[1:3]
|
||||
prefix = f'model.{num}'
|
||||
new_key = key.replace(prefix, convert_dict[prefix])
|
||||
|
||||
if '.m.' in new_key:
|
||||
new_key = new_key.replace('.m.', '.blocks.')
|
||||
new_key = new_key.replace('.cv', '.conv')
|
||||
elif 'bbox_head.head_module' in new_key:
|
||||
new_key = new_key.replace('.cv2', '.reg_preds')
|
||||
new_key = new_key.replace('.cv3', '.cls_preds')
|
||||
elif 'backbone.stage4.2' in new_key:
|
||||
new_key = new_key.replace('.cv', '.conv')
|
||||
else:
|
||||
new_key = new_key.replace('.cv1', '.main_conv')
|
||||
new_key = new_key.replace('.cv2', '.final_conv')
|
||||
|
||||
if 'bbox_head.head_module.dfl.conv.weight' == new_key:
|
||||
print('Drop "bbox_head.head_module.dfl.conv.weight", '
|
||||
'because it is useless')
|
||||
continue
|
||||
state_dict[new_key] = weight
|
||||
print(f'Convert {key} to {new_key}')
|
||||
|
||||
# save checkpoint
|
||||
checkpoint = dict()
|
||||
checkpoint['state_dict'] = state_dict
|
||||
torch.save(checkpoint, dst)
|
||||
|
||||
|
||||
# Note: This script must be placed under the YOLOv8 repo to run.
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Convert model keys')
|
||||
parser.add_argument(
|
||||
'--src', default='yolov8s.pt', help='src YOLOv8 model path')
|
||||
parser.add_argument('--dst', default='mmyolov8s.pth', help='save path')
|
||||
args = parser.parse_args()
|
||||
convert(args.src, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue