[Feature] Support YOLOv8 inference (#445)

* Add backbone

* Improve code

* fix

* Add neck

* Add config

* Fix layer param

* Fix layer param

* Add head

* Add head

* Add model converter

* Add model converter

* Add model converter

* Improve code

* update

* update

* align test

* Improve code

* Improve code

* Improve code

* Fix lint

* Fix lint

* Fix lint

* Fix lint

* Improve code

* Improve code

* Improve code

* Add configs

* Improve doc

* Improve doc

* update

* Fix config

* update

* Fix config

* Fix config

* update

* Fix config epoch

* update

* Fix docstr

* Add UT

* Add UT

* Fix doc

* update

* Fix lint

* Fix doc

* Fix config name

* Improve config

* Improve default

* Add docstr

* Improve code

* Improve code

* Drop `bbox_head.head_module.dfl.conv.weight` when convert to mmyolo weight

* Delete useless code

* Add batch_shapes_cfg but not enable it.

* update

Co-authored-by: huanghaian <huanghaian@sensetime.com>
pull/456/head
HinGwenWoong 2023-01-11 19:01:06 +08:00 committed by GitHub
parent f1855ca618
commit 935d710f79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1038 additions and 21 deletions

View File

@ -171,6 +171,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [YOLOv6](configs/yolov6)
- [x] [YOLOv7](configs/yolov7)
- [x] [PPYOLOE](configs/ppyoloe)
- [ ] [YOLOv8](configs/yolov8) (Inference only)
</details>

View File

@ -190,6 +190,7 @@ MMYOLO 用法和 MMDetection 几乎一致,所有教程都是通用的,你也
- [x] [YOLOv6](configs/yolov6)
- [x] [YOLOv7](configs/yolov7)
- [x] [PPYOLOE](configs/ppyoloe)
- [ ] [YOLOv8](configs/yolov8)(仅推理)
</details>

View File

@ -0,0 +1,23 @@
# YOLOv8
<!-- [ALGORITHM] -->
## Abstract
Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO versions and introduces new features and improvements to further boost performance and flexibility. YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of object detection, image segmentation and image classification tasks.
## Results and models
### COCO
| Backbone | Arch | size | AMP | Mem (GB) | box AP | Config | Download |
| :------: | :--: | :--: | :-: | :------: | :----: | :-------------------------------------------------------------------------------------------------------: | :--------------------: |
| YOLOv8-n | P5 | 640 | Yes | x | 37.3 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_n_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
| YOLOv8-s | P5 | 640 | Yes | x | 44.9 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_s_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
| YOLOv8-m | P5 | 640 | Yes | x | 50.3 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_m_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
| YOLOv8-l | P5 | 640 | Yes | x | 52.8 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_l_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
| YOLOv8-x | P5 | 640 | Yes | x | 53.8 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/yolov8/yolov8_x_syncbn_8xb16-500e_coco.py) | [model](x) \| [log](x) |
**Note**: The above AP is the result of the test after using the official weight conversion. We provide the [yolov8_to_mmyolo](https://github.com/open-mmlab/mmyolo/tree/dev/tools/model_converters/yolov8_to_mmyolo.py) script for you to convert YOLOv8 weights to MMYOLO.
## Citation

View File

@ -0,0 +1,21 @@
_base_ = './yolov8_m_syncbn_fast_8xb16-500e_coco.py'
deepen_factor = 1.00
widen_factor = 1.00
last_stage_out_channels = 512
model = dict(
backbone=dict(
last_stage_out_channels=last_stage_out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, last_stage_out_channels],
out_channels=[256, 512, last_stage_out_channels]),
bbox_head=dict(
head_module=dict(
widen_factor=widen_factor,
in_channels=[256, 512, last_stage_out_channels])))

View File

@ -0,0 +1,21 @@
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
deepen_factor = 0.67
widen_factor = 0.75
last_stage_out_channels = 768
model = dict(
backbone=dict(
last_stage_out_channels=last_stage_out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, last_stage_out_channels],
out_channels=[256, 512, last_stage_out_channels]),
bbox_head=dict(
head_module=dict(
widen_factor=widen_factor,
in_channels=[256, 512, last_stage_out_channels])))

