[FEATURE] Support YOLOv6 v3.0 face detection ()

* support the usage of WIDERFace dataset

* add YOLOv6FaceHead

* add YOLOv6 face detection configs

* add a checkpoint convertion script

* add a face visualizer

* fix a bug of YOLOv6CSPBep initialization
yolov6_face
Qingren 2023-08-23 15:17:25 +08:00 committed by GitHub
parent 8c4d9dc503
commit 6b52f81c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1050 additions and 7 deletions

View File

@ -0,0 +1,27 @@
_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1
# The scaling factor that controls the width of the network structure
widen_factor = 1
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
use_cspsppf=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
block_act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor)))

View File

@ -0,0 +1,20 @@
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py'
# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.67
# The scaling factor that controls the width of the network structure
widen_factor = 0.75
# -----train val related-----
affine_scale = 0.9 # YOLOv5RandomAffine scaling ratio
# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
use_cspsppf=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor)))

View File

@ -0,0 +1,279 @@
_base_ = ['../../_base_/default_runtime.py', '../../_base_/det_p5_tta.py']
# ======================= Frequently modified parameters =====================
# -----data related-----
data_root = 'data/WIDERFace/' # Root path of data
num_classes = 1 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True
# -----train val related-----
# Base learning rate for optim_wrapper
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
num_last_epochs = 15 # Last epoch number to switch training pipeline
# ======================= Possible modified parameters =======================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv6WIDERFaceDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2
# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.25
# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
# ============================== Unmodified in most cases ===================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv6EfficientRep',
out_indices=[1, 2, 3, 4],
use_cspsppf=True,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6RepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[128, 256, 512, 1024],
out_channels=[128, 256, 512],
num_csp_blocks=12,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True),
),
bbox_head=dict(
type='YOLOv6FaceHead',
head_module=dict(
type='YOLOv6FaceHeadModule',
num_classes=num_classes,
in_channels=[128, 256, 512],
stemout_channels=[128, 256, 512],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32]),
loss_bbox=dict(
type='IoULoss',
iou_mode='siou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False)),
train_cfg=dict(
initial_epoch=4,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=num_classes,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
topk=13,
alpha=1,
beta=6),
),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.4,
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=1000))
# The training pipeline of YOLOv6 is basically the same as YOLOv5.
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa
pre_transform = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114),
max_shear_degree=0.0),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_pipeline_stage2 = [
*pre_transform,
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=True,
pad_val=dict(img=114)),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_shear_degree=0.0,
),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train.txt',
data_prefix=dict(img='WIDER_train'),
filter_cfg=dict(filter_empty_gt=True, bbox_min_size=17, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val.txt',
data_prefix=dict(img='WIDER_val'),
test_mode=True,
pipeline=test_pipeline))
test_dataloader = val_dataloader
# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa
# The difference is that the scheduler_type of YOLOv6 is cosine.
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=max_keep_ckpts,
save_best='auto'))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_last_epochs,
switch_pipeline=train_pipeline_stage2)
]
val_evaluator = dict(
# TODO: support WiderFace-Evaluation for easy, medium, hard cases
type='mmdet.VOCMetric',
metric='mAP',
eval_mode='11points')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='FaceVisualizer', vis_backends=vis_backends, name='visualizer')

View File

@ -0,0 +1,25 @@
_base_ = ['./yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py']
deepen_factor = 0.70
# The scaling factor that controls the width of the network structure
widen_factor = 0.50
model = dict(
backbone=dict(
type='YOLOv6CSPBep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
hidden_ratio=0.5,
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6CSPRepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
block_act_cfg=dict(type='ReLU', inplace=True),
hidden_ratio=0.5),
bbox_head=dict(
type='YOLOv6FaceHead',
head_module=dict(stemout_channels=256, widen_factor=widen_factor),
loss_bbox=dict(type='IoULoss', iou_mode='giou')))

View File

