[Feature] Support YOLOv7 P5 training (#243)

* support yolov7

* update dev

* update headmodule and convert

* update

* fix pipeline

* fix loss

* fix optimizer parameter groups

* refactor mosaic9

* refactor optim

* refactor loss

* refactor

* refactor

* refactor new

* support yolov7x

* refine

* support tiny

* refactor model

* refactor model

* support yolov7x inference

* support yolov7-tiny inference

* support yolov7-e inference

* refactor

* support yolov7-tiny train

* add docstr

* fix merge error

* fix merge error

* fix merge error

* fix lint

* fix lint

* fix lint

* fix UT

* update

* update
pull/307/head
Haian Huang(深度眸) 2022-11-21 17:30:19 +08:00 committed by GitHub
parent 5e0599c825
commit 573ff033e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 3257 additions and 416 deletions

View File

@ -0,0 +1,21 @@
_base_ = './yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py'
model = dict(
backbone=dict(arch='D'),
neck=dict(
use_maxpool_in_downsample=True,
use_in_channels_in_downsample=True,
block_cfg=dict(
type='ELANBlock',
middle_ratio=0.4,
block_ratio=0.2,
num_blocks=6,
num_convs_in_block=1),
in_channels=[384, 768, 1152, 1536],
out_channels=[192, 384, 576, 768]),
bbox_head=dict(
head_module=dict(
in_channels=[192, 384, 576, 768],
main_out_channels=[384, 768, 1152, 1536],
aux_out_channels=[384, 768, 1152, 1536],
)))

View File

@ -0,0 +1,19 @@
_base_ = './yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py'
model = dict(
backbone=dict(arch='E'),
neck=dict(
use_maxpool_in_downsample=True,
use_in_channels_in_downsample=True,
block_cfg=dict(
type='ELANBlock',
middle_ratio=0.4,
block_ratio=0.2,
num_blocks=6,
num_convs_in_block=1),
in_channels=[320, 640, 960, 1280],
out_channels=[160, 320, 480, 640]),
bbox_head=dict(
head_module=dict(
in_channels=[160, 320, 480, 640],
main_out_channels=[320, 640, 960, 1280])))

View File

@ -0,0 +1,20 @@
_base_ = './yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py'
model = dict(
backbone=dict(arch='E2E'),
neck=dict(
use_maxpool_in_downsample=True,
use_in_channels_in_downsample=True,
block_cfg=dict(
type='EELANBlock',
num_elan_block=2,
middle_ratio=0.4,
block_ratio=0.2,
num_blocks=6,
num_convs_in_block=1),
in_channels=[320, 640, 960, 1280],
out_channels=[160, 320, 480, 640]),
bbox_head=dict(
head_module=dict(
in_channels=[160, 320, 480, 640],
main_out_channels=[320, 640, 960, 1280])))

View File

@ -1,129 +0,0 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# parameters that often need to be modified
img_scale = (640, 640) # height, width
deepen_factor = 1.0
widen_factor = 1.0
max_epochs = 300
save_epoch_intervals = 10
train_batch_size_per_gpu = 16
train_num_workers = 8
val_batch_size_per_gpu = 1
val_num_workers = 2
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# different from yolov5
anchors = [[(12, 16), (19, 36), (40, 28)], [(36, 75), (76, 55), (72, 146)],
[(142, 110), (192, 243), (459, 401)]]
strides = [8, 16, 32]
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv7Backbone',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='YOLOv7PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
upsample_feats_cat_first=False,
in_channels=[512, 1024, 1024],
out_channels=[128, 256, 512],
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv7Head',
head_module=dict(
type='YOLOv5HeadModule',
num_classes=80,
in_channels=[256, 512, 1024],
widen_factor=widen_factor,
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides)),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
test_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10), # Can be accelerated
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox')
test_evaluator = val_evaluator
# train_cfg = dict(
# type='EpochBasedTrainLoop',
# max_epochs=max_epochs,
# val_interval=save_epoch_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# randomness = dict(seed=1, deterministic=True)

View File

@ -0,0 +1,267 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# parameters that often need to be modified
img_scale = (640, 640) # height, width
max_epochs = 300
save_epoch_intervals = 10
train_batch_size_per_gpu = 16
train_num_workers = 8
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
val_batch_size_per_gpu = 1
val_num_workers = 2
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# different from yolov5
anchors = [
[(12, 16), (19, 36), (40, 28)], # P3/8
[(36, 75), (76, 55), (72, 146)], # P4/16
[(142, 110), (192, 243), (459, 401)] # P5/32
]
strides = [8, 16, 32]
num_det_layers = 3
num_classes = 80
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv7Backbone',
arch='L',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
type='YOLOv7PAFPN',
block_cfg=dict(
type='ELANBlock',
middle_ratio=0.5,
block_ratio=0.25,
num_blocks=4,
num_convs_in_block=1),
upsample_feats_cat_first=False,
in_channels=[512, 1024, 1024],
# The real output channel will be multiplied by 2
out_channels=[128, 256, 512],
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(
type='YOLOv7Head',
head_module=dict(
type='YOLOv7HeadModule',
num_classes=80,
in_channels=[256, 512, 1024],
featmap_strides=strides,
num_base_priors=3),
prior_generator=dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=anchors,
strides=strides),
# scaled based on number of detection layers
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.3 * (num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xywh',
reduction='mean',
loss_weight=0.05 * (3 / num_det_layers),
return_iou=True),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.7 * ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
obj_level_weights=[4., 1., 0.4],
# BatchYOLOv7Assigner params
prior_match_thr=4.,
simota_candidate_topk=10,
simota_iou_weight=3.0,
simota_cls_weight=1.0),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
mosiac4_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
mosiac9_pipeline = [
dict(
type='Mosaic9',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
randchoice_mosaic_pipeline = dict(
type='RandomChoice',
transforms=[mosiac4_pipeline, mosiac9_pipeline],
prob=[0.8, 0.2])
train_pipeline = [
*pre_transform,
randchoice_mosaic_pipeline,
dict(
type='YOLOv5MixUp',
alpha=8.0, # note
beta=8.0, # note
prob=0.15,
pre_transform=[*pre_transform, randchoice_mosaic_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='yolov5_collate'), # FASTER
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
param_scheduler = None
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=0.01,
momentum=0.937,
weight_decay=0.0005,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv7OptimWrapperConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=0.1, # note
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
save_param_scheduler=False,
interval=1,
save_best='auto',
max_keep_ckpts=3))
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10), # Can be accelerated
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(270, 1)])
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49)
]
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# randomness = dict(seed=1, deterministic=True)

View File