View File

@ -0,0 +1,9 @@
_base_ = './yolov8_s_syncbn_fast_8xb16-500e_coco.py'
deepen_factor = 0.33
widen_factor = 0.25
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

View File

@ -0,0 +1,120 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# parameters that often need to be modified
num_classes = 80
img_scale = (640, 640) # height, width
deepen_factor = 0.33
widen_factor = 0.5
val_batch_size_per_gpu = 1
val_num_workers = 2
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
strides = [8, 16, 32]
num_det_layers = 3
last_stage_out_channels = 1024
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv8CSPDarknet',
arch='P5',
last_stage_out_channels=last_stage_out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='YOLOv8PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, last_stage_out_channels],
out_channels=[256, 512, last_stage_out_channels],
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv8Head',
head_module=dict(
type='YOLOv8HeadModule',
num_classes=num_classes,
in_channels=[256, 512, last_stage_out_channels],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32])),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.7),
max_per_img=300))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
# only on Val
# you can turn on `batch_shapes_cfg`,
# we tested YOLOv8-m will get 0.02 higher than not using it.
batch_shapes_cfg = None
# batch_shapes_cfg = dict(
# type='BatchShapePolicy',
# batch_size=val_batch_size_per_gpu,
# img_size=img_scale[0],
# size_divisor=32,
# extra_pad_ratio=0.5)
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox')
test_evaluator = val_evaluator
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -0,0 +1,9 @@
_base_ = './yolov8_l_syncbn_fast_8xb16-500e_coco.py'
deepen_factor = 1.00
widen_factor = 1.25
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_backbone import BaseBackbone
from .csp_darknet import YOLOv5CSPDarknet, YOLOXCSPDarknet
from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
from .csp_resnet import PPYOLOECSPResNet
from .cspnext import CSPNeXt
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
@ -8,5 +8,6 @@ from .yolov7_backbone import YOLOv7Backbone
__all__ = [
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet'
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
'YOLOv8CSPDarknet'
]

View File