@ -6,9 +6,10 @@ from .yolov5_coco import YOLOv5CocoDataset
from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset
from .yolov5_dota import YOLOv5DOTADataset
from .yolov5_voc import YOLOv5VOCDataset
from .yolov6_widerface import YOLOv6WIDERFaceDataset
__all__ = [
'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy',
'yolov5_collate', 'YOLOv5CrowdHumanDataset', 'YOLOv5DOTADataset',
'PoseCocoDataset'
'PoseCocoDataset', 'YOLOv6WIDERFaceDataset'
]

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets import WIDERFaceDataset
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from ..registry import DATASETS
@DATASETS.register_module()
class YOLOv6WIDERFaceDataset(BatchShapePolicyDataset, WIDERFaceDataset):
"""Dataset for YOLOv6 WIDERFace Dataset."""
pass

View File

@ -32,6 +32,9 @@ class YOLOv6EfficientRep(BaseBackbone):
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
use_cspsppf (bool): Whether to use CSPSPPFBottleneck. It is only valid
when `use_spp`=True, i.e. it may be used in the last stage of the
backbone. Defaults to False.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
@ -188,6 +191,9 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
use_cspsppf (bool): Whether to use CSPSPPFBottleneck. It is only valid
when `use_spp`=True, i.e. it may be used in the last stage of the
backbone. Defaults to False.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
@ -239,7 +245,6 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
block_cfg: ConfigType = dict(type='ConvWrapper'),
init_cfg: OptMultiConfig = None):
self.hidden_ratio = hidden_ratio
self.use_cspsppf = use_cspsppf
super().__init__(
arch=arch,
deepen_factor=deepen_factor,
@ -248,6 +253,7 @@ class YOLOv6CSPBep(YOLOv6EfficientRep):
out_indices=out_indices,
plugins=plugins,
frozen_stages=frozen_stages,
use_cspsppf=use_cspsppf,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
norm_eval=norm_eval,

View File

@ -6,6 +6,7 @@ from .rtmdet_rotated_head import (RTMDetRotatedHead,
RTMDetRotatedSepBNHeadModule)
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
from .yolov5_ins_head import YOLOv5InsHead, YOLOv5InsHeadModule
from .yolov6_face_head import YOLOv6FaceHead, YOLOv6FaceHeadModule
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
@ -13,10 +14,11 @@ from .yolox_head import YOLOXHead, YOLOXHeadModule
from .yolox_pose_head import YOLOXPoseHead, YOLOXPoseHeadModule
__all__ = [
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
'YOLOv5Head', 'YOLOv6Head', 'YOLOv6FaceHead', 'YOLOXHead',
'YOLOv5HeadModule', 'YOLOv6HeadModule', 'YOLOv6FaceHeadModule',
'YOLOXHeadModule', 'RTMDetHead', 'RTMDetSepBNHeadModule', 'YOLOv7Head',
'PPYOLOEHead', 'PPYOLOEHeadModule', 'YOLOv7HeadModule',
'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
'RTMDetInsSepBNHeadModule', 'YOLOv5InsHead', 'YOLOv5InsHeadModule',
'YOLOXPoseHead', 'YOLOXPoseHeadModule'

View File

@ -0,0 +1,434 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.models.utils import filter_scores_and_topk
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig)
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
@MODELS.register_module()
class YOLOv6FaceHeadModule(YOLOv6HeadModule):
"""YOLOv6FaceHead head module used in `YOLOv6.
<https://arxiv.org/pdf/2209.02976>`_.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
stemout_channels (Union[int, Sequence]): Number of channels of the
feature map output by stem module.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors: (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to [8, 16, 32].
None, otherwise False. Defaults to "auto".
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: Union[int, Sequence],
stemout_channels: Union[int, Sequence] = None,
widen_factor: float = 1.0,
num_base_priors: int = 1,
reg_max=0,
featmap_strides: Sequence[int] = (8, 16, 32),
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
if stemout_channels is None:
self.stemout_channels = self.in_channels
if isinstance(stemout_channels, int):
num_levels = len(featmap_strides)
self.stemout_channels = [int(stemout_channels * widen_factor)
] * num_levels
else:
self.stemout_channels = [
int(i * widen_factor) for i in stemout_channels
]
super().__init__(
num_classes=num_classes,
in_channels=in_channels,
widen_factor=widen_factor,
num_base_priors=num_base_priors,
reg_max=reg_max,
featmap_strides=featmap_strides,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def _init_layers(self):
"""initialize conv layers in YOLOv6 head."""
# Init decouple head
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.stems = nn.ModuleList()
if self.reg_max > 1:
proj = torch.arange(
self.reg_max + self.num_base_priors, dtype=torch.float)
self.register_buffer('proj', proj, persistent=False)
for i in range(self.num_levels):
self.stems.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.stemout_channels[i],
kernel_size=1,
stride=1,
padding=1 // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.cls_convs.append(
ConvModule(
in_channels=self.stemout_channels[i],
out_channels=self.stemout_channels[i],
kernel_size=3,
stride=1,
padding=3 // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.reg_convs.append(
ConvModule(
in_channels=self.stemout_channels[i],
out_channels=self.stemout_channels[i],
kernel_size=3,
stride=1,
padding=3 // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.cls_preds.append(
nn.Conv2d(
in_channels=self.stemout_channels[i],
out_channels=self.num_base_priors * self.num_classes,
kernel_size=1))
self.reg_preds.append(
nn.Conv2d(
in_channels=self.stemout_channels[i],
out_channels=(self.num_base_priors + self.reg_max) * 4 +
10,
kernel_size=1))
def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module,
cls_pred: nn.Module, reg_conv: nn.Module,
reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
"""Forward feature of a single scale level."""
b, _, h, w = x.shape
y = stem(x)
cls_x = y
reg_x = y
cls_feat = cls_conv(cls_x)
reg_feat = reg_conv(reg_x)
cls_score = cls_pred(cls_feat)
bbox_dist_preds = reg_pred(reg_feat)
keypoint_preds = bbox_dist_preds[:, -10:, :, :]
bbox_dist_preds = bbox_dist_preds[:, :-10, :, :]
if self.reg_max > 1:
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max + self.num_base_priors,
h * w]).permute(0, 3, 1, 2)
# TODO: The get_flops script cannot handle the situation of
# matmul, and needs to be fixed later
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
bbox_preds = bbox_dist_preds.softmax(3).matmul(
self.proj.view([-1, 1])).squeeze(-1)
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
else:
bbox_preds = bbox_dist_preds
if self.training:
return cls_score, bbox_preds, bbox_dist_preds, keypoint_preds
else:
return cls_score, bbox_preds, keypoint_preds
@MODELS.register_module()
class YOLOv6FaceHead(YOLOv6Head):
"""YOLOv6FaceHead head used in `YOLOv6.
<https://arxiv.org/pdf/2209.02976>`_.
Args:
head_module(ConfigType): Base module used for YOLOv6Head
prior_generator(dict): Points generator feature maps
in 2D points-based detectors.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0.5,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.VarifocalLoss',
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
keypoint_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
keypoint_preds (list[Tensor]): Face keypoints for bboxs
in all scale levels, each is a 4D-tensor, has shape
(batch_size, 10, H, W)
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds) \
and len(keypoint_preds) == len(bbox_preds)
if objectnesses is None:
with_objectnesses = False
else:
with_objectnesses = True
assert len(cls_scores) == len(objectnesses)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
multi_label = cfg.multi_label
multi_label &= self.num_classes > 1
cfg.multi_label = multi_label
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
# If the shape does not change, use the previous mlvl_priors
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size.numel() * self.num_base_priors, ), stride) for
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds, keypoint_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_keypoint_preds = [
keypoint_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 10)
for keypoint_pred in keypoint_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_decoded_bboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_bbox_preds, flatten_stride)
flatten_keypoint_preds = torch.cat(flatten_keypoint_preds, dim=1)
flatten_keypoint_preds = flatten_keypoint_preds * \
flatten_stride[None, :, None] + \
flatten_priors.repeat(1, 5)
if with_objectnesses:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
else:
flatten_objectness = [None for _ in range(num_imgs)]
results_list = []
for (bboxes, scores, keypoints, objectness,
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
flatten_keypoint_preds, flatten_objectness,
batch_img_metas):
ori_shape = img_meta['ori_shape']
scale_factor = img_meta['scale_factor']
if 'pad_param' in img_meta:
pad_param = img_meta['pad_param']
else:
pad_param = None
score_thr = cfg.get('score_thr', -1)
# yolox_style does not require the following operations
if objectness is not None and score_thr > 0 and not cfg.get(
'yolox_style', False):
conf_inds = objectness > score_thr
bboxes = bboxes[conf_inds, :]
scores = scores[conf_inds, :]
objectness = objectness[conf_inds]
if objectness is not None:
# conf = obj_conf * cls_conf
scores *= objectness[:, None]
if scores.shape[0] == 0:
empty_results = InstanceData()
empty_results.bboxes = bboxes
empty_results.scores = scores[:, 0]
empty_results.labels = scores[:, 0].int()
results_list.append(empty_results)
continue
nms_pre = cfg.get('nms_pre', 100000)
if cfg.multi_label is False:
scores, labels = scores.max(1, keepdim=True)
scores, _, keep_idxs, results = filter_scores_and_topk(
scores,
score_thr,
nms_pre,
results=dict(labels=labels[:, 0]))
labels = results['labels']
else:
scores, labels, keep_idxs, _ = filter_scores_and_topk(
scores, score_thr, nms_pre)
results = InstanceData(
scores=scores,
labels=labels,
bboxes=bboxes[keep_idxs],
keypoints=keypoints[keep_idxs])
if rescale:
if pad_param is not None:
results.bboxes -= results.bboxes.new_tensor([
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
])
results.keypoints -= results.keypoints.new_tensor(
[pad_param[2], pad_param[0]]).repeat(5)
results.bboxes /= results.bboxes.new_tensor(
scale_factor).repeat((1, 2))
results.keypoints /= results.keypoints.new_tensor(
scale_factor).repeat((1, 5))
if cfg.get('yolox_style', False):
# do not need max_per_img
cfg.max_per_img = len(results)
results = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=False,
with_nms=with_nms,
img_meta=img_meta)
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
results_list.append(results)
return results_list
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
keypoint_preds: List[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
# TODO: calculate keypoint preds
super().loss_by_feat(cls_scores, bbox_preds, bbox_dist_preds,
batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .face_visualizer import FaceVisualizer
from .misc import is_metainfo_lower, switch_to_deploy
from .setup_env import register_all_modules
__all__ = [
'register_all_modules', 'collect_env', 'switch_to_deploy',
'is_metainfo_lower'
'is_metainfo_lower', 'FaceVisualizer'
]

View File

@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from mmdet.visualization import DetLocalVisualizer
from mmengine.structures import InstanceData
from mmyolo.registry import VISUALIZERS
@VISUALIZERS.register_module()
class FaceVisualizer(DetLocalVisualizer):
def __init__(self,
name: str = 'visualizer',
image: Optional[np.ndarray] = None,
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
bbox_color: Optional[Union[str, Tuple[int]]] = None,
text_color: Optional[Union[str,
Tuple[int]]] = (200, 200, 200),
mask_color: Optional[Union[str, Tuple[int]]] = None,
keypoint_color: Optional[Union[str,
Tuple[int]]] = ('blue',
'green', 'red',
'cyan',
'yellow'),
line_width: Union[int, float] = 3,
alpha: float = 0.8) -> None:
super().__init__(name, image, vis_backends, save_dir, bbox_color,
text_color, mask_color, line_width, alpha)
self.keypoint_color = keypoint_color
def _draw_instances(self, image: np.ndarray, instances: List[InstanceData],
classes: Optional[List[str]],
palette: Optional[List[tuple]]) -> np.ndarray:
super()._draw_instances(image, instances, classes, palette)
if 'keypoints' in instances:
keypoints = instances.keypoints
for i in range(5):
self.draw_points(
positions=keypoints[:, i * 2:(i + 1) * 2],
colors=self.keypoint_color[i],
sizes=5)
return self.get_image()

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.config import Config
from mmyolo.models.dense_heads import YOLOv6FaceHead
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv6FaceHead(TestCase):
def setUp(self):
self.head_module = dict(
type='YOLOv6FaceHeadModule',
num_classes=2,
in_channels=[32, 64, 128],
stemout_channels=64,
featmap_strides=[8, 16, 32])
def test_predict_by_feat(self):
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'ori_shape': (s, s, 3),
'scale_factor': (1.0, 1.0),
}]
test_cfg = Config(
dict(
multi_label=True,
max_per_img=300,
score_thr=0.01,
nms=dict(type='nms', iou_threshold=0.65)))
head = YOLOv6FaceHead(head_module=self.head_module, test_cfg=test_cfg)
head.eval()
feat = []
for i in range(len(self.head_module['in_channels'])):
in_channel = self.head_module['in_channels'][i]
feat_size = self.head_module['featmap_strides'][i]
feat.append(
torch.rand(1, in_channel, s // feat_size, s // feat_size))
cls_scores, bbox_preds, keypoint_preds = head.forward(feat)
head.predict_by_feat(
cls_scores,
bbox_preds,
keypoint_preds,
None,
img_metas,
cfg=test_cfg,
rescale=True,
with_nms=True)
head.predict_by_feat(
cls_scores,
bbox_preds,
keypoint_preds,
None,
img_metas,
cfg=test_cfg,
rescale=False,
with_nms=False)

View File

@ -0,0 +1,127 @@
import argparse
from collections import OrderedDict
import torch
def convert(src, dst):
import sys
sys.path.append('yolov6')
try:
ckpt = torch.load(src, map_location=torch.device('cpu'))
except ModuleNotFoundError:
raise RuntimeError(
'This script must be placed under the meituan/YOLOv6 repo,'
' because loading the official pretrained model need'
' some python files to build model.')
# The saved model is the model before reparameterization
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
name = k
if 'detect' in k:
if 'proj' in k:
continue
name = k.replace('detect', 'bbox_head.head_module')
if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
continue
if 'ERBlock_2' in k:
name = k.replace('ERBlock_2', 'stage1.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_3' in k:
name = k.replace('ERBlock_3', 'stage2.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_4' in k:
name = k.replace('ERBlock_4', 'stage3.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'ERBlock_5' in k:
name = k.replace('ERBlock_5', 'stage4.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if 'stage4.0.2' in name:
name = name.replace('stage4.0.2', 'stage4.1')
name = name.replace('cv', 'conv')
elif 'reduce_layer0' in k:
name = k.replace('reduce_layer0', 'reduce_layers.2')
elif 'Rep_p4' in k:
name = k.replace('Rep_p4', 'top_down_layers.0.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'reduce_layer1' in k:
name = k.replace('reduce_layer1', 'top_down_layers.0.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_p3' in k:
name = k.replace('Rep_p3', 'top_down_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Bifusion0' in k:
name = k.replace('Bifusion0', 'upsample_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if '.upsample_transpose.' in k:
name = name.replace('.upsample_transpose.', '.')
elif 'Bifusion1' in k:
name = k.replace('Bifusion1', 'upsample_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
if '.upsample_transpose.' in k:
name = name.replace('.upsample_transpose.', '.')
elif 'Rep_n3' in k:
name = k.replace('Rep_n3', 'bottom_up_layers.0')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'Rep_n4' in k:
name = k.replace('Rep_n4', 'bottom_up_layers.1')
if '.cv' in k:
name = name.replace('.cv', '.conv')
if '.m.' in k:
name = name.replace('.m.', '.block.')
elif 'downsample2' in k:
name = k.replace('downsample2', 'downsample_layers.0')
elif 'downsample1' in k:
name = k.replace('downsample1', 'downsample_layers.1')
new_state_dict[name] = v
data = {'state_dict': new_state_dict}
torch.save(data, dst)
# Note: This script must be placed under the yolov6 repo to run.
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov6s.pt', help='src yolov6 model path')
parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()