@ -0,0 +1,81 @@
_base_ = './yolov7_l_syncbn_fast_8x16b-300e_coco.py'
num_classes = _base_.num_classes
num_det_layers = _base_.num_det_layers
img_scale = _base_.img_scale
pre_transform = _base_.pre_transform
model = dict(
backbone=dict(
arch='Tiny', act_cfg=dict(type='LeakyReLU', negative_slope=0.1)),
neck=dict(
is_tiny_version=True,
in_channels=[128, 256, 512],
out_channels=[64, 128, 256],
block_cfg=dict(
_delete_=True, type='TinyDownSampleBlock', middle_ratio=0.25),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
use_repconv_outs=False),
bbox_head=dict(
head_module=dict(in_channels=[128, 256, 512]),
loss_cls=dict(loss_weight=0.5 *
(num_classes / 80 * 3 / num_det_layers)),
loss_obj=dict(loss_weight=1.0 *
((img_scale[0] / 640)**2 * 3 / num_det_layers))))
mosiac4_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.1, # change
scaling_ratio_range=(0.5, 1.6), # change
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
mosiac9_pipeline = [
dict(
type='Mosaic9',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.1, # change
scaling_ratio_range=(0.5, 1.6), # change
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
randchoice_mosaic_pipeline = dict(
type='RandomChoice',
transforms=[mosiac4_pipeline, mosiac9_pipeline],
prob=[0.8, 0.2])
train_pipeline = [
*pre_transform,
randchoice_mosaic_pipeline,
dict(
type='YOLOv5MixUp',
alpha=8.0,
beta=8.0,
prob=0.05, # change
pre_transform=[*pre_transform, randchoice_mosaic_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
default_hooks = dict(param_scheduler=dict(lr_factor=0.01))

View File

@ -0,0 +1,52 @@
_base_ = './yolov7_l_syncbn_fast_8x16b-300e_coco.py'
img_scale = (1280, 1280) # height, width
num_classes = 80
# only on Val
batch_shapes_cfg = dict(img_size=img_scale[0], size_divisor=64)
anchors = [
[(19, 27), (44, 40), (38, 94)], # P3/8
[(96, 68), (86, 152), (180, 137)], # P4/16
[(140, 301), (303, 264), (238, 542)], # P5/32
[(436, 615), (739, 380), (925, 792)] # P6/64
]
strides = [8, 16, 32, 64]
num_det_layers = 4
model = dict(
backbone=dict(arch='W', out_indices=(2, 3, 4, 5)),
neck=dict(
in_channels=[256, 512, 768, 1024],
out_channels=[128, 256, 384, 512],
use_maxpool_in_downsample=False,
use_repconv_outs=False),
bbox_head=dict(
head_module=dict(
type='YOLOv7p6HeadModule',
in_channels=[128, 256, 384, 512],
featmap_strides=strides,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
prior_generator=dict(base_sizes=anchors, strides=strides),
obj_level_weights=[4.0, 1.0, 0.25, 0.06]))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader

View File

@ -0,0 +1,15 @@
_base_ = './yolov7_l_syncbn_fast_8x16b-300e_coco.py'
model = dict(
backbone=dict(arch='X'),
neck=dict(
in_channels=[640, 1280, 1280],
out_channels=[160, 320, 640],
block_cfg=dict(
type='ELANBlock',
middle_ratio=0.4,
block_ratio=0.4,
num_blocks=3,
num_convs_in_block=2),
use_repconv_outs=False),
bbox_head=dict(head_module=dict(in_channels=[320, 640, 1280])))

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mix_img_transforms import Mosaic, YOLOv5MixUp, YOLOXMixUp
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
from .transforms import (LetterResize, LoadAnnotations, YOLOv5HSVRandomAug,
YOLOv5KeepRatioResize, YOLOv5RandomAffine)
__all__ = [
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
'YOLOv5RandomAffine'
'YOLOv5RandomAffine', 'Mosaic9'
]

View File

@ -465,6 +465,279 @@ class Mosaic(BaseMixImageTransform):
return repr_str
@TRANSFORMS.register_module()
class Mosaic9(BaseMixImageTransform):
"""Mosaic9 augmentation.
Given 9 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
The mosaic transform steps are as follows:
1. Get the center image according to the index, and randomly
sample another 8 images from the custom dataset.
2. Randomly offset the image after Mosaic
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_bboxes_labels (np.int64) (optional)
- gt_ignore_flags (np.bool) (optional)
- mix_results (List[dict])
Modified Keys:
- img
- img_shape
- gt_bboxes (optional)
- gt_bboxes_labels (optional)
- gt_ignore_flags (optional)
Args:
img_scale (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Defaults to (640, 640).
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
pad_val (int): Pad value. Defaults to 114.
pre_transform(Sequence[dict]): Sequence of transform object or
config dict to be composed.
prob (float): Probability of applying this transformation.
Defaults to 1.0.
use_cached (bool): Whether to use cache. Defaults to False.
max_cached_images (int): The maximum length of the cache. The larger
the cache, the stronger the randomness of this transform. As a
rule of thumb, providing 5 caches for each image suffices for
randomness. Defaults to 50.
random_pop (bool): Whether to randomly pop a result from the cache
when the cache is full. If set to False, use FIFO popping method.
Defaults to True.
max_refetch (int): The maximum number of retry iterations for getting
valid results from the pipeline. If the number of iterations is
greater than `max_refetch`, but results is still None, then the
iteration is terminated and raise the error. Defaults to 15.
"""
def __init__(self,
img_scale: Tuple[int, int] = (640, 640),
bbox_clip_border: bool = True,
pad_val: Union[float, int] = 114.0,
pre_transform: Sequence[dict] = None,
prob: float = 1.0,
use_cached: bool = False,
max_cached_images: int = 50,
random_pop: bool = True,
max_refetch: int = 15):
assert isinstance(img_scale, tuple)
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
f'got {prob}.'
if use_cached:
assert max_cached_images >= 9, 'The length of cache must >= 9, ' \
f'but got {max_cached_images}.'
super().__init__(
pre_transform=pre_transform,
prob=prob,
use_cached=use_cached,
max_cached_images=max_cached_images,
random_pop=random_pop,
max_refetch=max_refetch)
self.img_scale = img_scale
self.bbox_clip_border = bbox_clip_border
self.pad_val = pad_val
# intermediate variables
self._current_img_shape = [0, 0]
self._center_img_shape = [0, 0]
self._previous_img_shape = [0, 0]
def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
"""Call function to collect indexes.
Args:
dataset (:obj:`Dataset` or list): The dataset or cached list.
Returns:
list: indexes.
"""
indexes = [random.randint(0, len(dataset)) for _ in range(8)]
return indexes
def mix_img_transform(self, results: dict) -> dict:
"""Mixed image data transformation.
Args:
results (dict): Result dict.
Returns:
results (dict): Updated result dict.
"""
assert 'mix_results' in results
mosaic_bboxes = []
mosaic_bboxes_labels = []
mosaic_ignore_flags = []
img_scale_h, img_scale_w = self.img_scale
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(img_scale_h * 3), int(img_scale_w * 3), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full((int(img_scale_h * 3), int(img_scale_w * 3)),
self.pad_val,
dtype=results['img'].dtype)
# index = 0 is mean original image
# len(results['mix_results']) = 8
loc_strs = ('center', 'top', 'top_right', 'right', 'bottom_right',
'bottom', 'bottom_left', 'left', 'top_left')
results_all = [results, *results['mix_results']]
for index, results_patch in enumerate(results_all):
img_i = results_patch['img']
# keep_ratio resize
img_i_h, img_i_w = img_i.shape[:2]
scale_ratio_i = min(img_scale_h / img_i_h, img_scale_w / img_i_w)
img_i = mmcv.imresize(
img_i,
(int(img_i_w * scale_ratio_i), int(img_i_h * scale_ratio_i)))
paste_coord = self._mosaic_combine(loc_strs[index],
img_i.shape[:2])
padw, padh = paste_coord[:2]
x1, y1, x2, y2 = (max(x, 0) for x in paste_coord)
mosaic_img[y1:y2, x1:x2] = img_i[y1 - padh:, x1 - padw:]
gt_bboxes_i = results_patch['gt_bboxes']
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
gt_ignore_flags_i = results_patch['gt_ignore_flags']
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
gt_bboxes_i.translate_([padw, padh])
mosaic_bboxes.append(gt_bboxes_i)
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
mosaic_ignore_flags.append(gt_ignore_flags_i)
# Offset
offset_x = int(random.uniform(0, img_scale_w))
offset_y = int(random.uniform(0, img_scale_h))
mosaic_img = mosaic_img[offset_y:offset_y + 2 * img_scale_h,
offset_x:offset_x + 2 * img_scale_w]
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
mosaic_bboxes.translate_([-offset_x, -offset_y])
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
if self.bbox_clip_border:
mosaic_bboxes.clip_([2 * img_scale_h, 2 * img_scale_w])
else:
# remove outside bboxes
inside_inds = mosaic_bboxes.is_inside(
[2 * img_scale_h, 2 * img_scale_w]).numpy()
mosaic_bboxes = mosaic_bboxes[inside_inds]
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['gt_bboxes'] = mosaic_bboxes
results['gt_bboxes_labels'] = mosaic_bboxes_labels
results['gt_ignore_flags'] = mosaic_ignore_flags
return results
def _mosaic_combine(self, loc: str,
img_shape_hw: Tuple[int, int]) -> Tuple[int, ...]:
"""Calculate global coordinate of mosaic image.
Args:
loc (str): Index for the sub-image.
img_shape_hw (Sequence[int]): Height and width of sub-image
Returns:
paste_coord (tuple): paste corner coordinate in mosaic image.
"""
assert loc in ('center', 'top', 'top_right', 'right', 'bottom_right',
'bottom', 'bottom_left', 'left', 'top_left')
img_scale_h, img_scale_w = self.img_scale
self._current_img_shape = img_shape_hw
current_img_h, current_img_w = self._current_img_shape
previous_img_h, previous_img_w = self._previous_img_shape
center_img_h, center_img_w = self._center_img_shape
if loc == 'center':
self._center_img_shape = self._current_img_shape
# xmin, ymin, xmax, ymax
paste_coord = img_scale_w, \
img_scale_h, \
img_scale_w + current_img_w, \
img_scale_h + current_img_h
elif loc == 'top':
paste_coord = img_scale_w, \
img_scale_h - current_img_h, \
img_scale_w + current_img_w, \
img_scale_h
elif loc == 'top_right':
paste_coord = img_scale_w + previous_img_w, \
img_scale_h - current_img_h, \
img_scale_w + previous_img_w + current_img_w, \
img_scale_h
elif loc == 'right':
paste_coord = img_scale_w + center_img_w, \
img_scale_h, \
img_scale_w + center_img_w + current_img_w, \
img_scale_h + current_img_h
elif loc == 'bottom_right':
paste_coord = img_scale_w + center_img_w, \
img_scale_h + previous_img_h, \
img_scale_w + center_img_w + current_img_w, \
img_scale_h + previous_img_h + current_img_h
elif loc == 'bottom':
paste_coord = img_scale_w + center_img_w - current_img_w, \
img_scale_h + center_img_h, \
img_scale_w + center_img_w, \
img_scale_h + center_img_h + current_img_h
elif loc == 'bottom_left':
paste_coord = img_scale_w + center_img_w - \
previous_img_w - current_img_w, \
img_scale_h + center_img_h, \
img_scale_w + center_img_w - previous_img_w, \
img_scale_h + center_img_h + current_img_h
elif loc == 'left':
paste_coord = img_scale_w - current_img_w, \
img_scale_h + center_img_h - current_img_h, \
img_scale_w, \
img_scale_h + center_img_h
elif loc == 'top_left':
paste_coord = img_scale_w - current_img_w, \
img_scale_h + center_img_h - \
previous_img_h - current_img_h, \
img_scale_w, \
img_scale_h + center_img_h - previous_img_h
self._previous_img_shape = self._current_img_shape
# xmin, ymin, xmax, ymax
return paste_coord
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob})'
return repr_str
@TRANSFORMS.register_module()
class YOLOv5MixUp(BaseMixImageTransform):
"""MixUp data augmentation for YOLOv5.

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .yolov5_optim_constructor import YOLOv5OptimizerConstructor
from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor
__all__ = ['YOLOv5OptimizerConstructor']
__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor']

View File

@ -120,6 +120,10 @@ class YOLOv5OptimizerConstructor:
# bias
optimizer_cfg['params'].append({'params': params_groups[2]})
print_log(
'Optimizer groups: %g .bias, %g conv.weight, %g other' %
(len(params_groups[2]), len(params_groups[0]), len(
params_groups[1])), 'current')
del params_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)

View File

@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.optim import OptimWrapper
from mmyolo.models.dense_heads.yolov7_head import ImplicitA, ImplicitM
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
OPTIMIZERS)
# TODO: Consider merging into YOLOv5OptimizerConstructor
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class YOLOv7OptimWrapperConstructor:
"""YOLOv7 constructor for optimizer wrappers.
It has the following functions
- divides the optimizer parameters into 3 groups:
Conv, Bias and BN/ImplicitA/ImplicitM
- support `weight_decay` parameter adaption based on
`batch_size_per_gpu`
Args:
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
Positional fields are
- ``type``: class name of the OptimizerWrapper
- ``optimizer``: The configuration of optimizer.
Optional fields are
- any arguments of the corresponding optimizer wrapper type,
e.g., accumulative_counts, clip_grad, etc.
The positional fields of ``optimizer`` are
- `type`: class name of the optimizer.
Optional fields are
- any arguments of the corresponding optimizer type, e.g.,
lr, weight_decay, momentum, etc.
paramwise_cfg (dict, optional): Parameter-wise options. Must include
`base_total_batch_size` if not None. If the total input batch
is smaller than `base_total_batch_size`, the `weight_decay`
parameter will be kept unchanged, otherwise linear scaling.
Example:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> optim_wrapper_cfg = dict(
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
>>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16))
>>> paramwise_cfg = dict(base_total_batch_size=64)
>>> optim_wrapper_builder = YOLOv7OptimWrapperConstructor(
>>> optim_wrapper_cfg, paramwise_cfg)
>>> optim_wrapper = optim_wrapper_builder(model)
"""
def __init__(self,
optim_wrapper_cfg: dict,
paramwise_cfg: Optional[dict] = None):
if paramwise_cfg is None:
paramwise_cfg = {'base_total_batch_size': 64}
assert 'base_total_batch_size' in paramwise_cfg
if not isinstance(optim_wrapper_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optim_wrapper_cfg)}')
assert 'optimizer' in optim_wrapper_cfg, (
'`optim_wrapper_cfg` must contain "optimizer" config')
self.optim_wrapper_cfg = optim_wrapper_cfg
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
self.base_total_batch_size = paramwise_cfg['base_total_batch_size']
def __call__(self, model: nn.Module) -> OptimWrapper:
if is_model_wrapper(model):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
weight_decay = optimizer_cfg.pop('weight_decay', 0)
if 'batch_size_per_gpu' in optimizer_cfg:
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
# No scaling if total_batch_size is less than
# base_total_batch_size, otherwise linear scaling.
total_batch_size = get_world_size() * batch_size_per_gpu
accumulate = max(
round(self.base_total_batch_size / total_batch_size), 1)
scale_factor = total_batch_size * \
accumulate / self.base_total_batch_size
if scale_factor != 1:
weight_decay *= scale_factor
print_log(f'Scaled weight_decay to {weight_decay}', 'current')
params_groups = [], [], []
for v in model.modules():
# no decay
# Caution: Coupling with model
if isinstance(v, (ImplicitA, ImplicitM)):
params_groups[0].append(v.implicit)
elif isinstance(v, nn.modules.batchnorm._NormBase):
params_groups[0].append(v.weight)
# apply decay
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
params_groups[1].append(v.weight) # apply decay
# biases, no decay
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
params_groups[2].append(v.bias)
# Note: Make sure bias is in the last parameter group
optimizer_cfg['params'] = []
# conv
optimizer_cfg['params'].append({
'params': params_groups[1],
'weight_decay': weight_decay
})
# bn ...
optimizer_cfg['params'].append({'params': params_groups[0]})
# bias
optimizer_cfg['params'].append({'params': params_groups[2]})
print_log(
'Optimizer groups: %g .bias, %g conv.weight, %g other' %
(len(params_groups[2]), len(params_groups[1]), len(
params_groups[0])), 'current')
del params_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)
optim_wrapper = OPTIM_WRAPPERS.build(
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper

View File

@ -48,7 +48,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
In P6 model, n=5
Args:
arch_setting (dict): Architecture of BaseBackbone.
arch_setting (list): Architecture of BaseBackbone.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
@ -75,7 +75,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""
def __init__(self,
arch_setting: dict,
arch_setting: list,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,

View File

@ -1,12 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.models.backbones.csp_darknet import Focus
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import ELANBlock, MaxPoolAndStrideConvBlock
from ..layers import MaxPoolAndStrideConvBlock
from .base_backbone import BaseBackbone
@ -15,8 +16,7 @@ class YOLOv7Backbone(BaseBackbone):
"""Backbone used in YOLOv7.
Args:
arch (str): Architecture of YOLOv7, from {P5, P6}.
Defaults to P5.
arch (str): Architecture of YOLOv7Defaults to L.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
@ -40,28 +40,107 @@ class YOLOv7Backbone(BaseBackbone):
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`]): Initialization config dict.
"""
_tiny_stage1_cfg = dict(type='TinyDownSampleBlock', middle_ratio=0.5)
_tiny_stage2_4_cfg = dict(type='TinyDownSampleBlock', middle_ratio=1.0)
_l_expand_channel_2x = dict(
type='ELANBlock',
middle_ratio=0.5,
block_ratio=0.5,
num_blocks=2,
num_convs_in_block=2)
_l_no_change_channel = dict(
type='ELANBlock',
middle_ratio=0.25,
block_ratio=0.25,
num_blocks=2,
num_convs_in_block=2)
_x_expand_channel_2x = dict(
type='ELANBlock',
middle_ratio=0.4,
block_ratio=0.4,
num_blocks=3,
num_convs_in_block=2)
_x_no_change_channel = dict(
type='ELANBlock',
middle_ratio=0.2,
block_ratio=0.2,
num_blocks=3,
num_convs_in_block=2)
_w_no_change_channel = dict(
type='ELANBlock',
middle_ratio=0.5,
block_ratio=0.5,
num_blocks=2,
num_convs_in_block=2)
_e_no_change_channel = dict(
type='ELANBlock',
middle_ratio=0.4,
block_ratio=0.4,
num_blocks=3,
num_convs_in_block=2)
_d_no_change_channel = dict(
type='ELANBlock',
middle_ratio=1 / 3,
block_ratio=1 / 3,
num_blocks=4,
num_convs_in_block=2)
_e2e_no_change_channel = dict(
type='EELANBlock',
num_elan_block=2,
middle_ratio=0.4,
block_ratio=0.4,
num_blocks=3,
num_convs_in_block=2)
# From left to right:
# in_channels, out_channels, ELAN mode
# in_channels, out_channels, Block_params
arch_settings = {
'P5': [[64, 128, 'expand_channel_2x'], [256, 512, 'expand_channel_2x'],
[512, 1024, 'expand_channel_2x'],
[1024, 1024, 'no_change_channel']]
'Tiny': [[64, 64, _tiny_stage1_cfg], [64, 128, _tiny_stage2_4_cfg],
[128, 256, _tiny_stage2_4_cfg],
[256, 512, _tiny_stage2_4_cfg]],
'L': [[64, 256, _l_expand_channel_2x],
[256, 512, _l_expand_channel_2x],
[512, 1024, _l_expand_channel_2x],
[1024, 1024, _l_no_change_channel]],
'X': [[80, 320, _x_expand_channel_2x],
[320, 640, _x_expand_channel_2x],
[640, 1280, _x_expand_channel_2x],
[1280, 1280, _x_no_change_channel]],
'W':
[[64, 128, _w_no_change_channel], [128, 256, _w_no_change_channel],
[256, 512, _w_no_change_channel], [512, 768, _w_no_change_channel],
[768, 1024, _w_no_change_channel]],
'E':
[[80, 160, _e_no_change_channel], [160, 320, _e_no_change_channel],
[320, 640, _e_no_change_channel], [640, 960, _e_no_change_channel],
[960, 1280, _e_no_change_channel]],
'D': [[96, 192,
_d_no_change_channel], [192, 384, _d_no_change_channel],
[384, 768, _d_no_change_channel],
[768, 1152, _d_no_change_channel],
[1152, 1536, _d_no_change_channel]],
'E2E': [[80, 160, _e2e_no_change_channel],
[160, 320, _e2e_no_change_channel],
[320, 640, _e2e_no_change_channel],
[640, 960, _e2e_no_change_channel],
[960, 1280, _e2e_no_change_channel]],
}
def __init__(self,
arch: str = 'P5',
plugins: Union[dict, List[dict]] = None,
arch: str = 'L',
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Tuple[int] = (2, 3, 4),
frozen_stages: int = -1,
plugins: Union[dict, List[dict]] = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
assert arch in self.arch_settings.keys()
self.arch = arch
super().__init__(
self.arch_settings[arch],
deepen_factor,
@ -77,31 +156,57 @@ class YOLOv7Backbone(BaseBackbone):
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""
stem = nn.Sequential(
ConvModule(
if self.arch in ['L', 'X']:
stem = nn.Sequential(
ConvModule(
3,
int(self.arch_setting[0][0] * self.widen_factor // 2),
3,
padding=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
int(self.arch_setting[0][0] * self.widen_factor // 2),
int(self.arch_setting[0][0] * self.widen_factor),
3,
padding=1,
stride=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
int(self.arch_setting[0][0] * self.widen_factor),
int(self.arch_setting[0][0] * self.widen_factor),
3,
padding=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
elif self.arch == 'Tiny':
stem = nn.Sequential(
ConvModule(
3,
int(self.arch_setting[0][0] * self.widen_factor // 2),
3,
padding=1,
stride=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
int(self.arch_setting[0][0] * self.widen_factor // 2),
int(self.arch_setting[0][0] * self.widen_factor),
3,
padding=1,
stride=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
elif self.arch in ['W', 'E', 'D', 'E2E']:
stem = Focus(
3,
int(self.arch_setting[0][0] * self.widen_factor // 2),
3,
padding=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
int(self.arch_setting[0][0] * self.widen_factor // 2),
int(self.arch_setting[0][0] * self.widen_factor),
3,
padding=1,
stride=2,
kernel_size=3,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
int(self.arch_setting[0][0] * self.widen_factor),
int(self.arch_setting[0][0] * self.widen_factor),
3,
padding=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
act_cfg=self.act_cfg)
return stem
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
@ -111,14 +216,43 @@ class YOLOv7Backbone(BaseBackbone):
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
in_channels, out_channels, elan_mode = setting
in_channels, out_channels, stage_block_cfg = setting
in_channels = int(in_channels * self.widen_factor)
out_channels = int(out_channels * self.widen_factor)
stage_block_cfg = stage_block_cfg.copy()
stage_block_cfg.setdefault('norm_cfg', self.norm_cfg)
stage_block_cfg.setdefault('act_cfg', self.act_cfg)
stage_block_cfg['in_channels'] = in_channels
stage_block_cfg['out_channels'] = out_channels
stage = []
if stage_idx == 0:
pre_layer = ConvModule(
if self.arch in ['W', 'E', 'D', 'E2E']:
stage_block_cfg['in_channels'] = out_channels
elif self.arch in ['L', 'X']:
if stage_idx == 0:
stage_block_cfg['in_channels'] = out_channels // 2
downsample_layer = self._build_downsample_layer(
stage_idx, in_channels, out_channels)
stage.append(MODELS.build(stage_block_cfg))
if downsample_layer is not None:
stage.insert(0, downsample_layer)
return stage
def _build_downsample_layer(self, stage_idx: int, in_channels: int,
out_channels: int) -> Optional[nn.Module]:
"""Build a downsample layer pre stage."""
if self.arch in ['E', 'D', 'E2E']:
downsample_layer = MaxPoolAndStrideConvBlock(
in_channels,
out_channels,
use_in_channels_of_middle=True,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
elif self.arch == 'W':
downsample_layer = ConvModule(
in_channels,
out_channels,
3,
@ -126,24 +260,26 @@ class YOLOv7Backbone(BaseBackbone):
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
elan_layer = ELANBlock(
out_channels,
mode=elan_mode,
num_blocks=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.extend([pre_layer, elan_layer])
else:
pre_layer = MaxPoolAndStrideConvBlock(
in_channels,
mode='reduce_channel_2x',
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
elan_layer = ELANBlock(
in_channels,
mode=elan_mode,
num_blocks=2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
stage.extend([pre_layer, elan_layer])
return stage
elif self.arch == 'Tiny':
if stage_idx != 0:
downsample_layer = nn.MaxPool2d(2, 2)
else:
downsample_layer = None
elif self.arch in ['L', 'X']:
if stage_idx == 0:
downsample_layer = ConvModule(
in_channels,
out_channels // 2,
3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
downsample_layer = MaxPoolAndStrideConvBlock(
in_channels,
in_channels,
use_in_channels_of_middle=False,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
return downsample_layer

View File

@ -3,11 +3,12 @@ from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
from .yolov7_head import YOLOv7Head
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
from .yolox_head import YOLOXHead, YOLOXHeadModule
__all__ = [
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule'
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
'YOLOv7HeadModule', 'YOLOv7p6HeadModule'
]

View File

@ -1,84 +1,202 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import math
from typing import List, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig)
from mmcv.cnn import ConvModule
from mmdet.models.utils import multi_apply
from mmdet.utils import ConfigType, OptInstanceList
from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from .yolov5_head import YOLOv5Head
from ..layers import ImplicitA, ImplicitM
from ..task_modules.assigners.batch_yolov7_assigner import BatchYOLOv7Assigner
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
@MODELS.register_module()
class YOLOv7HeadModule(YOLOv5HeadModule):
"""YOLOv7Head head module used in YOLOv7."""
def _init_layers(self):
"""initialize conv layers in YOLOv7 head."""
self.convs_pred = nn.ModuleList()
for i in range(self.num_levels):
conv_pred = nn.Sequential(
ImplicitA(self.in_channels[i]),
nn.Conv2d(self.in_channels[i],
self.num_base_priors * self.num_out_attrib, 1),
ImplicitM(self.num_base_priors * self.num_out_attrib),
)
self.convs_pred.append(conv_pred)
def init_weights(self):
"""Initialize the bias of YOLOv7 head."""
super(YOLOv5HeadModule, self).init_weights()
for mi, s in zip(self.convs_pred, self.featmap_strides): # from
mi = mi[1] # nn.Conv2d
b = mi.bias.data.view(3, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
# TODO: to check
@MODELS.register_module()
class YOLOv7p6HeadModule(YOLOv5HeadModule):
"""YOLOv7Head head module used in YOLOv7."""
def __init__(self,
*args,
main_out_channels: Sequence[int] = [256, 512, 768, 1024],
aux_out_channels: Sequence[int] = [320, 640, 960, 1280],
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
**kwargs):
self.main_out_channels = main_out_channels
self.aux_out_channels = aux_out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
super().__init__(*args, **kwargs)
def _init_layers(self):
"""initialize conv layers in YOLOv7 head."""
self.main_convs_pred = nn.ModuleList()
self.aux_convs_pred = nn.ModuleList()
for i in range(self.num_levels):
conv_pred = nn.Sequential(
ConvModule(
self.in_channels[i],
self.main_out_channels[i],
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ImplicitA(self.main_out_channels[i]),
nn.Conv2d(self.main_out_channels[i],
self.num_base_priors * self.num_out_attrib, 1),
ImplicitM(self.num_base_priors * self.num_out_attrib),
)
self.main_convs_pred.append(conv_pred)
aux_pred = nn.Sequential(
ConvModule(
self.in_channels[i],
self.aux_out_channels[i],
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(self.aux_out_channels[i],
self.num_base_priors * self.num_out_attrib, 1))
self.aux_convs_pred.append(aux_pred)
def init_weights(self):
"""Initialize the bias of YOLOv5 head."""
super(YOLOv5HeadModule, self).init_weights()
for mi, aux, s in zip(self.main_convs_pred, self.aux_convs_pred,
self.featmap_strides): # from
mi = mi[2] # nn.Conv2d
b = mi.bias.data.view(3, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
aux = aux[1] # nn.Conv2d
b = aux.bias.data.view(3, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions, and objectnesses.
"""
assert len(x) == self.num_levels
return multi_apply(self.forward_single, x, self.main_convs_pred,
self.aux_convs_pred)
def forward_single(self, x: Tensor, convs: nn.Module,
aux_convs: nn.Module) \
-> Tuple[Union[Tensor, List], Union[Tensor, List],
Union[Tensor, List]]:
"""Forward feature of a single scale level."""
pred_map = convs(x)
bs, _, ny, nx = pred_map.shape
pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
ny, nx)
cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
if not self.training:
return cls_score, bbox_pred, objectness
else:
aux_pred_map = aux_convs(x)
aux_pred_map = aux_pred_map.view(bs, self.num_base_priors,
self.num_out_attrib, ny, nx)
aux_cls_score = aux_pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
aux_bbox_pred = aux_pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
aux_objectness = aux_pred_map[:, :, 4:5,
...].reshape(bs, -1, ny, nx)
return [cls_score,
aux_cls_score], [bbox_pred, aux_bbox_pred
], [objectness, aux_objectness]
# Training mode is currently not supported
@MODELS.register_module()
class YOLOv7Head(YOLOv5Head):
"""YOLOv7Head head used in `YOLOv7 <https://arxiv.org/abs/2207.02696>`_.
Args:
head_module(nn.Module): Base module used for YOLOv6Head
prior_generator(dict): Points generator feature maps
in 2D points-based detectors.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
simota_candidate_topk (int): The candidate top-k which used to
get top-k ious to calculate dynamic-k in BatchYOLOv7Assigner.
Defaults to 10.
simota_iou_weight (float): The scale factor for regression
iou cost in BatchYOLOv7Assigner. Defaults to 3.0.
simota_cls_weight (float): The scale factor for classification
cost in BatchYOLOv7Assigner. Defaults to 1.0.
"""
def __init__(self,
head_module: nn.Module,
prior_generator: ConfigType = dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=[[(10, 13), (16, 30), (33, 23)],
[(30, 61), (62, 45), (59, 119)],
[(116, 90), (156, 198), (373, 326)]],
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='YOLOv5BBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_obj=loss_obj,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
pass
*args,
simota_candidate_topk: int = 10,
simota_iou_weight: float = 3.0,
simota_cls_weight: float = 1.0,
**kwargs):
super().__init__(*args, **kwargs)
self.assigner = BatchYOLOv7Assigner(
num_classes=self.num_classes,
num_base_priors=self.num_base_priors,
featmap_strides=self.featmap_strides,
prior_match_thr=self.prior_match_thr,
candidate_topk=simota_candidate_topk,
iou_weight=simota_iou_weight,
cls_weight=simota_cls_weight)
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
@ -92,6 +210,9 @@ class YOLOv7Head(YOLOv5Head):
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
objectnesses (Sequence[Tensor]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@ -104,4 +225,122 @@ class YOLOv7Head(YOLOv5Head):
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
raise NotImplementedError('Not implemented yet')
batch_size = cls_scores[0].shape[0]
device = cls_scores[0].device
loss_cls = torch.zeros(1, device=device)
loss_box = torch.zeros(1, device=device)
loss_obj = torch.zeros(1, device=device)
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
cls_scores)
scaled_factors = [
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
for head_pred in head_preds
]
# 1. Convert gt to norm xywh format
# (num_base_priors, num_batch_gt, 7)
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
# w_norm, h_norm, prior_idx)
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)
assigner_results = self.assigner(
head_preds, batch_targets_normed,
batch_img_metas[0]['batch_input_shape'], self.priors_base_sizes,
self.grid_offset)
# mlvl is mean multi_level
mlvl_positive_infos = assigner_results['mlvl_positive_infos']
mlvl_priors = assigner_results['mlvl_priors']
mlvl_targets_normed = assigner_results['mlvl_targets_normed']
# calc losses
for i, head_pred in enumerate(head_preds):
batch_inds, proir_idx, grid_x, grid_y = mlvl_positive_infos[i].T
num_pred_positive = batch_inds.shape[0]
target_obj = torch.zeros_like(head_pred[..., 0])
# empty positive sampler
if num_pred_positive == 0:
loss_box += head_pred[..., :4].sum() * 0
loss_cls += head_pred[..., 5:].sum() * 0
loss_obj += self.loss_obj(
head_pred[..., 4], target_obj) * self.obj_level_weights[i]
continue
priors = mlvl_priors[i]
targets_normed = mlvl_targets_normed[i]
head_pred_positive = head_pred[batch_inds, proir_idx, grid_y,
grid_x]
# calc bbox loss
grid_xy = torch.stack([grid_x, grid_y], dim=1)
decoded_pred_bbox = self._decode_bbox_to_xywh(
head_pred_positive[:, :4], priors, grid_xy)
target_bbox_scaled = targets_normed[:, 2:6] * scaled_factors[i]
loss_box_i, iou = self.loss_bbox(decoded_pred_bbox,
target_bbox_scaled)
loss_box += loss_box_i
# calc obj loss
target_obj[batch_inds, proir_idx, grid_y,
grid_x] = iou.detach().clamp(0).type(target_obj.dtype)
loss_obj += self.loss_obj(head_pred[..., 4],
target_obj) * self.obj_level_weights[i]
# calc cls loss
if self.num_classes > 1:
pred_cls_scores = targets_normed[:, 1].long()
target_class = torch.full_like(
head_pred_positive[:, 5:], 0., device=device)
target_class[range(num_pred_positive), pred_cls_scores] = 1.
loss_cls += self.loss_cls(head_pred_positive[:, 5:],
target_class)
else:
loss_cls += head_pred_positive[:, 5:].sum() * 0
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * batch_size * world_size,
loss_obj=loss_obj * batch_size * world_size,
loss_bbox=loss_box * batch_size * world_size)
def _merge_predict_results(self, bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
cls_scores: Sequence[Tensor]) -> List[Tensor]:
"""Merge predict output from 3 heads.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
objectnesses (Sequence[Tensor]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
Returns:
List[Tensor]: Merged output.
"""
head_preds = []
for bbox_pred, objectness, cls_score in zip(bbox_preds, objectnesses,
cls_scores):
b, _, h, w = bbox_pred.shape
bbox_pred = bbox_pred.reshape(b, self.num_base_priors, -1, h, w)
objectness = objectness.reshape(b, self.num_base_priors, -1, h, w)
cls_score = cls_score.reshape(b, self.num_base_priors, -1, h, w)
head_pred = torch.cat([bbox_pred, objectness, cls_score],
dim=2).permute(0, 1, 3, 4, 2).contiguous()
head_preds.append(head_pred)
return head_preds
def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes,
grid_xy) -> Tensor:
bbox_pred = bbox_pred.sigmoid()
pred_xy = bbox_pred[:, :2] * 2 - 0.5 + grid_xy
pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
return decoded_bbox_pred

View File

@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ema import ExpMomentumEMA
from .yolo_bricks import (BepC3StageBlock, EffectiveSELayer, ELANBlock,
from .yolo_bricks import (BepC3StageBlock, EELANBlock, EffectiveSELayer,
ELANBlock, ImplicitA, ImplicitM,
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
RepStageBlock, RepVGGBlock, SPPFBottleneck,
SPPFCSPBlock)
SPPFCSPBlock, TinyDownSampleBlock)
__all__ = [
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'BepC3StageBlock'
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock'
]

View File

@ -9,7 +9,6 @@ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from torch import Tensor
from torch.nn.parameter import Parameter
from mmyolo.registry import MODELS
@ -32,6 +31,7 @@ else:
class SPPFBottleneck(BaseModule):
"""Spatial pyramid pooling - Fast (SPPF) layer for
YOLOv5, YOLOX and PPYOLOE by Glenn Jocher
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
@ -230,9 +230,9 @@ class RepVGGBlock(nn.Module):
def forward(self, inputs: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
@ -271,9 +271,9 @@ class RepVGGBlock(nn.Module):
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
"""Pad 1x1 tensor to 3x3.
Args:
kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
Returns:
Tensor: 3x3 kernel after padded.
"""
@ -288,6 +288,7 @@ class RepVGGBlock(nn.Module):
Args:
branch (nn.Module): The layer that needs to be equivalently
transformed, which can be nn.Sequential or nn.Batchnorm2d
Returns:
tuple: Equivalent kernel and bias
"""
@ -467,7 +468,7 @@ class BottleRep(nn.Module):
else:
self.shortcut = True
if adaptive_weight:
self.alpha = Parameter(torch.ones(1))
self.alpha = nn.Parameter(torch.ones(1))
else:
self.alpha = 1.0
@ -528,6 +529,7 @@ class EffectiveSELayer(nn.Module):
arxiv (https://arxiv.org/abs/1911.06667)
This code referenced to
https://github.com/youngwanLEE/CenterMask/blob/72147e8aae673fcaf4103ee90a6a6b73863e7fa1/maskrcnn_benchmark/modeling/backbone/vovnet.py#L108-L121 # noqa
Args:
channels (int): The input and output channels of this Module.
act_cfg (dict): Config dict for activation layer.
@ -556,13 +558,13 @@ class EffectiveSELayer(nn.Module):
class PPYOLOESELayer(nn.Module):
"""Squeeze-and-Excitation Attention Module for PPYOLOE.
There are some differences between the current implementation and
There are some differences between the current implementation and
SELayer in mmdet:
1. For fast speed and avoiding double inference in ppyoloe,
use `F.adaptive_avg_pool2d` before PPYOLOESELayer.
2. Special ways to init weights.
3. Different convolution order.
Args:
feat_channels (int): The input (and output) channels of the SE layer.
norm_cfg (dict): Config dict for normalization layer.
@ -602,19 +604,21 @@ class PPYOLOESELayer(nn.Module):
return self.conv(feat * weight)
@MODELS.register_module()
class ELANBlock(BaseModule):
"""Efficient layer aggregation networks for YOLOv7.
- if mode is `reduce_channel_2x`, the output channel will be
reduced by a factor of 2
- if mode is `no_change_channel`, the output channel does not change.
- if mode is `expand_channel_2x`, the output channel will be
expanded by a factor of 2
Args:
in_channels (int): The input channels of this Module.
mode (str): Output channel mode. Defaults to `expand_channel_2x`.
out_channels (int): The out channels of this Module.
middle_ratio (float): The scaling ratio of the middle layer
based on the in_channels.
block_ratio (float): The scaling ratio of the block layer
based on the in_channels.
num_blocks (int): The number of blocks in the main branch.
Defaults to 2.
num_convs_in_block (int): The number of convs pre block.
Defaults to 1.
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
which means using conv2d. Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
@ -627,37 +631,28 @@ class ELANBlock(BaseModule):
def __init__(self,
in_channels: int,
mode: str = 'expand_channel_2x',
out_channels: int,
middle_ratio: float,
block_ratio: float,
num_blocks: int = 2,
num_convs_in_block: int = 1,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert num_blocks >= 1
assert num_convs_in_block >= 1
assert mode in ('expand_channel_2x', 'no_change_channel',
'reduce_channel_2x')
if mode == 'expand_channel_2x':
mid_channels = in_channels // 2
block_channels = mid_channels
final_conv_in_channels = 2 * in_channels
final_conv_out_channels = 2 * in_channels
elif mode == 'no_change_channel':
mid_channels = in_channels // 4
block_channels = mid_channels
final_conv_in_channels = in_channels
final_conv_out_channels = in_channels
else:
mid_channels = in_channels // 2
block_channels = mid_channels // 2
final_conv_in_channels = in_channels * 2
final_conv_out_channels = in_channels // 2
middle_channels = int(in_channels * middle_ratio)
block_channels = int(in_channels * block_ratio)
final_conv_in_channels = int(
num_blocks * block_channels) + 2 * middle_channels
self.main_conv = ConvModule(
in_channels,
mid_channels,
middle_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
@ -665,7 +660,7 @@ class ELANBlock(BaseModule):
self.short_conv = ConvModule(
in_channels,
mid_channels,
middle_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
@ -673,9 +668,9 @@ class ELANBlock(BaseModule):
self.blocks = nn.ModuleList()
for _ in range(num_blocks):
if mode == 'reduce_channel_2x':
if num_convs_in_block == 1:
internal_block = ConvModule(
mid_channels,
middle_channels,
block_channels,
3,
padding=1,
@ -683,29 +678,26 @@ class ELANBlock(BaseModule):
norm_cfg=norm_cfg,
act_cfg=act_cfg)
else:
internal_block = nn.Sequential(
ConvModule(
mid_channels,
block_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
block_channels,
block_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
mid_channels = block_channels
internal_block = []
for _ in range(num_convs_in_block):
internal_block.append(
ConvModule(
middle_channels,
block_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
middle_channels = block_channels
internal_block = nn.Sequential(*internal_block)
middle_channels = block_channels
self.blocks.append(internal_block)
self.final_conv = ConvModule(
final_conv_in_channels,
final_conv_out_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
@ -727,16 +719,38 @@ class ELANBlock(BaseModule):
return self.final_conv(x_final)
@MODELS.register_module()
class EELANBlock(BaseModule):
"""Expand efficient layer aggregation networks for YOLOv7.
Args:
num_elan_block (int): The number of ELANBlock.
"""
def __init__(self, num_elan_block: int, **kwargs):
super().__init__()
assert num_elan_block >= 1
self.e_elan_blocks = nn.ModuleList()
for _ in range(num_elan_block):
self.e_elan_blocks.append(ELANBlock(**kwargs))
def forward(self, x: Tensor) -> Tensor:
outs = []
for elan_blocks in self.e_elan_blocks:
outs.append(elan_blocks(x))
return sum(outs)
class MaxPoolAndStrideConvBlock(BaseModule):
"""Max pooling and stride conv layer for YOLOv7.
- if mode is `reduce_channel_2x`, the output channel will
be reduced by a factor of 2
- if mode is `no_change_channel`, the output channel does not change.
Args:
in_channels (int): The input channels of this Module.
mode (str): Output channel mode. `reduce_channel_2x` or
`no_change_channel`. Defaults to `reduce_channel_2x`
out_channels (int): The out channels of this Module.
maxpool_kernel_sizes (int): kernel sizes of pooling layers.
Defaults to 2.
use_in_channels_of_middle (bool): Whether to calculate middle channels
based on in_channels. Defaults to False.
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
which means using conv2d. Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
@ -749,7 +763,9 @@ class MaxPoolAndStrideConvBlock(BaseModule):
def __init__(self,
in_channels: int,
mode: str = 'reduce_channel_2x',
out_channels: int,
maxpool_kernel_sizes: int = 2,
use_in_channels_of_middle: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
@ -757,33 +773,31 @@ class MaxPoolAndStrideConvBlock(BaseModule):
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
assert mode in ('no_change_channel', 'reduce_channel_2x')
if mode == 'reduce_channel_2x':
out_channels = in_channels // 2
else:
out_channels = in_channels
middle_channels = in_channels if use_in_channels_of_middle \
else out_channels // 2
self.maxpool_branches = nn.Sequential(
MaxPool2d(2, 2),
MaxPool2d(
kernel_size=maxpool_kernel_sizes, stride=maxpool_kernel_sizes),
ConvModule(
in_channels,
out_channels,
out_channels // 2,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.stride_conv_branches = nn.Sequential(
ConvModule(
in_channels,
out_channels,
middle_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
out_channels,
out_channels,
middle_channels,
out_channels // 2,
3,
stride=2,
padding=1,
@ -801,9 +815,96 @@ class MaxPoolAndStrideConvBlock(BaseModule):
return torch.cat([stride_conv_out, maxpool_out], dim=1)
@MODELS.register_module()
class TinyDownSampleBlock(BaseModule):
"""Down sample layer for YOLOv7-tiny.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The out channels of this Module.
middle_ratio (float): The scaling ratio of the middle layer
based on the in_channels. Defaults to 1.0.
kernel_sizes (int, tuple[int]): Sequential or number of kernel
sizes of pooling layers. Defaults to 3.
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
which means using conv2d. Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
middle_ratio: float = 1.0,
kernel_sizes: Union[int, Sequence[int]] = 3,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='LeakyReLU', negative_slope=0.1),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
middle_channels = int(in_channels * middle_ratio)
self.short_conv = ConvModule(
in_channels,
middle_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.main_convs = nn.ModuleList()
for i in range(3):
if i == 0:
self.main_convs.append(
ConvModule(
in_channels,
middle_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
else:
self.main_convs.append(
ConvModule(
middle_channels,
middle_channels,
kernel_sizes,
padding=(kernel_sizes - 1) // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.final_conv = ConvModule(
middle_channels * 4,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x) -> Tensor:
short_out = self.short_conv(x)
main_outs = []
for main_conv in self.main_convs:
main_out = main_conv(x)
main_outs.append(main_out)
x = main_out
return self.final_conv(torch.cat([*main_outs[::-1], short_out], dim=1))
@MODELS.register_module()
class SPPFCSPBlock(BaseModule):
"""Spatial pyramid pooling - Fast (SPPF) layer with CSP for
YOLOv7
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
@ -811,6 +912,8 @@ class SPPFCSPBlock(BaseModule):
Defaults to 0.5.
kernel_sizes (int, tuple[int]): Sequential or number of kernel
sizes of pooling layers. Defaults to 5.
is_tiny_version (bool): Is tiny version of SPPFCSPBlock. If True,
it means it is a yolov7 tiny model. Defaults to False.
conv_cfg (dict): Config dict for convolution layer. Defaults to None.
which means using conv2d. Defaults to None.
norm_cfg (dict): Config dict for normalization layer.
@ -826,38 +929,50 @@ class SPPFCSPBlock(BaseModule):
out_channels: int,
expand_ratio: float = 0.5,
kernel_sizes: Union[int, Sequence[int]] = 5,
is_tiny_version: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.is_tiny_version = is_tiny_version
mid_channels = int(2 * out_channels * expand_ratio)
self.main_layers = nn.Sequential(
ConvModule(
if is_tiny_version:
self.main_layers = ConvModule(
in_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
act_cfg=act_cfg)
else:
self.main_layers = nn.Sequential(
ConvModule(
in_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.kernel_sizes = kernel_sizes
if isinstance(kernel_sizes, int):
@ -869,24 +984,33 @@ class SPPFCSPBlock(BaseModule):
for ks in kernel_sizes
])
self.fuse_layers = nn.Sequential(
ConvModule(
if is_tiny_version:
self.fuse_layers = ConvModule(
4 * mid_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
act_cfg=act_cfg)
else:
self.fuse_layers = nn.Sequential(
ConvModule(
4 * mid_channels,
mid_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
mid_channels,
mid_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.short_layers = ConvModule(
self.short_layer = ConvModule(
in_channels,
mid_channels,
1,
@ -911,15 +1035,66 @@ class SPPFCSPBlock(BaseModule):
if isinstance(self.kernel_sizes, int):
y1 = self.poolings(x1)
y2 = self.poolings(y1)
x1 = self.fuse_layers(
torch.cat([x1] + [y1, y2, self.poolings(y2)], 1))
concat_list = [x1] + [y1, y2, self.poolings(y2)]
if self.is_tiny_version:
x1 = self.fuse_layers(torch.cat(concat_list[::-1], 1))
else:
x1 = self.fuse_layers(torch.cat(concat_list, 1))
else:
x1 = self.fuse_layers(
torch.cat([x1] + [m(x1) for m in self.poolings], 1))
x2 = self.short_layers(x)
concat_list = [x1] + [m(x1) for m in self.poolings]
if self.is_tiny_version:
x1 = self.fuse_layers(torch.cat(concat_list[::-1], 1))
else:
x1 = self.fuse_layers(torch.cat(concat_list, 1))
x2 = self.short_layer(x)
return self.final_conv(torch.cat((x1, x2), dim=1))
class ImplicitA(nn.Module):
"""Implicit add layer in YOLOv7.
Args:
in_channels (int): The input channels of this Module.
mean (float): Mean value of implicit module. Defaults to 0.
std (float): Std value of implicit module. Defaults to 0.02
"""
def __init__(self, in_channels: int, mean: float = 0., std: float = .02):
super().__init__()
self.implicit = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
nn.init.normal_(self.implicit, mean=mean, std=std)
def forward(self, x):
"""Forward process
Args:
x (Tensor): The input tensor.
"""
return self.implicit + x
class ImplicitM(nn.Module):
"""Implicit multiplier layer in YOLOv7.
Args:
in_channels (int): The input channels of this Module.
mean (float): Mean value of implicit module. Defaults to 1.
std (float): Std value of implicit module. Defaults to 0.02.
"""
def __init__(self, in_channels: int, mean: float = 1., std: float = .02):
super().__init__()
self.implicit = nn.Parameter(torch.ones(1, in_channels, 1, 1))
nn.init.normal_(self.implicit, mean=mean, std=std)
def forward(self, x):
"""Forward process
Args:
x (Tensor): The input tensor.
"""
return self.implicit * x
@MODELS.register_module()
class PPYOLOEBasicBlock(nn.Module):
"""PPYOLOE Backbone BasicBlock.
@ -966,9 +1141,9 @@ class PPYOLOEBasicBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
@ -1152,9 +1327,11 @@ class RepStageBlock(nn.Module):
block_cfg.update(
dict(in_channels=out_channels, out_channels=out_channels))
self.block = nn.Sequential(*(
MODELS.build(block_cfg)
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
self.block = None
if num_blocks > 1:
self.block = nn.Sequential(*(MODELS.build(block_cfg)
for _ in range(num_blocks - 1)))
if bottle_block == BottleRep:
self.conv1 = BottleRep(
@ -1163,19 +1340,20 @@ class RepStageBlock(nn.Module):
block_cfg=block_cfg,
adaptive_weight=True)
num_blocks = num_blocks // 2
self.block = nn.Sequential(*(
BottleRep(
self.block = None
if num_blocks > 1:
self.block = nn.Sequential(*(BottleRep(
out_channels,
out_channels,
block_cfg=block_cfg,
adaptive_weight=True)
for _ in range(num_blocks - 1))) if num_blocks > 1 else None
adaptive_weight=True) for _ in range(num_blocks - 1)))
def forward(self, x: Tensor) -> Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""

View File

@ -20,27 +20,31 @@ def bbox_overlaps(pred: torch.Tensor,
`Implementation of paper `Enhancing Geometric Factors into
Model Learning and Inference for Object Detection and Instance
Segmentation <https://arxiv.org/abs/2005.03572>`_.
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
difference in the way the alpha parameter is computed.
mmdet version:
alpha = (ious > 0.5).float() * v / (1 - ious + v)
YOLOv5 version:
alpha = v / (v - ious + (1 + eps)
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
or (x, y, w, h),shape (n, 4).
target (Tensor): Corresponding gt bboxes, shape (n, 4).
iou_mode (str): Options are "ciou".
iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou').
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
siou_theta (float): siou_theta for SIoU when calculate shape cost.
Defaults to 4.0.
eps (float): Eps to avoid log(0).
Returns:
Tensor: shape (n,).
Tensor: shape (n, ).
"""
assert iou_mode in ('ciou', 'giou', 'siou')
assert iou_mode in ('iou', 'ciou', 'giou', 'siou')
assert bbox_format in ('xyxy', 'xywh')
if bbox_format == 'xywh':
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)

View File

@ -6,8 +6,7 @@ from mmcv.cnn import ConvModule
from mmdet.utils import ConfigType, OptMultiConfig
from mmyolo.registry import MODELS
from ..layers import (ELANBlock, MaxPoolAndStrideConvBlock, RepVGGBlock,
SPPFCSPBlock)
from ..layers import MaxPoolAndStrideConvBlock, RepVGGBlock, SPPFCSPBlock
from .base_yolo_neck import BaseYOLONeck
@ -18,12 +17,21 @@ class YOLOv7PAFPN(BaseYOLONeck):
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
block_cfg (dict): Config dict for block.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
spp_expand_ratio (float): Expand ratio of SPPCSPBlock.
Defaults to 0.5.
is_tiny_version (bool): Is tiny version of neck. If True,
it means it is a yolov7 tiny model. Defaults to False.
use_maxpool_in_downsample (bool): Whether maxpooling is
used in downsample layers. Defaults to True.
use_in_channels_in_downsample (bool): MaxPoolAndStrideConvBlock
module input parameters. Defaults to False.
use_repconv_outs (bool): Whether to use `repconv` in the output
layer. Defaults to True.
upsample_feats_cat_first (bool): Whether the output features are
concat first after upsampling in the topdown module.
Defaults to True. Currently only YOLOv7 is false.
@ -39,9 +47,19 @@ class YOLOv7PAFPN(BaseYOLONeck):
def __init__(self,
in_channels: List[int],
out_channels: List[int],
block_cfg: dict = dict(
type='ELANBlock',
middle_ratio=0.5,
block_ratio=0.25,
num_blocks=4,
num_convs_in_block=1),
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
spp_expand_ratio: float = 0.5,
is_tiny_version: bool = False,
use_maxpool_in_downsample: bool = True,
use_in_channels_in_downsample: bool = False,
use_repconv_outs: bool = True,
upsample_feats_cat_first: bool = False,
freeze_all: bool = False,
norm_cfg: ConfigType = dict(
@ -49,7 +67,15 @@ class YOLOv7PAFPN(BaseYOLONeck):
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
self.is_tiny_version = is_tiny_version
self.use_maxpool_in_downsample = use_maxpool_in_downsample
self.use_in_channels_in_downsample = use_in_channels_in_downsample
self.spp_expand_ratio = spp_expand_ratio
self.use_repconv_outs = use_repconv_outs
self.block_cfg = block_cfg
self.block_cfg.setdefault('norm_cfg', norm_cfg)
self.block_cfg.setdefault('act_cfg', act_cfg)
super().__init__(
in_channels=[
int(channel * widen_factor) for channel in in_channels
@ -74,11 +100,12 @@ class YOLOv7PAFPN(BaseYOLONeck):
Returns:
nn.Module: The reduce layer.
"""
if idx == 2:
if idx == len(self.in_channels) - 1:
layer = SPPFCSPBlock(
self.in_channels[idx],
self.out_channels[idx],
expand_ratio=self.spp_expand_ratio,
is_tiny_version=self.is_tiny_version,
kernel_sizes=5,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
@ -112,12 +139,10 @@ class YOLOv7PAFPN(BaseYOLONeck):
Returns:
nn.Module: The top down layer.
"""
return ELANBlock(
self.out_channels[idx - 1] * 2,
mode='reduce_channel_2x',
num_blocks=4,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
block_cfg = self.block_cfg.copy()
block_cfg['in_channels'] = self.out_channels[idx - 1] * 2
block_cfg['out_channels'] = self.out_channels[idx - 1]
return MODELS.build(block_cfg)
def build_downsample_layer(self, idx: int) -> nn.Module:
"""build downsample layer.
@ -128,11 +153,22 @@ class YOLOv7PAFPN(BaseYOLONeck):
Returns:
nn.Module: The downsample layer.
"""
return MaxPoolAndStrideConvBlock(
self.out_channels[idx],
mode='no_change_channel',
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_maxpool_in_downsample and not self.is_tiny_version:
return MaxPoolAndStrideConvBlock(
self.out_channels[idx],
self.out_channels[idx + 1],
use_in_channels_of_middle=self.use_in_channels_in_downsample,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
return ConvModule(
self.out_channels[idx],
self.out_channels[idx + 1],
3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def build_bottom_up_layer(self, idx: int) -> nn.Module:
"""build bottom up layer.
@ -143,12 +179,10 @@ class YOLOv7PAFPN(BaseYOLONeck):
Returns:
nn.Module: The bottom up layer.
"""
return ELANBlock(
self.out_channels[idx + 1] * 2,
mode='reduce_channel_2x',
num_blocks=4,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
block_cfg = self.block_cfg.copy()
block_cfg['in_channels'] = self.out_channels[idx + 1] * 2
block_cfg['out_channels'] = self.out_channels[idx + 1]
return MODELS.build(block_cfg)
def build_out_layer(self, idx: int) -> nn.Module:
"""build out layer.
@ -159,9 +193,24 @@ class YOLOv7PAFPN(BaseYOLONeck):
Returns:
nn.Module: The out layer.
"""
return RepVGGBlock(
self.out_channels[idx],
self.out_channels[idx] * 2,
3,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if len(self.in_channels) == 4:
# P6
return nn.Identity()
out_channels = self.out_channels[idx] * 2
if self.use_repconv_outs:
return RepVGGBlock(
self.out_channels[idx],
out_channels,
3,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
return ConvModule(
self.out_channels[idx],
out_channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

View File

@ -0,0 +1,303 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps
def _cat_multi_level_tensor_in_place(*multi_level_tensor, place_hold_var):
for level_tensor in multi_level_tensor:
for i, var in enumerate(level_tensor):
if len(var) > 0:
level_tensor[i] = torch.cat(var, dim=0)
else:
level_tensor[i] = place_hold_var
class BatchYOLOv7Assigner(nn.Module):
"""Batch YOLOv7 Assigner.
It consists of two assigning steps:
1. YOLOv5 cross-grid sample assigning
2. SimOTA assigning
This code referenced to
https://github.com/WongKinYiu/yolov7/blob/main/utils/loss.py.
"""
def __init__(self,
num_classes: int,
num_base_priors,
featmap_strides,
prior_match_thr: float = 4.0,
candidate_topk: int = 10,
iou_weight: float = 3.0,
cls_weight: float = 1.0):
super().__init__()
self.num_classes = num_classes
self.num_base_priors = num_base_priors
self.featmap_strides = featmap_strides
# yolov5 param
self.prior_match_thr = prior_match_thr
# simota param
self.candidate_topk = candidate_topk
self.iou_weight = iou_weight
self.cls_weight = cls_weight
@torch.no_grad()
def forward(self, pred_results, batch_targets_normed, batch_input_shape,
priors_base_sizes, grid_offset) -> dict:
# (num_base_priors, num_batch_gt, 7)
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
# w_norm, h_norm, prior_idx)
# mlvl is mean multi_level
if batch_targets_normed.shape[1] == 0:
# empty gt of batch
num_levels = len(pred_results)
return dict(
mlvl_positive_infos=[pred_results[0].new_empty(
(0, 4))] * num_levels,
mlvl_priors=[] * num_levels,
mlvl_targets_normed=[] * num_levels)
mlvl_positive_infos, mlvl_priors = self.yolov5_assigner(
pred_results, batch_targets_normed, priors_base_sizes, grid_offset)
mlvl_positive_infos, mlvl_priors, \
mlvl_targets_normed = self.simota_assigner(
pred_results, batch_targets_normed, mlvl_positive_infos,
mlvl_priors, batch_input_shape)
place_hold_var = batch_targets_normed.new_empty((0, 4))
_cat_multi_level_tensor_in_place(
mlvl_positive_infos,
mlvl_priors,
mlvl_targets_normed,
place_hold_var=place_hold_var)
return dict(
mlvl_positive_infos=mlvl_positive_infos,
mlvl_priors=mlvl_priors,
mlvl_targets_normed=mlvl_targets_normed)
def yolov5_assigner(self, pred_results, batch_targets_normed,
priors_base_sizes, grid_offset):
num_batch_gts = batch_targets_normed.shape[1]
assert num_batch_gts > 0
mlvl_positive_infos, mlvl_priors = [], []
scaled_factor = torch.ones(7, device=pred_results[0].device)
for i in range(len(pred_results)): # lever
priors_base_sizes_i = priors_base_sizes[i]
# (1, 1, feat_shape_w, feat_shape_h, feat_shape_w, feat_shape_h)
scaled_factor[2:6] = torch.tensor(
pred_results[i].shape)[[3, 2, 3, 2]]
# Scale batch_targets from range 0-1 to range 0-features_maps size.
# (num_base_priors, num_batch_gts, 7)
batch_targets_scaled = batch_targets_normed * scaled_factor
# Shape match
wh_ratio = batch_targets_scaled[...,
4:6] / priors_base_sizes_i[:, None]
match_inds = torch.max(
wh_ratio, 1. / wh_ratio).max(2)[0] < self.prior_match_thr
batch_targets_scaled = batch_targets_scaled[
match_inds] # (num_matched_target, 7)
# no gt bbox matches anchor
if batch_targets_scaled.shape[0] == 0:
mlvl_positive_infos.append(
batch_targets_scaled.new_empty((0, 4)))
mlvl_priors.append([])
continue
# Positive samples with additional neighbors
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < 0.5) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < 0.5) & (grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
batch_targets_scaled = batch_targets_scaled.repeat(
(5, 1, 1))[offset_inds] # ()
retained_offsets = grid_offset.repeat(1, offset_inds.shape[1],
1)[offset_inds]
# batch_targets_scaled: (num_matched_target, 7)
# 7 is mean (batch_idx, cls_id, x_scaled,
# y_scaled, w_scaled, h_scaled, prior_idx)
# mlvl_positive_info: (num_matched_target, 4)
# 4 is mean (batch_idx, prior_idx, x_scaled, y_scaled)
mlvl_positive_info = batch_targets_scaled[:, [0, 6, 2, 3]]
mlvl_positive_info[:,
2:] = mlvl_positive_info[:,
2:] - retained_offsets
mlvl_positive_info[:, 2].clamp_(0, scaled_factor[2] - 1)
mlvl_positive_info[:, 3].clamp_(0, scaled_factor[3] - 1)
mlvl_positive_info = mlvl_positive_info.long()
priors_inds = mlvl_positive_info[:, 1]
mlvl_positive_infos.append(mlvl_positive_info)
mlvl_priors.append(priors_base_sizes_i[priors_inds])
return mlvl_positive_infos, mlvl_priors
def simota_assigner(self, pred_results, batch_targets_normed,
mlvl_positive_infos, mlvl_priors, batch_input_shape):
num_batch_gts = batch_targets_normed.shape[1]
assert num_batch_gts > 0
num_levels = len(mlvl_positive_infos)
mlvl_positive_infos_matched = [[] for _ in range(num_levels)]
mlvl_priors_matched = [[] for _ in range(num_levels)]
mlvl_targets_normed_matched = [[] for _ in range(num_levels)]
for batch_idx in range(pred_results[0].shape[0]):
# (num_batch_gt, 7)
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
# w_norm, h_norm, prior_idx)
targets_normed = batch_targets_normed[0]
# (num_gt, 7)
targets_normed = targets_normed[targets_normed[:, 0] == batch_idx]
num_gts = targets_normed.shape[0]
if num_gts == 0:
continue
_mlvl_decoderd_bboxes = []
_mlvl_obj_cls = []
_mlvl_priors = []
_mlvl_positive_infos = []
_from_which_layer = []
for i, head_pred in enumerate(pred_results):
# (num_matched_target, 4)
# 4 is mean (batch_idx, prior_idx, grid_x, grid_y)
_mlvl_positive_info = mlvl_positive_infos[i]
if _mlvl_positive_info.shape[0] == 0:
continue
idx = (_mlvl_positive_info[:, 0] == batch_idx)
_mlvl_positive_info = _mlvl_positive_info[idx]
_mlvl_positive_infos.append(_mlvl_positive_info)
priors = mlvl_priors[i][idx]
_mlvl_priors.append(priors)
_from_which_layer.append(
torch.ones(size=(_mlvl_positive_info.shape[0], )) * i)
# (n,85)
level_batch_idx, prior_ind, \
grid_x, grid_y = _mlvl_positive_info.T
pred_positive = head_pred[level_batch_idx, prior_ind, grid_y,
grid_x]
_mlvl_obj_cls.append(pred_positive[:, 4:])
# decoded
grid = torch.stack([grid_x, grid_y], dim=1)
pred_positive_cxcy = (pred_positive[:, :2].sigmoid() * 2. -
0.5 + grid) * self.featmap_strides[i]
pred_positive_wh = (pred_positive[:, 2:4].sigmoid() * 2) ** 2 \
* priors * self.featmap_strides[i]
pred_positive_xywh = torch.cat(
[pred_positive_cxcy, pred_positive_wh], dim=-1)
_mlvl_decoderd_bboxes.append(pred_positive_xywh)
# 1 calc pair_wise_iou_loss
_mlvl_decoderd_bboxes = torch.cat(_mlvl_decoderd_bboxes, dim=0)
num_pred_positive = _mlvl_decoderd_bboxes.shape[0]
if num_pred_positive == 0:
continue
# scaled xywh
batch_input_shape_wh = pred_results[0].new_tensor(
batch_input_shape[::-1]).repeat((1, 2))
targets_scaled_bbox = targets_normed[:, 2:6] * batch_input_shape_wh
targets_scaled_bbox = bbox_cxcywh_to_xyxy(targets_scaled_bbox)
_mlvl_decoderd_bboxes = bbox_cxcywh_to_xyxy(_mlvl_decoderd_bboxes)
pair_wise_iou = bbox_overlaps(targets_scaled_bbox,
_mlvl_decoderd_bboxes)
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
# 2 calc pair_wise_cls_loss
_mlvl_obj_cls = torch.cat(_mlvl_obj_cls, dim=0).float().sigmoid()
_mlvl_positive_infos = torch.cat(_mlvl_positive_infos, dim=0)
_from_which_layer = torch.cat(_from_which_layer, dim=0)
_mlvl_priors = torch.cat(_mlvl_priors, dim=0)
gt_cls_per_image = (
F.one_hot(targets_normed[:, 1].to(torch.int64),
self.num_classes).float().unsqueeze(1).repeat(
1, num_pred_positive, 1))
# cls_score * obj
cls_preds_ = _mlvl_obj_cls[:, 1:]\
.unsqueeze(0)\
.repeat(num_gts, 1, 1) \
* _mlvl_obj_cls[:, 0:1]\
.unsqueeze(0).repeat(num_gts, 1, 1)
y = cls_preds_.sqrt_()
pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
torch.log(y / (1 - y)), gt_cls_per_image,
reduction='none').sum(-1)
del cls_preds_
# calc cost
cost = (
self.cls_weight * pair_wise_cls_loss +
self.iou_weight * pair_wise_iou_loss)
# num_gt, num_match_pred
matching_matrix = torch.zeros_like(cost)
top_k, _ = torch.topk(
pair_wise_iou,
min(self.candidate_topk, pair_wise_iou.shape[1]),
dim=1)
dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
# Select only topk matches per gt
for gt_idx in range(num_gts):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[gt_idx][pos_idx] = 1.0
del top_k, dynamic_ks
# Each prediction box can match at most one gt box,
# and if there are more than one,
# only the least costly one can be taken
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(
cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
targets_normed = targets_normed[matched_gt_inds]
_mlvl_positive_infos = _mlvl_positive_infos[fg_mask_inboxes]
_from_which_layer = _from_which_layer[fg_mask_inboxes]
_mlvl_priors = _mlvl_priors[fg_mask_inboxes]
# Rearranged in the order of the prediction layers
# to facilitate loss
for i in range(num_levels):
layer_idx = _from_which_layer == i
mlvl_positive_infos_matched[i].append(
_mlvl_positive_infos[layer_idx])
mlvl_priors_matched[i].append(_mlvl_priors[layer_idx])
mlvl_targets_normed_matched[i].append(
targets_normed[layer_idx])
results = mlvl_positive_infos_matched, \
mlvl_priors_matched, \
mlvl_targets_normed_matched
return results

View File

@ -9,7 +9,7 @@ from mmdet.structures.bbox import HorizontalBoxes
from mmdet.structures.mask import BitmapMasks
from mmyolo.datasets import YOLOv5CocoDataset
from mmyolo.datasets.transforms import Mosaic, YOLOv5MixUp, YOLOXMixUp
from mmyolo.datasets.transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
from mmyolo.utils import register_all_modules
register_all_modules()
@ -108,6 +108,99 @@ class TestMosaic(unittest.TestCase):
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
class TestMosaic9(unittest.TestCase):
def setUp(self):
"""Setup the data info which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
rng = np.random.RandomState(0)
self.pre_transform = [
dict(
type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(type='LoadAnnotations', with_bbox=True)
]
self.dataset = YOLOv5CocoDataset(
data_prefix=dict(
img=osp.join(osp.dirname(__file__), '../../data')),
ann_file=osp.join(
osp.dirname(__file__), '../../data/coco_sample_color.json'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=[])
self.results = {
'img':
np.random.random((224, 224, 3)),
'img_shape': (224, 224),
'gt_bboxes_labels':
np.array([1, 2, 3], dtype=np.int64),
'gt_bboxes':
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
dtype=np.float32),
'gt_ignore_flags':
np.array([0, 0, 1], dtype=bool),
'gt_masks':
BitmapMasks(rng.rand(3, 224, 224), height=224, width=224),
'dataset':
self.dataset
}
def test_transform(self):
# test assertion for invalid img_scale
with self.assertRaises(AssertionError):
transform = Mosaic9(img_scale=640)
# test assertion for invalid probability
with self.assertRaises(AssertionError):
transform = Mosaic9(prob=1.5)
# test assertion for invalid max_cached_images
with self.assertRaises(AssertionError):
transform = Mosaic9(use_cached=True, max_cached_images=1)
transform = Mosaic9(
img_scale=(10, 12), pre_transform=self.pre_transform)
results = transform(copy.deepcopy(self.results))
self.assertTrue(results['img'].shape[:2] == (20, 24))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
def test_transform_with_no_gt(self):
self.results['gt_bboxes'] = np.empty((0, 4), dtype=np.float32)
self.results['gt_bboxes_labels'] = np.empty((0, ), dtype=np.int64)
self.results['gt_ignore_flags'] = np.empty((0, ), dtype=bool)
transform = Mosaic9(
img_scale=(10, 12), pre_transform=self.pre_transform)
results = transform(copy.deepcopy(self.results))
self.assertIsInstance(results, dict)
self.assertTrue(results['img'].shape[:2] == (20, 24))
self.assertTrue(
results['gt_bboxes_labels'].shape[0] == results['gt_bboxes'].
shape[0] == results['gt_ignore_flags'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
def test_transform_with_box_list(self):
transform = Mosaic9(
img_scale=(10, 12), pre_transform=self.pre_transform)
results = copy.deepcopy(self.results)
results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
results = transform(results)
self.assertTrue(results['img'].shape[:2] == (20, 24))
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
results['gt_bboxes'].shape[0])
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
class TestYOLOv5MixUp(unittest.TestCase):
def setUp(self):

View File

@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.optim import build_optim_wrapper
from mmyolo.engine import YOLOv7OptimWrapperConstructor
from mmyolo.utils import register_all_modules
register_all_modules()
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
class TestYOLOv7OptimWrapperConstructor(TestCase):
def setUp(self):
self.model = ExampleModel()
self.base_lr = 0.01
self.weight_decay = 0.0001
self.optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=self.base_lr,
momentum=0.9,
weight_decay=self.weight_decay,
batch_size_per_gpu=16))
def test_init(self):
YOLOv7OptimWrapperConstructor(copy.deepcopy(self.optim_wrapper_cfg))
YOLOv7OptimWrapperConstructor(
copy.deepcopy(self.optim_wrapper_cfg),
paramwise_cfg={'base_total_batch_size': 64})
# `paramwise_cfg` must include `base_total_batch_size` if not None.
with self.assertRaises(AssertionError):
YOLOv7OptimWrapperConstructor(
copy.deepcopy(self.optim_wrapper_cfg), paramwise_cfg={'a': 64})
def test_build(self):
optim_wrapper = YOLOv7OptimWrapperConstructor(
copy.deepcopy(self.optim_wrapper_cfg))(
self.model)
# test param_groups
assert len(optim_wrapper.optimizer.param_groups) == 3
for i in range(3):
param_groups_i = optim_wrapper.optimizer.param_groups[i]
assert param_groups_i['lr'] == self.base_lr
if i == 0:
assert param_groups_i['weight_decay'] == self.weight_decay
else:
assert param_groups_i['weight_decay'] == 0
# test weight_decay linear scaling
optim_wrapper_cfg = copy.deepcopy(self.optim_wrapper_cfg)
optim_wrapper_cfg['optimizer']['batch_size_per_gpu'] = 128
optim_wrapper = YOLOv7OptimWrapperConstructor(optim_wrapper_cfg)(
self.model)
assert optim_wrapper.optimizer.param_groups[0][
'weight_decay'] == self.weight_decay * 2
# test without batch_size_per_gpu
optim_wrapper_cfg = copy.deepcopy(self.optim_wrapper_cfg)
optim_wrapper_cfg['optimizer'].pop('batch_size_per_gpu')
optim_wrapper = dict(
optim_wrapper_cfg, constructor='YOLOv7OptimWrapperConstructor')
optim_wrapper = build_optim_wrapper(self.model, optim_wrapper)
assert optim_wrapper.optimizer.param_groups[0][
'weight_decay'] == self.weight_decay

View File

@ -0,0 +1,154 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.models.backbones import YOLOv7Backbone
from mmyolo.utils import register_all_modules
from .utils import check_norm_state
register_all_modules()
class TestYOLOv7Backbone(TestCase):
def test_init(self):
# out_indices in range(len(arch_setting) + 1)
with pytest.raises(AssertionError):
YOLOv7Backbone(out_indices=(6, ))
with pytest.raises(ValueError):
# frozen_stages must in range(-1, len(arch_setting) + 1)
YOLOv7Backbone(frozen_stages=6)
def test_forward(self):
# Test YOLOv7Backbone-L with first stage frozen
frozen_stages = 1
model = YOLOv7Backbone(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for mod in model.stem.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'stage{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test YOLOv7Backbone-L with norm_eval=True
model = YOLOv7Backbone(norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test YOLOv7Backbone-L forward with widen_factor=0.25
model = YOLOv7Backbone(
widen_factor=0.25, out_indices=tuple(range(0, 5)))
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size((1, 16, 32, 32))
assert feat[1].shape == torch.Size((1, 64, 16, 16))
assert feat[2].shape == torch.Size((1, 128, 8, 8))
assert feat[3].shape == torch.Size((1, 256, 4, 4))
assert feat[4].shape == torch.Size((1, 256, 2, 2))
# Test YOLOv7Backbone-L with plugins
model = YOLOv7Backbone(
widen_factor=0.25,
plugins=[
dict(
cfg=dict(
type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
stages=(False, False, True, True)),
])
assert len(model.stage1) == 2
assert len(model.stage2) == 2
assert len(model.stage3) == 3 # +DropBlock
assert len(model.stage4) == 3 # +DropBlock
model.train()
imgs = torch.randn(1, 3, 128, 128)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 128, 16, 16))
assert feat[1].shape == torch.Size((1, 256, 8, 8))
assert feat[2].shape == torch.Size((1, 256, 4, 4))
# Test YOLOv7Backbone-X forward with widen_factor=0.25
model = YOLOv7Backbone(arch='X', widen_factor=0.25)
model.train()
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 160, 8, 8))
assert feat[1].shape == torch.Size((1, 320, 4, 4))
assert feat[2].shape == torch.Size((1, 320, 2, 2))
# Test YOLOv7Backbone-tiny forward with widen_factor=0.25
model = YOLOv7Backbone(arch='Tiny', widen_factor=0.25)
model.train()
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 32, 8, 8))
assert feat[1].shape == torch.Size((1, 64, 4, 4))
assert feat[2].shape == torch.Size((1, 128, 2, 2))
# Test YOLOv7Backbone-w forward with widen_factor=0.25
model = YOLOv7Backbone(
arch='W', widen_factor=0.25, out_indices=(2, 3, 4, 5))
model.train()
imgs = torch.randn(1, 3, 128, 128)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 64, 16, 16))
assert feat[1].shape == torch.Size((1, 128, 8, 8))
assert feat[2].shape == torch.Size((1, 192, 4, 4))
assert feat[3].shape == torch.Size((1, 256, 2, 2))
# Test YOLOv7Backbone-w forward with widen_factor=0.25
model = YOLOv7Backbone(
arch='D', widen_factor=0.25, out_indices=(2, 3, 4, 5))
model.train()
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 96, 16, 16))
assert feat[1].shape == torch.Size((1, 192, 8, 8))
assert feat[2].shape == torch.Size((1, 288, 4, 4))
assert feat[3].shape == torch.Size((1, 384, 2, 2))
# Test YOLOv7Backbone-w forward with widen_factor=0.25
model = YOLOv7Backbone(
arch='E', widen_factor=0.25, out_indices=(2, 3, 4, 5))
model.train()
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 80, 16, 16))
assert feat[1].shape == torch.Size((1, 160, 8, 8))
assert feat[2].shape == torch.Size((1, 240, 4, 4))
assert feat[3].shape == torch.Size((1, 320, 2, 2))
# Test YOLOv7Backbone-w forward with widen_factor=0.25
model = YOLOv7Backbone(
arch='E2E', widen_factor=0.25, out_indices=(2, 3, 4, 5))
model.train()
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 80, 16, 16))
assert feat[1].shape == torch.Size((1, 160, 8, 8))
assert feat[2].shape == torch.Size((1, 240, 4, 4))
assert feat[3].shape == torch.Size((1, 320, 2, 2))

View File

@ -127,7 +127,7 @@ class TestYOLOv5Head(TestCase):
head = YOLOv5Head(head_module=self.head_module)
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([1]))
labels=torch.LongTensor([0]))
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
[gt_instances], img_metas)

View File

@ -0,0 +1,145 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmengine.config import Config
from mmengine.structures import InstanceData
from mmyolo.models.dense_heads import YOLOv7Head
from mmyolo.utils import register_all_modules
register_all_modules()
# TODO: Test YOLOv7p6HeadModule
class TestYOLOv7Head(TestCase):
def setUp(self):
self.head_module = dict(
type='YOLOv7HeadModule',
num_classes=2,
in_channels=[32, 64, 128],
featmap_strides=[8, 16, 32],
num_base_priors=3)
def test_predict_by_feat(self):
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'ori_shape': (s, s, 3),
'scale_factor': (1.0, 1.0),
}]
test_cfg = Config(
dict(
multi_label=True,
max_per_img=300,
score_thr=0.01,
nms=dict(type='nms', iou_threshold=0.65)))
head = YOLOv7Head(head_module=self.head_module, test_cfg=test_cfg)
feat = []
for i in range(len(self.head_module['in_channels'])):
in_channel = self.head_module['in_channels'][i]
feat_size = self.head_module['featmap_strides'][i]
feat.append(
torch.rand(1, in_channel, s // feat_size, s // feat_size))
cls_scores, bbox_preds, objectnesses = head.forward(feat)
head.predict_by_feat(
cls_scores,
bbox_preds,
objectnesses,
img_metas,
cfg=test_cfg,
rescale=True,
with_nms=True)
head.predict_by_feat(
cls_scores,
bbox_preds,
objectnesses,
img_metas,
cfg=test_cfg,
rescale=False,
with_nms=False)
def test_loss_by_feat(self):
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'batch_input_shape': (s, s),
'scale_factor': 1,
}]
head = YOLOv7Head(head_module=self.head_module)
feat = []
for i in range(len(self.head_module['in_channels'])):
in_channel = self.head_module['in_channels'][i]
feat_size = self.head_module['featmap_strides'][i]
feat.append(
torch.rand(1, in_channel, s // feat_size, s // feat_size))
cls_scores, bbox_preds, objectnesses = head.forward(feat)
# Test that empty ground truth encourages the network to predict
# background
gt_instances = InstanceData(
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
objectnesses, [gt_instances],
img_metas)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
empty_cls_loss = empty_gt_losses['loss_cls'].sum()
empty_box_loss = empty_gt_losses['loss_bbox'].sum()
empty_obj_loss = empty_gt_losses['loss_obj'].sum()
self.assertEqual(
empty_cls_loss.item(), 0,
'there should be no cls loss when there are no true boxes')
self.assertEqual(
empty_box_loss.item(), 0,
'there should be no box loss when there are no true boxes')
self.assertGreater(empty_obj_loss.item(), 0,
'objectness loss should be non-zero')
# When truth is non-empty then both cls and box loss should be nonzero
# for random inputs
head = YOLOv7Head(head_module=self.head_module)
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([1]))
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
[gt_instances], img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
onegt_obj_loss = one_gt_losses['loss_obj'].sum()
self.assertGreater(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')
self.assertGreater(onegt_obj_loss.item(), 0,
'obj loss should be non-zero')
# test num_class = 1
self.head_module['num_classes'] = 1
head = YOLOv7Head(head_module=self.head_module)
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([0]))
cls_scores, bbox_preds, objectnesses = head.forward(feat)
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
[gt_instances], img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
onegt_obj_loss = one_gt_losses['loss_obj'].sum()
self.assertEqual(onegt_cls_loss.item(), 0,
'cls loss should be non-zero')
self.assertGreater(onegt_box_loss.item(), 0,
'box loss should be non-zero')
self.assertGreater(onegt_obj_loss.item(), 0,
'obj loss should be non-zero')

View File

@ -22,7 +22,8 @@ class TestSingleStageDetector(TestCase):
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
'yolox/yolox_tiny_8xb8-300e_coco.py',
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py'
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py',
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py'
])
def test_init(self, cfg_file):
model = get_detector_cfg(cfg_file)
@ -37,6 +38,7 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_loss_mode(self, cfg_file, devices):
@ -47,6 +49,13 @@ class TestSingleStageDetector(TestCase):
model = get_detector_cfg(cfg_file)
model.backbone.init_cfg = None
if 'fast' in cfg_file:
model.data_preprocessor = dict(
type='mmdet.DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True)
from mmdet.models import build_detector
assert all([device in ['cpu', 'cuda'] for device in devices])
@ -69,6 +78,7 @@ class TestSingleStageDetector(TestCase):
'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_predict_mode(self, cfg_file, devices):
@ -100,6 +110,7 @@ class TestSingleStageDetector(TestCase):
'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
def test_forward_tensor_mode(self, cfg_file, devices):

View File

@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcv.cnn import ConvModule
from mmyolo.models.necks import YOLOv7PAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv7PAFPN(TestCase):
def test_forward(self):
# test P5
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv7PAFPN(in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i] * 2
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
# test is_tiny_version
neck = YOLOv7PAFPN(
in_channels=in_channels,
out_channels=out_channels,
is_tiny_version=True)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i] * 2
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
# test use_in_channels_in_downsample
neck = YOLOv7PAFPN(
in_channels=in_channels,
out_channels=out_channels,
use_in_channels_in_downsample=True)
for f in feats:
print(f.shape)
outs = neck(feats)
for f in outs:
print(f.shape)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i] * 2
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
# test use_repconv_outs is False
neck = YOLOv7PAFPN(
in_channels=in_channels,
out_channels=out_channels,
use_repconv_outs=False)
self.assertIsInstance(neck.out_layers[0], ConvModule)
# test P6
s = 64
in_channels = [8, 16, 32, 64]
feat_sizes = [s // 2**i for i in range(4)]
out_channels = [8, 16, 32, 64]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv7PAFPN(in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

View File

@ -1,10 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import torch
convert_dict = {
convert_dict_tiny = {
# stem
'model.0': 'backbone.stem.0',
'model.1': 'backbone.stem.1',
# stage1 TinyDownSampleBlock
'model.2': 'backbone.stage1.0.short_conv',
'model.3': 'backbone.stage1.0.main_convs.0',
'model.4': 'backbone.stage1.0.main_convs.1',
'model.5': 'backbone.stage1.0.main_convs.2',
'model.7': 'backbone.stage1.0.final_conv',
# stage2 TinyDownSampleBlock
'model.9': 'backbone.stage2.1.short_conv',
'model.10': 'backbone.stage2.1.main_convs.0',
'model.11': 'backbone.stage2.1.main_convs.1',
'model.12': 'backbone.stage2.1.main_convs.2',
'model.14': 'backbone.stage2.1.final_conv',
# stage3 TinyDownSampleBlock
'model.16': 'backbone.stage3.1.short_conv',
'model.17': 'backbone.stage3.1.main_convs.0',
'model.18': 'backbone.stage3.1.main_convs.1',
'model.19': 'backbone.stage3.1.main_convs.2',
'model.21': 'backbone.stage3.1.final_conv',
# stage4 TinyDownSampleBlock
'model.23': 'backbone.stage4.1.short_conv',
'model.24': 'backbone.stage4.1.main_convs.0',
'model.25': 'backbone.stage4.1.main_convs.1',
'model.26': 'backbone.stage4.1.main_convs.2',
'model.28': 'backbone.stage4.1.final_conv',
# neck SPPCSPBlock
'model.29': 'neck.reduce_layers.2.short_layer',
'model.30': 'neck.reduce_layers.2.main_layers',
'model.35': 'neck.reduce_layers.2.fuse_layers',
'model.37': 'neck.reduce_layers.2.final_conv',
'model.38': 'neck.upsample_layers.0.0',
'model.40': 'neck.reduce_layers.1',
'model.42': 'neck.top_down_layers.0.short_conv',
'model.43': 'neck.top_down_layers.0.main_convs.0',
'model.44': 'neck.top_down_layers.0.main_convs.1',
'model.45': 'neck.top_down_layers.0.main_convs.2',
'model.47': 'neck.top_down_layers.0.final_conv',
'model.48': 'neck.upsample_layers.1.0',
'model.50': 'neck.reduce_layers.0',
'model.52': 'neck.top_down_layers.1.short_conv',
'model.53': 'neck.top_down_layers.1.main_convs.0',
'model.54': 'neck.top_down_layers.1.main_convs.1',
'model.55': 'neck.top_down_layers.1.main_convs.2',
'model.57': 'neck.top_down_layers.1.final_conv',
'model.58': 'neck.downsample_layers.0',
'model.60': 'neck.bottom_up_layers.0.short_conv',
'model.61': 'neck.bottom_up_layers.0.main_convs.0',
'model.62': 'neck.bottom_up_layers.0.main_convs.1',
'model.63': 'neck.bottom_up_layers.0.main_convs.2',
'model.65': 'neck.bottom_up_layers.0.final_conv',
'model.66': 'neck.downsample_layers.1',
'model.68': 'neck.bottom_up_layers.1.short_conv',
'model.69': 'neck.bottom_up_layers.1.main_convs.0',
'model.70': 'neck.bottom_up_layers.1.main_convs.1',
'model.71': 'neck.bottom_up_layers.1.main_convs.2',
'model.73': 'neck.bottom_up_layers.1.final_conv',
'model.74': 'neck.out_layers.0',
'model.75': 'neck.out_layers.1',
'model.76': 'neck.out_layers.2',
# head
'model.77.m.0': 'bbox_head.head_module.convs_pred.0.1',
'model.77.m.1': 'bbox_head.head_module.convs_pred.1.1',
'model.77.m.2': 'bbox_head.head_module.convs_pred.2.1'
}
convert_dict_l = {
# stem
'model.0': 'backbone.stem.0',
'model.1': 'backbone.stem.1',
@ -70,7 +145,7 @@ convert_dict = {
'model.51.cv4': 'neck.reduce_layers.2.main_layers.2',
'model.51.cv5': 'neck.reduce_layers.2.fuse_layers.0',
'model.51.cv6': 'neck.reduce_layers.2.fuse_layers.1',
'model.51.cv2': 'neck.reduce_layers.2.short_layers',
'model.51.cv2': 'neck.reduce_layers.2.short_layer',
'model.51.cv7': 'neck.reduce_layers.2.final_conv',
# neck
@ -140,11 +215,522 @@ convert_dict = {
'model.104.rbr_1x1.1': 'neck.out_layers.2.rbr_1x1.bn',
# head
'model.105.m': 'bbox_head.head_module.convs_pred'
'model.105.m.0': 'bbox_head.head_module.convs_pred.0.1',
'model.105.m.1': 'bbox_head.head_module.convs_pred.1.1',
'model.105.m.2': 'bbox_head.head_module.convs_pred.2.1'
}
convert_dict_x = {
# stem
'model.0': 'backbone.stem.0',
'model.1': 'backbone.stem.1',
'model.2': 'backbone.stem.2',
# stage1
# ConvModule
'model.3': 'backbone.stage1.0',
# ELANBlock expand_channel_2x
'model.4': 'backbone.stage1.1.short_conv',
'model.5': 'backbone.stage1.1.main_conv',
'model.6': 'backbone.stage1.1.blocks.0.0',
'model.7': 'backbone.stage1.1.blocks.0.1',
'model.8': 'backbone.stage1.1.blocks.1.0',
'model.9': 'backbone.stage1.1.blocks.1.1',
'model.10': 'backbone.stage1.1.blocks.2.0',
'model.11': 'backbone.stage1.1.blocks.2.1',
'model.13': 'backbone.stage1.1.final_conv',
# stage2
# MaxPoolBlock reduce_channel_2x
'model.15': 'backbone.stage2.0.maxpool_branches.1',
'model.16': 'backbone.stage2.0.stride_conv_branches.0',
'model.17': 'backbone.stage2.0.stride_conv_branches.1',
# ELANBlock expand_channel_2x
'model.19': 'backbone.stage2.1.short_conv',
'model.20': 'backbone.stage2.1.main_conv',
'model.21': 'backbone.stage2.1.blocks.0.0',
'model.22': 'backbone.stage2.1.blocks.0.1',
'model.23': 'backbone.stage2.1.blocks.1.0',
'model.24': 'backbone.stage2.1.blocks.1.1',
'model.25': 'backbone.stage2.1.blocks.2.0',
'model.26': 'backbone.stage2.1.blocks.2.1',
'model.28': 'backbone.stage2.1.final_conv',
# stage3
# MaxPoolBlock reduce_channel_2x
'model.30': 'backbone.stage3.0.maxpool_branches.1',
'model.31': 'backbone.stage3.0.stride_conv_branches.0',
'model.32': 'backbone.stage3.0.stride_conv_branches.1',
# ELANBlock expand_channel_2x
'model.34': 'backbone.stage3.1.short_conv',
'model.35': 'backbone.stage3.1.main_conv',
'model.36': 'backbone.stage3.1.blocks.0.0',
'model.37': 'backbone.stage3.1.blocks.0.1',
'model.38': 'backbone.stage3.1.blocks.1.0',
'model.39': 'backbone.stage3.1.blocks.1.1',
'model.40': 'backbone.stage3.1.blocks.2.0',
'model.41': 'backbone.stage3.1.blocks.2.1',
'model.43': 'backbone.stage3.1.final_conv',
# stage4
# MaxPoolBlock reduce_channel_2x
'model.45': 'backbone.stage4.0.maxpool_branches.1',
'model.46': 'backbone.stage4.0.stride_conv_branches.0',
'model.47': 'backbone.stage4.0.stride_conv_branches.1',
# ELANBlock no_change_channel
'model.49': 'backbone.stage4.1.short_conv',
'model.50': 'backbone.stage4.1.main_conv',
'model.51': 'backbone.stage4.1.blocks.0.0',
'model.52': 'backbone.stage4.1.blocks.0.1',
'model.53': 'backbone.stage4.1.blocks.1.0',
'model.54': 'backbone.stage4.1.blocks.1.1',
'model.55': 'backbone.stage4.1.blocks.2.0',
'model.56': 'backbone.stage4.1.blocks.2.1',
'model.58': 'backbone.stage4.1.final_conv',
# neck SPPCSPBlock
'model.59.cv1': 'neck.reduce_layers.2.main_layers.0',
'model.59.cv3': 'neck.reduce_layers.2.main_layers.1',
'model.59.cv4': 'neck.reduce_layers.2.main_layers.2',
'model.59.cv5': 'neck.reduce_layers.2.fuse_layers.0',
'model.59.cv6': 'neck.reduce_layers.2.fuse_layers.1',
'model.59.cv2': 'neck.reduce_layers.2.short_layer',
'model.59.cv7': 'neck.reduce_layers.2.final_conv',
# neck
'model.60': 'neck.upsample_layers.0.0',
'model.62': 'neck.reduce_layers.1',
# neck ELANBlock reduce_channel_2x
'model.64': 'neck.top_down_layers.0.short_conv',
'model.65': 'neck.top_down_layers.0.main_conv',
'model.66': 'neck.top_down_layers.0.blocks.0.0',
'model.67': 'neck.top_down_layers.0.blocks.0.1',
'model.68': 'neck.top_down_layers.0.blocks.1.0',
'model.69': 'neck.top_down_layers.0.blocks.1.1',
'model.70': 'neck.top_down_layers.0.blocks.2.0',
'model.71': 'neck.top_down_layers.0.blocks.2.1',
'model.73': 'neck.top_down_layers.0.final_conv',
'model.74': 'neck.upsample_layers.1.0',
'model.76': 'neck.reduce_layers.0',
# neck ELANBlock reduce_channel_2x
'model.78': 'neck.top_down_layers.1.short_conv',
'model.79': 'neck.top_down_layers.1.main_conv',
'model.80': 'neck.top_down_layers.1.blocks.0.0',
'model.81': 'neck.top_down_layers.1.blocks.0.1',
'model.82': 'neck.top_down_layers.1.blocks.1.0',
'model.83': 'neck.top_down_layers.1.blocks.1.1',
'model.84': 'neck.top_down_layers.1.blocks.2.0',
'model.85': 'neck.top_down_layers.1.blocks.2.1',
'model.87': 'neck.top_down_layers.1.final_conv',
# neck MaxPoolBlock no_change_channel
'model.89': 'neck.downsample_layers.0.maxpool_branches.1',
'model.90': 'neck.downsample_layers.0.stride_conv_branches.0',
'model.91': 'neck.downsample_layers.0.stride_conv_branches.1',
# neck ELANBlock reduce_channel_2x
'model.93': 'neck.bottom_up_layers.0.short_conv',
'model.94': 'neck.bottom_up_layers.0.main_conv',
'model.95': 'neck.bottom_up_layers.0.blocks.0.0',
'model.96': 'neck.bottom_up_layers.0.blocks.0.1',
'model.97': 'neck.bottom_up_layers.0.blocks.1.0',
'model.98': 'neck.bottom_up_layers.0.blocks.1.1',
'model.99': 'neck.bottom_up_layers.0.blocks.2.0',
'model.100': 'neck.bottom_up_layers.0.blocks.2.1',
'model.102': 'neck.bottom_up_layers.0.final_conv',
# neck MaxPoolBlock no_change_channel
'model.104': 'neck.downsample_layers.1.maxpool_branches.1',
'model.105': 'neck.downsample_layers.1.stride_conv_branches.0',
'model.106': 'neck.downsample_layers.1.stride_conv_branches.1',
# neck ELANBlock reduce_channel_2x
'model.108': 'neck.bottom_up_layers.1.short_conv',
'model.109': 'neck.bottom_up_layers.1.main_conv',
'model.110': 'neck.bottom_up_layers.1.blocks.0.0',
'model.111': 'neck.bottom_up_layers.1.blocks.0.1',
'model.112': 'neck.bottom_up_layers.1.blocks.1.0',
'model.113': 'neck.bottom_up_layers.1.blocks.1.1',
'model.114': 'neck.bottom_up_layers.1.blocks.2.0',
'model.115': 'neck.bottom_up_layers.1.blocks.2.1',
'model.117': 'neck.bottom_up_layers.1.final_conv',
# Conv
'model.118': 'neck.out_layers.0',
'model.119': 'neck.out_layers.1',
'model.120': 'neck.out_layers.2',
# head
'model.121.m.0': 'bbox_head.head_module.convs_pred.0.1',
'model.121.m.1': 'bbox_head.head_module.convs_pred.1.1',
'model.121.m.2': 'bbox_head.head_module.convs_pred.2.1'
}
convert_dict_w = {
# stem
'model.1': 'backbone.stem.conv',
# stage1
# ConvModule
'model.2': 'backbone.stage1.0',
# ELANBlock
'model.3': 'backbone.stage1.1.short_conv',
'model.4': 'backbone.stage1.1.main_conv',
'model.5': 'backbone.stage1.1.blocks.0.0',
'model.6': 'backbone.stage1.1.blocks.0.1',
'model.7': 'backbone.stage1.1.blocks.1.0',
'model.8': 'backbone.stage1.1.blocks.1.1',
'model.10': 'backbone.stage1.1.final_conv',
# stage2
'model.11': 'backbone.stage2.0',
# ELANBlock
'model.12': 'backbone.stage2.1.short_conv',
'model.13': 'backbone.stage2.1.main_conv',
'model.14': 'backbone.stage2.1.blocks.0.0',
'model.15': 'backbone.stage2.1.blocks.0.1',
'model.16': 'backbone.stage2.1.blocks.1.0',
'model.17': 'backbone.stage2.1.blocks.1.1',
'model.19': 'backbone.stage2.1.final_conv',
# stage3
'model.20': 'backbone.stage3.0',
# ELANBlock
'model.21': 'backbone.stage3.1.short_conv',
'model.22': 'backbone.stage3.1.main_conv',
'model.23': 'backbone.stage3.1.blocks.0.0',
'model.24': 'backbone.stage3.1.blocks.0.1',
'model.25': 'backbone.stage3.1.blocks.1.0',
'model.26': 'backbone.stage3.1.blocks.1.1',
'model.28': 'backbone.stage3.1.final_conv',
# stage4
'model.29': 'backbone.stage4.0',
# ELANBlock
'model.30': 'backbone.stage4.1.short_conv',
'model.31': 'backbone.stage4.1.main_conv',
'model.32': 'backbone.stage4.1.blocks.0.0',
'model.33': 'backbone.stage4.1.blocks.0.1',
'model.34': 'backbone.stage4.1.blocks.1.0',
'model.35': 'backbone.stage4.1.blocks.1.1',
'model.37': 'backbone.stage4.1.final_conv',
# stage5
'model.38': 'backbone.stage5.0',
# ELANBlock
'model.39': 'backbone.stage5.1.short_conv',
'model.40': 'backbone.stage5.1.main_conv',
'model.41': 'backbone.stage5.1.blocks.0.0',
'model.42': 'backbone.stage5.1.blocks.0.1',
'model.43': 'backbone.stage5.1.blocks.1.0',
'model.44': 'backbone.stage5.1.blocks.1.1',
'model.46': 'backbone.stage5.1.final_conv',
# neck SPPCSPBlock
'model.47.cv1': 'neck.reduce_layers.3.main_layers.0',
'model.47.cv3': 'neck.reduce_layers.3.main_layers.1',
'model.47.cv4': 'neck.reduce_layers.3.main_layers.2',
'model.47.cv5': 'neck.reduce_layers.3.fuse_layers.0',
'model.47.cv6': 'neck.reduce_layers.3.fuse_layers.1',
'model.47.cv2': 'neck.reduce_layers.3.short_layer',
'model.47.cv7': 'neck.reduce_layers.3.final_conv',
# neck
'model.48': 'neck.upsample_layers.0.0',
'model.50': 'neck.reduce_layers.2',
# neck ELANBlock
'model.52': 'neck.top_down_layers.0.short_conv',
'model.53': 'neck.top_down_layers.0.main_conv',
'model.54': 'neck.top_down_layers.0.blocks.0',
'model.55': 'neck.top_down_layers.0.blocks.1',
'model.56': 'neck.top_down_layers.0.blocks.2',
'model.57': 'neck.top_down_layers.0.blocks.3',
'model.59': 'neck.top_down_layers.0.final_conv',
'model.60': 'neck.upsample_layers.1.0',
'model.62': 'neck.reduce_layers.1',
# neck ELANBlock reduce_channel_2x
'model.64': 'neck.top_down_layers.1.short_conv',
'model.65': 'neck.top_down_layers.1.main_conv',
'model.66': 'neck.top_down_layers.1.blocks.0',
'model.67': 'neck.top_down_layers.1.blocks.1',
'model.68': 'neck.top_down_layers.1.blocks.2',
'model.69': 'neck.top_down_layers.1.blocks.3',
'model.71': 'neck.top_down_layers.1.final_conv',
'model.72': 'neck.upsample_layers.2.0',
'model.74': 'neck.reduce_layers.0',
'model.76': 'neck.top_down_layers.2.short_conv',
'model.77': 'neck.top_down_layers.2.main_conv',
'model.78': 'neck.top_down_layers.2.blocks.0',
'model.79': 'neck.top_down_layers.2.blocks.1',
'model.80': 'neck.top_down_layers.2.blocks.2',
'model.81': 'neck.top_down_layers.2.blocks.3',
'model.83': 'neck.top_down_layers.2.final_conv',
'model.84': 'neck.downsample_layers.0',
# neck ELANBlock
'model.86': 'neck.bottom_up_layers.0.short_conv',
'model.87': 'neck.bottom_up_layers.0.main_conv',
'model.88': 'neck.bottom_up_layers.0.blocks.0',
'model.89': 'neck.bottom_up_layers.0.blocks.1',
'model.90': 'neck.bottom_up_layers.0.blocks.2',
'model.91': 'neck.bottom_up_layers.0.blocks.3',
'model.93': 'neck.bottom_up_layers.0.final_conv',
'model.94': 'neck.downsample_layers.1',
# neck ELANBlock reduce_channel_2x
'model.96': 'neck.bottom_up_layers.1.short_conv',
'model.97': 'neck.bottom_up_layers.1.main_conv',
'model.98': 'neck.bottom_up_layers.1.blocks.0',
'model.99': 'neck.bottom_up_layers.1.blocks.1',
'model.100': 'neck.bottom_up_layers.1.blocks.2',
'model.101': 'neck.bottom_up_layers.1.blocks.3',
'model.103': 'neck.bottom_up_layers.1.final_conv',
'model.104': 'neck.downsample_layers.2',
# neck ELANBlock reduce_channel_2x
'model.106': 'neck.bottom_up_layers.2.short_conv',
'model.107': 'neck.bottom_up_layers.2.main_conv',
'model.108': 'neck.bottom_up_layers.2.blocks.0',
'model.109': 'neck.bottom_up_layers.2.blocks.1',
'model.110': 'neck.bottom_up_layers.2.blocks.2',
'model.111': 'neck.bottom_up_layers.2.blocks.3',
'model.113': 'neck.bottom_up_layers.2.final_conv',
'model.114': 'bbox_head.head_module.main_convs_pred.0.0',
'model.115': 'bbox_head.head_module.main_convs_pred.1.0',
'model.116': 'bbox_head.head_module.main_convs_pred.2.0',
'model.117': 'bbox_head.head_module.main_convs_pred.3.0',
# head
'model.118.m.0': 'bbox_head.head_module.main_convs_pred.0.2',
'model.118.m.1': 'bbox_head.head_module.main_convs_pred.1.2',
'model.118.m.2': 'bbox_head.head_module.main_convs_pred.2.2',
'model.118.m.3': 'bbox_head.head_module.main_convs_pred.3.2'
}
convert_dict_e = {
# stem
'model.1': 'backbone.stem.conv',
# stage1
'model.2.cv1': 'backbone.stage1.0.stride_conv_branches.0',
'model.2.cv2': 'backbone.stage1.0.stride_conv_branches.1',
'model.2.cv3': 'backbone.stage1.0.maxpool_branches.1',
# ELANBlock
'model.3': 'backbone.stage1.1.short_conv',
'model.4': 'backbone.stage1.1.main_conv',
'model.5': 'backbone.stage1.1.blocks.0.0',
'model.6': 'backbone.stage1.1.blocks.0.1',
'model.7': 'backbone.stage1.1.blocks.1.0',
'model.8': 'backbone.stage1.1.blocks.1.1',
'model.9': 'backbone.stage1.1.blocks.2.0',
'model.10': 'backbone.stage1.1.blocks.2.1',
'model.12': 'backbone.stage1.1.final_conv',
# stage2
'model.13.cv1': 'backbone.stage2.0.stride_conv_branches.0',
'model.13.cv2': 'backbone.stage2.0.stride_conv_branches.1',
'model.13.cv3': 'backbone.stage2.0.maxpool_branches.1',
# ELANBlock
'model.14': 'backbone.stage2.1.short_conv',
'model.15': 'backbone.stage2.1.main_conv',
'model.16': 'backbone.stage2.1.blocks.0.0',
'model.17': 'backbone.stage2.1.blocks.0.1',
'model.18': 'backbone.stage2.1.blocks.1.0',
'model.19': 'backbone.stage2.1.blocks.1.1',
'model.20': 'backbone.stage2.1.blocks.2.0',
'model.21': 'backbone.stage2.1.blocks.2.1',
'model.23': 'backbone.stage2.1.final_conv',
# stage3
'model.24.cv1': 'backbone.stage3.0.stride_conv_branches.0',
'model.24.cv2': 'backbone.stage3.0.stride_conv_branches.1',
'model.24.cv3': 'backbone.stage3.0.maxpool_branches.1',
# ELANBlock
'model.25': 'backbone.stage3.1.short_conv',
'model.26': 'backbone.stage3.1.main_conv',
'model.27': 'backbone.stage3.1.blocks.0.0',
'model.28': 'backbone.stage3.1.blocks.0.1',
'model.29': 'backbone.stage3.1.blocks.1.0',
'model.30': 'backbone.stage3.1.blocks.1.1',
'model.31': 'backbone.stage3.1.blocks.2.0',
'model.32': 'backbone.stage3.1.blocks.2.1',
'model.34': 'backbone.stage3.1.final_conv',
# stage4
'model.35.cv1': 'backbone.stage4.0.stride_conv_branches.0',
'model.35.cv2': 'backbone.stage4.0.stride_conv_branches.1',
'model.35.cv3': 'backbone.stage4.0.maxpool_branches.1',
# ELANBlock
'model.36': 'backbone.stage4.1.short_conv',
'model.37': 'backbone.stage4.1.main_conv',
'model.38': 'backbone.stage4.1.blocks.0.0',
'model.39': 'backbone.stage4.1.blocks.0.1',
'model.40': 'backbone.stage4.1.blocks.1.0',
'model.41': 'backbone.stage4.1.blocks.1.1',
'model.42': 'backbone.stage4.1.blocks.2.0',
'model.43': 'backbone.stage4.1.blocks.2.1',
'model.45': 'backbone.stage4.1.final_conv',
# stage5
'model.46.cv1': 'backbone.stage5.0.stride_conv_branches.0',
'model.46.cv2': 'backbone.stage5.0.stride_conv_branches.1',
'model.46.cv3': 'backbone.stage5.0.maxpool_branches.1',
# ELANBlock
'model.47': 'backbone.stage5.1.short_conv',
'model.48': 'backbone.stage5.1.main_conv',
'model.49': 'backbone.stage5.1.blocks.0.0',
'model.50': 'backbone.stage5.1.blocks.0.1',
'model.51': 'backbone.stage5.1.blocks.1.0',
'model.52': 'backbone.stage5.1.blocks.1.1',
'model.53': 'backbone.stage5.1.blocks.2.0',
'model.54': 'backbone.stage5.1.blocks.2.1',
'model.56': 'backbone.stage5.1.final_conv',
# neck SPPCSPBlock
'model.57.cv1': 'neck.reduce_layers.3.main_layers.0',
'model.57.cv3': 'neck.reduce_layers.3.main_layers.1',
'model.57.cv4': 'neck.reduce_layers.3.main_layers.2',
'model.57.cv5': 'neck.reduce_layers.3.fuse_layers.0',
'model.57.cv6': 'neck.reduce_layers.3.fuse_layers.1',
'model.57.cv2': 'neck.reduce_layers.3.short_layer',
'model.57.cv7': 'neck.reduce_layers.3.final_conv',
# neck
'model.58': 'neck.upsample_layers.0.0',
'model.60': 'neck.reduce_layers.2',
# neck ELANBlock
'model.62': 'neck.top_down_layers.0.short_conv',
'model.63': 'neck.top_down_layers.0.main_conv',
'model.64': 'neck.top_down_layers.0.blocks.0',
'model.65': 'neck.top_down_layers.0.blocks.1',
'model.66': 'neck.top_down_layers.0.blocks.2',
'model.67': 'neck.top_down_layers.0.blocks.3',
'model.68': 'neck.top_down_layers.0.blocks.4',
'model.69': 'neck.top_down_layers.0.blocks.5',
'model.71': 'neck.top_down_layers.0.final_conv',
'model.72': 'neck.upsample_layers.1.0',
'model.74': 'neck.reduce_layers.1',
# neck ELANBlock
'model.76': 'neck.top_down_layers.1.short_conv',
'model.77': 'neck.top_down_layers.1.main_conv',
'model.78': 'neck.top_down_layers.1.blocks.0',
'model.79': 'neck.top_down_layers.1.blocks.1',
'model.80': 'neck.top_down_layers.1.blocks.2',
'model.81': 'neck.top_down_layers.1.blocks.3',
'model.82': 'neck.top_down_layers.1.blocks.4',
'model.83': 'neck.top_down_layers.1.blocks.5',
'model.85': 'neck.top_down_layers.1.final_conv',
'model.86': 'neck.upsample_layers.2.0',
'model.88': 'neck.reduce_layers.0',
'model.90': 'neck.top_down_layers.2.short_conv',
'model.91': 'neck.top_down_layers.2.main_conv',
'model.92': 'neck.top_down_layers.2.blocks.0',
'model.93': 'neck.top_down_layers.2.blocks.1',
'model.94': 'neck.top_down_layers.2.blocks.2',
'model.95': 'neck.top_down_layers.2.blocks.3',
'model.96': 'neck.top_down_layers.2.blocks.4',
'model.97': 'neck.top_down_layers.2.blocks.5',
'model.99': 'neck.top_down_layers.2.final_conv',
'model.100.cv1': 'neck.downsample_layers.0.stride_conv_branches.0',
'model.100.cv2': 'neck.downsample_layers.0.stride_conv_branches.1',
'model.100.cv3': 'neck.downsample_layers.0.maxpool_branches.1',
# neck ELANBlock
'model.102': 'neck.bottom_up_layers.0.short_conv',
'model.103': 'neck.bottom_up_layers.0.main_conv',
'model.104': 'neck.bottom_up_layers.0.blocks.0',
'model.105': 'neck.bottom_up_layers.0.blocks.1',
'model.106': 'neck.bottom_up_layers.0.blocks.2',
'model.107': 'neck.bottom_up_layers.0.blocks.3',
'model.108': 'neck.bottom_up_layers.0.blocks.4',
'model.109': 'neck.bottom_up_layers.0.blocks.5',
'model.111': 'neck.bottom_up_layers.0.final_conv',
'model.112.cv1': 'neck.downsample_layers.1.stride_conv_branches.0',
'model.112.cv2': 'neck.downsample_layers.1.stride_conv_branches.1',
'model.112.cv3': 'neck.downsample_layers.1.maxpool_branches.1',
# neck ELANBlock
'model.114': 'neck.bottom_up_layers.1.short_conv',
'model.115': 'neck.bottom_up_layers.1.main_conv',
'model.116': 'neck.bottom_up_layers.1.blocks.0',
'model.117': 'neck.bottom_up_layers.1.blocks.1',
'model.118': 'neck.bottom_up_layers.1.blocks.2',
'model.119': 'neck.bottom_up_layers.1.blocks.3',
'model.120': 'neck.bottom_up_layers.1.blocks.4',
'model.121': 'neck.bottom_up_layers.1.blocks.5',
'model.123': 'neck.bottom_up_layers.1.final_conv',
'model.124.cv1': 'neck.downsample_layers.2.stride_conv_branches.0',
'model.124.cv2': 'neck.downsample_layers.2.stride_conv_branches.1',
'model.124.cv3': 'neck.downsample_layers.2.maxpool_branches.1',
# neck ELANBlock
'model.126': 'neck.bottom_up_layers.2.short_conv',
'model.127': 'neck.bottom_up_layers.2.main_conv',
'model.128': 'neck.bottom_up_layers.2.blocks.0',
'model.129': 'neck.bottom_up_layers.2.blocks.1',
'model.130': 'neck.bottom_up_layers.2.blocks.2',
'model.131': 'neck.bottom_up_layers.2.blocks.3',
'model.132': 'neck.bottom_up_layers.2.blocks.4',
'model.133': 'neck.bottom_up_layers.2.blocks.5',
'model.135': 'neck.bottom_up_layers.2.final_conv',
'model.136': 'bbox_head.head_module.main_convs_pred.0.0',
'model.137': 'bbox_head.head_module.main_convs_pred.1.0',
'model.138': 'bbox_head.head_module.main_convs_pred.2.0',
'model.139': 'bbox_head.head_module.main_convs_pred.3.0',
# head
'model.140.m.0': 'bbox_head.head_module.main_convs_pred.0.2',
'model.140.m.1': 'bbox_head.head_module.main_convs_pred.1.2',
'model.140.m.2': 'bbox_head.head_module.main_convs_pred.2.2',
'model.140.m.3': 'bbox_head.head_module.main_convs_pred.3.2'
}
convert_dicts = {
'yolov7-tiny.pt': convert_dict_tiny,
'yolov7-w6.pt': convert_dict_w,
'yolov7-e6.pt': convert_dict_e,
'yolov7.pt': convert_dict_l,
'yolov7x.pt': convert_dict_x
}
def convert(src, dst):
src_key = osp.basename(src)
convert_dict = convert_dicts[osp.basename(src)]
num_levels = 3
if src_key == 'yolov7.pt':
indexes = [102, 51]
in_channels = [256, 512, 1024]
elif src_key == 'yolov7x.pt':
indexes = [121, 59]
in_channels = [320, 640, 1280]
elif src_key == 'yolov7-tiny.pt':
indexes = [77, 1000]
in_channels = [128, 256, 512]
elif src_key == 'yolov7-w6.pt':
indexes = [118, 47]
in_channels = [256, 512, 768, 1024]
num_levels = 4
elif src_key == 'yolov7-e6.pt':
indexes = [140, [2, 13, 24, 35, 46, 57, 100, 112, 124]]
in_channels = 320, 640, 960, 1280
num_levels = 4
if isinstance(indexes[1], int):
indexes[1] = [indexes[1]]
"""Convert keys in detectron pretrained YOLOv7 models to mmyolo style."""
try:
yolov7_model = torch.load(src)['model'].float()
@ -161,23 +747,40 @@ def convert(src, dst):
continue
num, module = key.split('.')[1:3]
if int(num) < 102 and int(num) != 51:
if int(num) < indexes[0] and int(num) not in indexes[1]:
prefix = f'model.{num}'
new_key = key.replace(prefix, convert_dict[prefix])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
elif int(num) < 105 and int(num) != 51:
strs_key = key.split('.')[:4]
new_key = key.replace('.'.join(strs_key),
convert_dict['.'.join(strs_key)])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
else:
elif int(num) in indexes[1]:
strs_key = key.split('.')[:3]
new_key = key.replace('.'.join(strs_key),
convert_dict['.'.join(strs_key)])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
else:
strs_key = key.split('.')[:4]
new_key = key.replace('.'.join(strs_key),
convert_dict['.'.join(strs_key)])
state_dict[new_key] = weight
print(f'Convert {key} to {new_key}')
# Add ImplicitA and ImplicitM
for i in range(num_levels):
if num_levels == 3:
implicit_a = f'bbox_head.head_module.' \
f'convs_pred.{i}.0.implicit'
state_dict[implicit_a] = torch.zeros((1, in_channels[i], 1, 1))
implicit_m = f'bbox_head.head_module.' \
f'convs_pred.{i}.2.implicit'
state_dict[implicit_m] = torch.ones((1, 3 * 85, 1, 1))
else:
implicit_a = f'bbox_head.head_module.' \
f'main_convs_pred.{i}.1.implicit'
state_dict[implicit_a] = torch.zeros((1, in_channels[i], 1, 1))
implicit_m = f'bbox_head.head_module.' \
f'main_convs_pred.{i}.3.implicit'
state_dict[implicit_m] = torch.ones((1, 3 * 85, 1, 1))
# save checkpoint
checkpoint = dict()
@ -189,8 +792,8 @@ def convert(src, dst):
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument(
'--src', default='yolov7.pt', help='src yolov7 model path')
parser.add_argument('--dst', default='mm_yolov7l.pt', help='save path')
'src', default='yolov7.pt', help='src yolov7 model path')
parser.add_argument('dst', default='mm_yolov7l.pt', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)