@ -8,7 +8,7 @@ from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import SPPFBottleneck
from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
from ..utils import make_divisible, make_round
from .base_backbone import BaseBackbone
@ -16,12 +16,10 @@ from .base_backbone import BaseBackbone
@MODELS.register_module()
class YOLOv5CSPDarknet(BaseBackbone):
"""CSP-Darknet backbone used in YOLOv5.
Args:
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
@ -43,7 +41,6 @@ class YOLOv5CSPDarknet(BaseBackbone):
and its variants only. Defaults to False.
init_cfg (Union[dict,list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv5CSPDarknet
>>> import torch
@ -157,6 +154,140 @@ class YOLOv5CSPDarknet(BaseBackbone):
super().init_weights()
@MODELS.register_module()
class YOLOv8CSPDarknet(BaseBackbone):
"""CSP-Darknet backbone used in YOLOv8.
Args:
arch (str): Architecture of CSP-Darknet, from {P5}.
Defaults to P5.
last_stage_out_channels (int): Final layer output channel.
Defaults to 1024.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels (int): Number of input image channels. Defaults to: 3.
out_indices (Tuple[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (Union[dict,list[dict]], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmyolo.models import YOLOv8CSPDarknet
>>> import torch
>>> model = YOLOv8CSPDarknet()
>>> model.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = model(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 256, 52, 52)
(1, 512, 26, 26)
(1, 1024, 13, 13)
"""
# From left to right:
# in_channels, out_channels, num_blocks, add_identity, use_spp
# the final out_channels will be set according to the param.
arch_settings = {
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 6, True, False], [512, None, 3, True, True]],
}
def __init__(self,
arch: str = 'P5',
last_stage_out_channels: int = 1024,
plugins: Union[dict, List[dict]] = None,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
self.arch_settings[arch][-1][1] = last_stage_out_channels
super().__init__(
self.arch_settings[arch],
deepen_factor,
widen_factor,
input_channels=input_channels,
out_indices=out_indices,
plugins=plugins,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,
init_cfg=init_cfg)
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
return ConvModule(
self.input_channels,
make_divisible(self.arch_setting[0][0], self.widen_factor),
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
in_channels = make_divisible(in_channels, self.widen_factor)
out_channels = make_divisible(out_channels, self.widen_factor)
num_blocks = make_round(num_blocks, self.deepen_factor)
stage = []
conv_layer = ConvModule(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(conv_layer)
csp_layer = CSPLayerWithTwoConv(
out_channels,
out_channels,
num_blocks=num_blocks,
add_identity=add_identity,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(csp_layer)
if use_spp:
spp = SPPFBottleneck(
out_channels,
out_channels,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.append(spp)
return stage
@MODELS.register_module()
class YOLOXCSPDarknet(BaseBackbone):
"""CSP-Darknet backbone used in YOLOX.

View File

@ -4,11 +4,12 @@ from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
from .yolox_head import YOLOXHead, YOLOXHeadModule
__all__ = [
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule'
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule'
]

View File

@ -28,8 +28,8 @@ class PPYOLOEHeadModule(BaseModule):
category.
in_channels (int): Number of channels in the input feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
num_base_priors:int: The number of priors (points) at a point
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).

View File

@ -26,7 +26,7 @@ class RTMDetSepBNHeadModule(BaseModule):
in_channels (int): Number of channels in the input feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors:int: The number of priors (points) at a point
num_base_priors (int): The number of priors (points) at a point
on the feature grid. Defaults to 1.
feat_channels (int): Number of hidden channels. Used in child classes.
Defaults to 256

View File

@ -42,8 +42,8 @@ class YOLOv5HeadModule(BaseModule):
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
num_base_priors:int: The number of priors (points) at a point
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).

View File

@ -29,7 +29,7 @@ class YOLOv6HeadModule(BaseModule):
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
channels in each layer by this amount. Defaults to 1.0.
num_base_priors: (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.

View File

@ -0,0 +1,249 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmdet.models.utils import multi_apply
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig)
from mmengine.model import BaseModule, bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from ..utils import make_divisible
from .yolov5_head import YOLOv5Head
@MODELS.register_module()
class YOLOv8HeadModule(BaseModule):
"""YOLOv8HeadModule head module used in `YOLOv8`.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to [8, 16, 32].
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max-1}``
in QFL setting. Defaults to 16.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: Union[int, Sequence],
widen_factor: float = 1.0,
num_base_priors: int = 1,
featmap_strides: Sequence[int] = (8, 16, 32),
reg_max: int = 16,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.num_base_priors = num_base_priors
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_channels = in_channels
self.reg_max = reg_max
in_channels = []
for channel in self.in_channels:
channel = make_divisible(channel, widen_factor)
in_channels.append(channel)
self.in_channels = in_channels
self._init_layers()
def init_weights(self):
"""Initialize weights of the head."""
# Use prior in model initialization to improve stability
super().init_weights()
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
m.reset_parameters()
bias_init = bias_init_with_prob(0.01)
for conv_cls in self.cls_preds:
conv_cls.bias.data.fill_(bias_init)
def _init_layers(self):
"""initialize conv layers in YOLOv8 head."""
# Init decouple head
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
reg_out_channels = max(
(16, self.in_channels[0] // 4, self.reg_max * 4))
cls_out_channels = max(self.in_channels[0], self.num_classes)
for i in range(self.num_levels):
self.reg_preds.append(
nn.Sequential(
ConvModule(
in_channels=self.in_channels[i],
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
in_channels=reg_out_channels,
out_channels=reg_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
in_channels=reg_out_channels,
out_channels=4 * self.reg_max,
kernel_size=1)))
self.cls_preds.append(
nn.Sequential(
ConvModule(
in_channels=self.in_channels[i],
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
in_channels=cls_out_channels,
out_channels=cls_out_channels,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
in_channels=cls_out_channels,
out_channels=self.num_classes,
kernel_size=1)))
proj = torch.linspace(0, self.reg_max - 1,
self.reg_max).view([1, self.reg_max, 1, 1])
self.register_buffer('proj', proj, persistent=False)
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions
"""
assert len(x) == self.num_levels
return multi_apply(self.forward_single, x, self.cls_preds,
self.reg_preds)
def forward_single(self, x: torch.Tensor, cls_pred: nn.ModuleList,
reg_pred: nn.ModuleList) -> Tuple:
"""Forward feature of a single scale level."""
b, _, h, w = x.shape
cls_logit = cls_pred(x)
bbox_dist_preds = reg_pred(x)
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max, h * w]).permute(0, 2, 3, 1)
# TODO: Test whether use matmul instead of conv can
# speed up training.
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
else:
bbox_preds = bbox_dist_preds
return cls_logit, bbox_preds
# TODO Training mode is currently not supported
@MODELS.register_module()
class YOLOv8Head(YOLOv5Head):
"""YOLOv8Head head used in `YOLOv8`.
Args:
head_module(:obj:`ConfigDict` or dict): Base module used for YOLOv8Head
prior_generator(dict): Points generator feature maps
in 2D points-based detectors.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_dfl (:obj:`ConfigDict` or dict): Config of Distribution Focal
Loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0.5,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
loss_dfl=dict(
type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=0.5 / 4),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
pass
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
raise NotImplementedError('Not implemented yet !')

View File

@ -30,7 +30,7 @@ class YOLOXHeadModule(BaseModule):
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Default: 1.0.
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid
stacked_convs (int): Number of stacking convs of the head.

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ema import ExpMomentumEMA
from .yolo_bricks import (BepC3StageBlock, EELANBlock, EffectiveSELayer,
from .yolo_bricks import (BepC3StageBlock, CSPLayerWithTwoConv,
DarknetBottleneck, EELANBlock, EffectiveSELayer,
ELANBlock, ImplicitA, ImplicitM,
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
RepStageBlock, RepVGGBlock, SPPFBottleneck,
@ -10,5 +11,6 @@ __all__ = [
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock'
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock',
'CSPLayerWithTwoConv', 'DarknetBottleneck'
]

View File

@ -4,7 +4,10 @@ from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, MaxPool2d, build_norm_layer
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, MaxPool2d,
build_norm_layer)
from mmdet.models.layers.csp_layer import \
DarknetBottleneck as MMDET_DarknetBottleneck
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmengine.model import BaseModule
from mmengine.utils import digit_version
@ -1352,7 +1355,7 @@ class RepStageBlock(nn.Module):
"""Forward process.
Args:
inputs (Tensor): The input tensor.
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
@ -1361,3 +1364,147 @@ class RepStageBlock(nn.Module):
if self.block is not None:
x = self.block(x)
return x
class DarknetBottleneck(MMDET_DarknetBottleneck):
"""The basic bottleneck block used in Darknet.
Each ResBlock consists of two ConvModules and the input is added to the
final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
The first convLayer has filter size of k1Xk1 and the second one has the
filter size of k2Xk2.
Note:
This DarknetBottleneck is little different from MMDet's, we can
change the kernel size and padding for each conv.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (float): The kernel size for hidden channel.
Defaults to 0.5.
kernel_size (Sequence[int]): The kernel size of the convolution.
Defaults to (1, 3).
padding (Sequence[int]): The padding size of the convolution.
Defaults to (0, 1).
add_identity (bool): Whether to add identity to the out.
Defaults to True
use_depthwise (bool): Whether to use depthwise separable convolution.
Defaults to False
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='Swish').
"""
def __init__(self,
in_channels: int,
out_channels: int,
expansion: float = 0.5,
kernel_size: Sequence[int] = (1, 3),
padding: Sequence[int] = (0, 1),
add_identity: bool = True,
use_depthwise: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(in_channels, out_channels, init_cfg=init_cfg)
hidden_channels = int(out_channels * expansion)
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
assert isinstance(kernel_size, Sequence) and len(kernel_size) == 2
self.conv1 = ConvModule(
in_channels,
hidden_channels,
kernel_size[0],
padding=padding[0],
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = conv(
hidden_channels,
out_channels,
kernel_size[1],
stride=1,
padding=padding[1],
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.add_identity = \
add_identity and in_channels == out_channels
class CSPLayerWithTwoConv(BaseModule):
"""Cross Stage Partial Layer with 2 convolutions.
Args:
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Defaults to 0.5.
num_blocks (int): Number of blocks. Defaults to 1
add_identity (bool): Whether to add identity in blocks.
Defaults to True.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 0.5,
num_blocks: int = 1,
add_identity: bool = True, # shortcut
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.mid_channels = int(out_channels * expand_ratio)
self.main_conv = ConvModule(
in_channels,
2 * self.mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.final_conv = ConvModule(
(2 + num_blocks) * self.mid_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.blocks = nn.ModuleList(
DarknetBottleneck(
self.mid_channels,
self.mid_channels,
expansion=1,
kernel_size=(3, 3),
padding=(1, 1),
add_identity=add_identity,
use_depthwise=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg) for _ in range(num_blocks))
def forward(self, x: Tensor) -> Tensor:
"""Forward process."""
x_main = self.main_conv(x)
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1))
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks)
return self.final_conv(torch.cat(x_main, 1))

View File

@ -5,9 +5,11 @@ from .ppyoloe_csppan import PPYOLOECSPPAFPN
from .yolov5_pafpn import YOLOv5PAFPN
from .yolov6_pafpn import YOLOv6CSPRepPAFPN, YOLOv6RepPAFPN
from .yolov7_pafpn import YOLOv7PAFPN
from .yolov8_pafpn import YOLOv8PAFPN
from .yolox_pafpn import YOLOXPAFPN
__all__ = [
'YOLOv5PAFPN', 'BaseYOLONeck', 'YOLOv6RepPAFPN', 'YOLOXPAFPN',
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN'
'CSPNeXtPAFPN', 'YOLOv7PAFPN', 'PPYOLOECSPPAFPN', 'YOLOv6CSPRepPAFPN',
'YOLOv8PAFPN'
]

View File

@ -147,9 +147,8 @@ class BaseYOLONeck(BaseModule, metaclass=ABCMeta):
self.out_channels = out_channels
self.deepen_factor = deepen_factor
self.widen_factor = widen_factor
self.freeze_all = freeze_all
self.upsample_feats_cat_first = upsample_feats_cat_first
self.freeze_all = freeze_all
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg

View File

@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from .. import CSPLayerWithTwoConv
from ..utils import make_divisible, make_round
from .yolov5_pafpn import YOLOv5PAFPN
@MODELS.register_module()
class YOLOv8PAFPN(YOLOv5PAFPN):
"""Path Aggregation Network used in YOLOv8.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
freeze_all(bool): Whether to freeze the model
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: Union[List[int], int],
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
num_csp_blocks: int = 3,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
num_csp_blocks=num_csp_blocks,
freeze_all=freeze_all,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def build_reduce_layer(self, idx: int) -> nn.Module:
"""build reduce layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The reduce layer.
"""
return nn.Identity()
def build_top_down_layer(self, idx: int) -> nn.Module:
"""build top down layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The top down layer.
"""
return CSPLayerWithTwoConv(
make_divisible((self.in_channels[idx - 1] + self.in_channels[idx]),
self.widen_factor),
make_divisible(self.out_channels[idx - 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
"""build bottom up layer.
Args:
idx (int): layer idx.
Returns:
nn.Module: The bottom up layer.
"""
return CSPLayerWithTwoConv(
make_divisible(
(self.out_channels[idx] + self.out_channels[idx + 1]),
self.widen_factor),
make_divisible(self.out_channels[idx + 1], self.widen_factor),
num_blocks=make_round(self.num_csp_blocks, self.deepen_factor),
add_identity=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.config import Config
from mmyolo.models import YOLOv8Head
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv8Head(TestCase):
def setUp(self):
self.head_module = dict(
type='YOLOv8HeadModule',
num_classes=2,
in_channels=[32, 64, 128],
featmap_strides=[8, 16, 32])
def test_predict_by_feat(self):
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'ori_shape': (s, s, 3),
'scale_factor': (1.0, 1.0),
}]
test_cfg = Config(
dict(
multi_label=True,
max_per_img=300,
score_thr=0.01,
nms=dict(type='nms', iou_threshold=0.65)))
head = YOLOv8Head(head_module=self.head_module, test_cfg=test_cfg)
feat = []
for i in range(len(self.head_module['in_channels'])):
in_channel = self.head_module['in_channels'][i]
feat_size = self.head_module['featmap_strides'][i]
feat.append(
torch.rand(1, in_channel, s // feat_size, s // feat_size))
cls_scores, bbox_preds = head.forward(feat)
head.predict_by_feat(
cls_scores,
bbox_preds,
None,
img_metas,
cfg=test_cfg,
rescale=True,
with_nms=True)
head.predict_by_feat(
cls_scores,
bbox_preds,
None,
img_metas,
cfg=test_cfg,
rescale=False,
with_nms=False)

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models import YOLOv8PAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv8PAFPN(TestCase):
def test_YOLOv8PAFPN_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv8PAFPN(in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import torch
convert_dict_s = {
# backbone
'model.0': 'backbone.stem',
'model.1': 'backbone.stage1.0',
'model.2': 'backbone.stage1.1',
'model.3': 'backbone.stage2.0',
'model.4': 'backbone.stage2.1',
'model.5': 'backbone.stage3.0',
'model.6': 'backbone.stage3.1',
'model.7': 'backbone.stage4.0',
'model.8': 'backbone.stage4.1',
'model.9': 'backbone.stage4.2',
# neck
'model.12': 'neck.top_down_layers.0',
'model.15': 'neck.top_down_layers.1',
'model.16': 'neck.downsample_layers.0',
'model.18': 'neck.bottom_up_layers.0',
'model.19': 'neck.downsample_layers.1',
'model.21': 'neck.bottom_up_layers.1',
# Detector
'model.22': 'bbox_head.head_module',
}
def convert(src, dst):
"""Convert keys in pretrained YOLOv8 models to mmyolo style."""
convert_dict = convert_dict_s
try:
yolov8_model = torch.load(src)['model']
blobs = yolov8_model.state_dict()
except ModuleNotFoundError:
raise RuntimeError(
'This script must be placed under the ultralytics repo,'
' because loading the official pretrained model need'
' `model.py` to build model.'
'Also need to install hydra-core>=1.2.0 and thop>=0.1.1')
state_dict = OrderedDict()
for key, weight in blobs.items():
num, module = key.split('.')[1:3]
prefix = f'model.{num}'
new_key = key.replace(prefix, convert_dict[prefix])
if '.m.' in new_key:
new_key = new_key.replace('.m.', '.blocks.')
new_key = new_key.replace('.cv', '.conv')
elif 'bbox_head.head_module' in new_key:
new_key = new_key.replace('.cv2', '.reg_preds')
new_key = new_key.replace('.cv3', '.cls_preds')
elif 'backbone.stage4.2' in new_key:
new_key = new_key.replace('.cv', '.conv')
else:
new_key = new_key.replace('.cv1', '.main_conv')
new_key = new_key.replace('.cv2', '.final_conv')
if 'bbox_head.head_module.dfl.conv.weight' == new_key:
print('Drop "bbox_head.head_module.dfl.conv.weight", '
'because it is useless')
continue
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
# Note: This script must be placed under the YOLOv8 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov8s.pt', help='src YOLOv8 model path')
parser.add_argument('--dst', default='mmyolov8s.pth', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()