[Feature] Support YOLOv6 training ()

* init v6 loss

* init v6s train

* Add train pipeline

* Add lr scheduler

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix detach bug

* fix detach bug

* update

* Add stop aug hook

* Add save best ckpt

* update

* Add PipelineSwitchHook

* Fix train pipeline stage 2

* update

* Fix train pipeline

* update

* fix stage2 randomaffine bug

update

update

clean

clean

* update letterResize param

* add v6affine config

* add v6 randomaffine

* update v6 config

* update

* update

* update

* update

* update config param

* update

* update

* refactor iou loss % rm v6affine

* update

* rm dfl

* add v6 300 epoch config

* Factor batch atss assigner

* Format code

* Format code

* Roll back

* Refactor dist_calculator

* Refactor select_candidates_in_gts

* Refactor select_highest_overlaps

* Refactor iou_calculator

* Refactor all code

* Improve docstr

* Improve code

* clean config

* add nano tiny config

* pre-commit

* Refactor

* Improve code

* Improve naming and link

* Add UT

* pre commit

* Add UT

* Add UT

* Improve code, using mmdet.BboxOverlaps2D for all iou calculation

* Improve code, using mmdet.BboxOverlaps2D for all iou calculation

* Improve code

* pre commit

* pre commit

* Add UT

* fix config

* pre commit

* Improve code

* Improve code

* Improve code

* Improve code

* [Refactor] YOLOv6 BatchATSSAssigner ()

* Factor batch atss assigner

* Format code

* Format code

* Roll back

* Refactor dist_calculator

* Refactor select_candidates_in_gts

* Refactor select_highest_overlaps

* Refactor iou_calculator

* Refactor all code

* Improve docstr

* Improve code

* Improve code

* Improve naming and link

* Add UT

* pre commit

* Add UT

* Add UT

* Improve code, using mmdet.BboxOverlaps2D for all iou calculation

* Improve code, using mmdet.BboxOverlaps2D for all iou calculation

* Improve code

* pre commit

* Fix conflicts

* Improve code

* Improve code

* Improve code

* Improve code

* Improve code

* Improve code

* add utils.py, order the input param

* Improve docstr

* Fix lint

* Improve param mapping

* Improve param mapping

* Improve naming

* assigner return dict

* update

* update config

* update config

* Fix

* Fix UT

* Improve UT

* Improve naming

* Improve coding

* pre commit

* pre commit

* pre commit

* Fix ci

* Improve naming

* Improve coding

* Fix training iou calculate error

* Improve naming

* Improve naming

* Improve type hint

* fix lint

* fix conflicts

* fix UT

* Improve type hint

* Improve naming

* Improve coding

* Improve coding

* Fix UT

* Refactor SIoU

* Pre commit

* Fix

* Improve ciou

* Improve ciou

* refactor varifocal

* Improve ciou

* Improve ciou

* Improve siou

* Improve type hint

* Improve siou

* Improve siou

* Fix lint

* refactor varifocal

* fix iou bug

* fix siou and loss_cls bug

* update

* update

* add scope

* update

* update

* Improve func `gt_instances_preprocess`

* support deploy mode

* Improve func `gt_instances_preprocess`

* Improve func `gt_instances_preprocess`

* Improve func `gt_instances_preprocess`

* Improve func `bbox_overlaps`

* Improve coding

* Improve bbox_overlaps

* Delete useless code

* add yolov6 deploy mode hook

* fix lint

* Add common attributes to reduce calculation

* Improve code

* Improve code

* Fix bug

* Fix bug

* update

* add readme

* update readme

* update readme url

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
pull/249/head
wanghonglie 2022-11-02 20:23:25 +08:00 committed by Haian Huang(深度眸)
parent 177eb4ea13
commit 980e908618
28 changed files with 1694 additions and 111 deletions

View File

@ -16,9 +16,17 @@ For years, YOLO series have been de facto industry-level standard for efficient
### COCO
| Backbone | Arch | size | SyncBN | AMP | Mem (GB) | box AP | Config | Download |
| :------: | :--: | :--: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOv6-n | P5 | 640 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
| YOLOv6-t | P5 | 640 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
| YOLOv6-s | P5 | 640 | Yes | Yes | 8.88 | 43.7 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221030_202704-2ba343db.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221030_202704.log.json) |
**Note**:
1. We don't support training just yet. But you can use the `tools/model_converters/yolov6_to_mmyolo.py` script to convert the official weight.
1. The performance is unstable and may fluctuate by about 0.3 mAP.
2. YOLOv6-m,l,x will be supported in later version.
3. If users need the weight of 300 epoch, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
## Citation

View File

@ -0,0 +1,13 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 0.33
widen_factor = 0.25
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))
default_hooks = dict(param_scheduler=dict(lr_factor=0.02))

View File

@ -0,0 +1,13 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
deepen_factor = 0.33
widen_factor = 0.25
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))
default_hooks = dict(param_scheduler=dict(lr_factor=0.02))

View File

@ -1,39 +0,0 @@
# Training mode is currently not supported
_base_ = '../yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py'
max_epochs = 400
train_batch_size_per_gpu = 32
deepen_factor = _base_.deepen_factor
widen_factor = _base_.widen_factor
model = dict(
backbone=dict(
_delete_=True,
type='YOLOv6EfficientRep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
_delete_=True,
type='YOLOv6RepPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[128, 256, 512],
num_csp_blocks=12,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True),
),
bbox_head=dict(
_delete_=True,
type='YOLOv6Head',
head_module=dict(
type='YOLOv6HeadModule',
num_classes=80,
in_channels=[128, 256, 512],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32])),
train_cfg=None)

View File

@ -0,0 +1,29 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
max_epochs = 300
num_last_epochs = 15
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=0.01,
max_epochs=max_epochs))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_last_epochs,
switch_pipeline=_base_.train_pipeline_stage2)
]
train_cfg = dict(
max_epochs=max_epochs,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])

View File

@ -0,0 +1,250 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
num_last_epochs = 15
max_epochs = 400
num_classes = 80
# parameters that often need to be modified
img_scale = (640, 640) # height, width
deepen_factor = 0.33
widen_factor = 0.5
save_epoch_intervals = 10
train_batch_size_per_gpu = 32
train_num_workers = 8
val_batch_size_per_gpu = 1
val_num_workers = 2
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
# only on Val
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)
# single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv6EfficientRep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6RepPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[128, 256, 512],
num_csp_blocks=12,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True),
),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(
type='YOLOv6HeadModule',
num_classes=num_classes,
in_channels=[128, 256, 512],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32]),
loss_bbox=dict(
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False)),
train_cfg=dict(
initial_epoch=4,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=num_classes,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
topk=13,
alpha=1,
beta=6),
),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.001,
nms=dict(type='nms', iou_threshold=0.65),
max_per_img=300))
# The training pipeline of YOLOv6 is basically the same as YOLOv5.
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa
pre_transform = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True)
]
train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(0.5, 1.5),
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114),
max_shear_degree=0.0),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_pipeline_stage2 = [
*pre_transform,
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=True,
pad_val=dict(img=114)),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(0.5, 1.5),
max_shear_degree=0.0,
),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu,
num_workers=train_num_workers,
collate_fn=dict(type='yolov5_collate'),
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='annotations/instances_train2017.json',
data_prefix=dict(img='train2017/'),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
batch_size=val_batch_size_per_gpu,
num_workers=val_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img='val2017/'),
ann_file='annotations/instances_val2017.json',
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa
# The difference is that the scheduler_type of YOLOv6 is cosine.
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=0.01,
momentum=0.937,
weight_decay=0.0005,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')
default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=0.01,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=3,
save_best='auto'))
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_last_epochs,
switch_pipeline=train_pipeline_stage2)
]
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox')
test_evaluator = val_evaluator
train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -0,0 +1,12 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
deepen_factor = 0.33
widen_factor = 0.375
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))

View File

@ -0,0 +1,12 @@
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
deepen_factor = 0.33
widen_factor = 0.375
model = dict(
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(
type='YOLOv6Head',
head_module=dict(widen_factor=widen_factor),
loss_bbox=dict(iou_mode='siou')))

View File

@ -11,7 +11,7 @@ from mmengine.logging import print_log
from mmengine.utils import ProgressBar, scandir
from mmyolo.registry import VISUALIZERS
from mmyolo.utils import register_all_modules
from mmyolo.utils import register_all_modules, switch_to_deploy
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
@ -30,7 +30,11 @@ def parse_args():
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
args = parser.parse_args()
return args
@ -42,6 +46,9 @@ def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
if args.deploy:
switch_to_deploy(model)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

View File

@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .switch_to_deploy_hook import SwitchToDeployHook
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
__all__ = ['YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook']
__all__ = [
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook'
]

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmyolo.registry import HOOKS
from mmyolo.utils import switch_to_deploy
@HOOKS.register_module()
class SwitchToDeployHook(Hook):
"""Switch to deploy mode before testing.
This hook converts the multi-channel structure of the training network
(high performance) to the one-way structure of the testing network (fast
speed and memory saving).
"""
def before_test_epoch(self, runner: Runner):
switch_to_deploy(runner.model)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -7,11 +7,13 @@ from mmcv.cnn import ConvModule
from mmdet.models.utils import multi_apply
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig)
from mmengine import MessageHub
from mmengine.dist import get_dist_info
from mmengine.model import BaseModule, bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from mmyolo.registry import MODELS, TASK_UTILS
from ..utils import make_divisible
from .yolov5_head import YOLOv5Head
@ -72,19 +74,6 @@ class YOLOv6HeadModule(BaseModule):
self._init_layers()
def init_weights(self):
"""Initialize weights of the head."""
# Use prior in model initialization to improve stability
super().init_weights()
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
m.reset_parameters()
bias_init = bias_init_with_prob(0.01)
for conv_cls in self.cls_preds:
conv_cls.bias.data.fill_(bias_init)
def _init_layers(self):
"""initialize conv layers in YOLOv6 head."""
# Init decouple head
@ -132,7 +121,18 @@ class YOLOv6HeadModule(BaseModule):
out_channels=self.num_base_priors * 4,
kernel_size=1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def init_weights(self):
super().init_weights()
bias_init = bias_init_with_prob(0.01)
for conv in self.cls_preds:
conv.bias.data.fill_(bias_init)
conv.weight.data.fill_(0.)
for conv in self.reg_preds:
conv.bias.data.fill_(1.0)
conv.weight.data.fill_(0.)
def forward(self, x: Tensor) -> Tensor:
"""Forward features from the upstream network.
Args:
@ -146,10 +146,10 @@ class YOLOv6HeadModule(BaseModule):
return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
self.cls_preds, self.reg_convs, self.reg_preds)
def forward_single(self, x: torch.Tensor, stem: nn.ModuleList,
def forward_single(self, x: Tensor, stem: nn.ModuleList,
cls_conv: nn.ModuleList, cls_pred: nn.ModuleList,
reg_conv: nn.ModuleList,
reg_pred: nn.ModuleList) -> torch.Tensor:
reg_pred: nn.ModuleList) -> Tuple[Tensor, Tensor]:
"""Forward feature of a single scale level."""
y = stem(x)
cls_x = y
@ -192,12 +192,20 @@ class YOLOv6Head(YOLOv5Head):
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
type='mmdet.VarifocalLoss',
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
@ -217,13 +225,27 @@ class YOLOv6Head(YOLOv5Head):
test_cfg=test_cfg,
init_cfg=init_cfg)
self.loss_bbox = MODELS.build(loss_bbox)
self.loss_cls = MODELS.build(loss_cls)
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
pass
if self.train_cfg:
self.initial_epoch = self.train_cfg['initial_epoch']
self.initial_assigner = TASK_UTILS.build(
self.train_cfg.initial_assigner)
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
# Add common attributes to reduce calculation
self.featmap_sizes = None
self.mlvl_priors = None
self.num_level_priors = None
self.flatten_priors = None
self.stride_tensor = None
def loss_by_feat(
self,
@ -254,4 +276,148 @@ class YOLOv6Head(YOLOv5Head):
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
raise NotImplementedError('Not implemented yet')
# get epoch information from message hub
message_hub = MessageHub.get_current_instance()
current_epoch = message_hub.get_info('epoch')
num_imgs = len(batch_img_metas)
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes:
self.featmap_sizes = current_featmap_sizes
self.mlvl_priors = self.prior_generator.grid_priors(
self.featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in self.mlvl_priors]
self.flatten_priors = torch.cat(self.mlvl_priors, dim=0)
self.stride_tensor = self.flatten_priors[..., [2]]
# gt info
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors[..., :2], flatten_pred_bboxes,
self.flatten_priors[..., 2])
pred_scores = torch.sigmoid(flatten_cls_preds)
if current_epoch < self.initial_epoch:
assigned_result = self.initial_assigner(
flatten_pred_bboxes.detach(), self.flatten_priors,
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
else:
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
pred_scores.detach(),
self.flatten_priors, gt_labels,
gt_bboxes, pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
# cls loss
with torch.cuda.amp.autocast(enabled=False):
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
# TODO: Add all_reduce makes training more stable
assigned_scores_sum = assigned_scores.sum()
if assigned_scores_sum > 0:
loss_cls /= assigned_scores_sum
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos,
assigned_bboxes_pos,
weight=bbox_weight,
avg_factor=assigned_scores_sum)
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
@staticmethod
def gt_instances_preprocess(batch_gt_instances: Tensor,
batch_size: int) -> Tensor:
"""Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
to.
[batch_size, number_gt, 5]. If some shape of single batch smaller than
gt bbox len, then using [-1., 0., 0., 0., 0.] to fill.
Args:
batch_gt_instances (Sequence[Tensor]): Ground truth
instances for whole batch, shape [all_gt_bboxes, 6]
batch_size (int): Batch size.
Returns:
Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
"""
# sqlit batch gt instance [all_gt_bboxes, 6] ->
# [batch_size, number_gt_each_batch, 5]
batch_instance_list = []
max_gt_bbox_len = 0
for i in range(batch_size):
single_batch_instance = \
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
single_batch_instance = single_batch_instance[:, 1:]
batch_instance_list.append(single_batch_instance)
if len(single_batch_instance) > max_gt_bbox_len:
max_gt_bbox_len = len(single_batch_instance)
# fill [-1., 0., 0., 0., 0.] if some shape of
# single batch not equal max_gt_bbox_len
for index, gt_instance in enumerate(batch_instance_list):
if gt_instance.shape[0] >= max_gt_bbox_len:
continue
fill_tensor = batch_gt_instances.new_full(
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
fill_tensor[:, 0] = -1.
batch_instance_list[index] = torch.cat(
(batch_instance_list[index], fill_tensor), dim=0)
return torch.stack(batch_instance_list)

View File

@ -294,7 +294,7 @@ class RepVGGBlock(nn.Module):
"""
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
if isinstance(branch, ConvModule):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
@ -302,7 +302,7 @@ class RepVGGBlock(nn.Module):
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),

View File

@ -10,18 +10,17 @@ from mmdet.structures.bbox import HorizontalBoxes
from mmyolo.registry import MODELS
# TODO: unify all code
def bbox_overlaps(pred: torch.Tensor,
target: torch.Tensor,
iou_mode: str = 'ciou',
bbox_format: str = 'xywh',
is_aligned: bool = False,
siou_theta: float = 4.0,
eps: float = 1e-7) -> torch.Tensor:
r"""Calculate overlap between two set of bboxes.
`Implementation of paper `Enhancing Geometric Factors into
Model Learning and Inference for Object Detection and Instance
Segmentation <https://arxiv.org/abs/2005.03572>`_.
In the CIoU implementation of YOLOv5 and mmdetection, there is a slight
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
difference in the way the alpha parameter is computed.
mmdet version:
alpha = (ious > 0.5).float() * v / (1 - ious + v)
@ -35,27 +34,36 @@ def bbox_overlaps(pred: torch.Tensor,
Defaults to "ciou".
bbox_format (str): Options are "xywh" and "xyxy".
Defaults to "xywh".
is_aligned (bool):
siou_theta (float): siou_theta for SIoU when calculate shape cost.
Defaults to 4.0.
eps (float): Eps to avoid log(0).
Returns:
Tensor: shape (n,).
"""
assert iou_mode in ('ciou', )
assert iou_mode in ('ciou', 'giou', 'siou')
assert bbox_format in ('xyxy', 'xywh')
if bbox_format == 'xywh':
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
target = HorizontalBoxes.cxcywh_to_xyxy(target)
# overlap
lt = torch.max(pred[:, :2], target[:, :2])
rb = torch.min(pred[:, 2:], target[:, 2:])
wh = (rb - lt).clamp(min=0)
overlap = wh[:, 0] * wh[:, 1]
bbox1_x1, bbox1_y1 = pred[:, 0], pred[:, 1]
bbox1_x2, bbox1_y2 = pred[:, 2], pred[:, 3]
bbox2_x1, bbox2_y1 = target[:, 0], target[:, 1]
bbox2_x2, bbox2_y2 = target[:, 2], target[:, 3]
# union
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = ap + ag - overlap + eps
# Overlap
overlap = (torch.min(bbox1_x2, bbox2_x2) -
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
(torch.min(bbox1_y2, bbox2_y2) -
torch.max(bbox1_y1, bbox2_y1)).clamp(0)
# Union
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
union = (w1 * h1) + (w2 * h2) - overlap + eps
h1 = bbox1_y2 - bbox1_y1 + eps
h2 = bbox2_y2 - bbox2_y1 + eps
# IoU
ious = overlap / union
@ -65,32 +73,78 @@ def bbox_overlaps(pred: torch.Tensor,
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
cw = enclose_wh[:, 0]
ch = enclose_wh[:, 1]
enclose_w = enclose_wh[:, 0] # cw
enclose_h = enclose_wh[:, 1] # ch
c2 = cw**2 + ch**2 + eps
if iou_mode == 'ciou':
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
b2_x1, b2_y1 = target[:, 0], target[:, 1]
b2_x2, b2_y2 = target[:, 2], target[:, 3]
# calculate enclose area (c^2)
enclose_area = enclose_w**2 + enclose_h**2 + eps
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# calculate ρ^2(b_pred,b_gt):
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
# center point, because bbox format is xyxy -> left-top xy and
# right-bottom xy, so need to / 4 to get center point.
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
(bbox1_y1 + bbox1_y2))**2 / 4
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
rho2 = left + right
# Width and height ratio (v)
wh_ratio = (4 / (math.pi**2)) * torch.pow(
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
factor = 4 / math.pi**2
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = wh_ratio / (wh_ratio - ious + (1 + eps))
with torch.no_grad():
alpha = v / (v - ious + (1 + eps))
# CIoU
ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))
# CIoU
cious = ious - (rho2 / c2 + alpha * v)
return cious.clamp(min=-1.0, max=1.0)
elif iou_mode == 'giou':
# GIoU = IoU - ( (A_c - union) / A_c )
convex_area = enclose_w * enclose_h + eps # convex area (A_c)
ious = ious - (convex_area - union) / convex_area
elif iou_mode == 'siou':
# SIoU: https://arxiv.org/pdf/2205.12740.pdf
# SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 )
# calculate sigma (σ):
# euclidean distance between bbox2(pred) and bbox1(gt) center point,
# sigma_cw = b_cx_gt - b_cx
sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps
# sigma_ch = b_cy_gt - b_cy
sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps
# sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) )
sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5)
# choose minimize alpha, sin(alpha)
sin_alpha = torch.abs(sigma_ch) / sigma
sin_beta = torch.abs(sigma_cw) / sigma
sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha,
sin_beta)
# Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) )
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
# Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t))
rho_x = (sigma_cw / enclose_w)**2 # ρ_x
rho_y = (sigma_ch / enclose_h)**2 # ρ_y
gamma = 2 - angle_cost # γ
distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + (
1 - torch.exp(-1 * gamma * rho_y))
# Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ )
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w),
siou_theta) + torch.pow(
1 - torch.exp(-1 * omiga_h), siou_theta)
ious = ious - ((distance_cost + shape_cost) * 0.5)
return ious.clamp(min=-1.0, max=1.0)
@MODELS.register_module()
@ -118,7 +172,7 @@ class IoULoss(nn.Module):
return_iou: bool = True):
super().__init__()
assert bbox_format in ('xywh', 'xyxy')
assert iou_mode in ('ciou', )
assert iou_mode in ('ciou', 'siou', 'giou')
self.iou_mode = iou_mode
self.bbox_format = bbox_format
self.eps = eps
@ -131,7 +185,7 @@ class IoULoss(nn.Module):
pred: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
avg_factor: Optional[str] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[Union[str, bool]] = None
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Forward function.
@ -155,11 +209,8 @@ class IoULoss(nn.Module):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# giou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
iou = bbox_overlaps(

View File

@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .assigners import BatchATSSAssigner, BatchTaskAlignedAssigner
from .coders import YOLOv5BBoxCoder, YOLOXBBoxCoder
__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder']
__all__ = [
'YOLOv5BBoxCoder', 'YOLOXBBoxCoder', 'BatchATSSAssigner',
'BatchTaskAlignedAssigner'
]

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_atss_assigner import BatchATSSAssigner
from .batch_task_aligned_assigner import BatchTaskAlignedAssigner
from .utils import (select_candidates_in_gts, select_highest_overlaps,
yolov6_iou_calculator)
__all__ = [
'BatchATSSAssigner', 'BatchTaskAlignedAssigner',
'select_candidates_in_gts', 'select_highest_overlaps',
'yolov6_iou_calculator'
]

View File

@ -0,0 +1,339 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.utils import ConfigType
from torch import Tensor
from mmyolo.registry import TASK_UTILS
from .utils import (select_candidates_in_gts, select_highest_overlaps,
yolov6_iou_calculator)
def bbox_center_distance(bboxes: Tensor,
priors: Tensor) -> Tuple[Tensor, Tensor]:
"""Compute the center distance between bboxes and priors.
Args:
bboxes (Tensor): Shape (n, 4) for bbox, "xyxy" format.
priors (Tensor): Shape (num_priors, 4) for priors, "xyxy" format.
Returns:
distances (Tensor): Center distances between bboxes and priors,
shape (num_priors, n).
priors_points (Tensor): Priors cx cy points,
shape (num_priors, 2).
"""
bbox_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
bbox_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
bbox_points = torch.stack((bbox_cx, bbox_cy), dim=1)
priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0
priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0
priors_points = torch.stack((priors_cx, priors_cy), dim=1)
distances = (bbox_points[:, None, :] -
priors_points[None, :, :]).pow(2).sum(-1).sqrt()
return distances, priors_points
@TASK_UTILS.register_module()
class BatchATSSAssigner(nn.Module):
"""Assign a batch of corresponding gt bboxes or background to each prior.
This code is based on
https://github.com/meituan/YOLOv6/blob/main/yolov6/assigners/atss_assigner.py
Each proposal will be assigned with `0` or a positive integer
indicating the ground truth index.
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
num_classes (int): number of class
iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
calculator. Defaults to ``dict(type='BboxOverlaps2D')``
topk (int): number of priors selected in each level
"""
def __init__(
self,
num_classes: int,
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'),
topk: int = 9):
super().__init__()
self.num_classes = num_classes
self.iou_calculator = TASK_UTILS.build(iou_calculator)
self.topk = topk
@torch.no_grad()
def forward(self, pred_bboxes: Tensor, priors: Tensor,
num_level_priors: List, gt_labels: Tensor, gt_bboxes: Tensor,
pad_bbox_flag: Tensor) -> dict:
"""Assign gt to priors.
The assignment is done in following steps
1. compute iou between all prior (prior of all pyramid levels) and gt
2. compute center distance between all prior and gt
3. on each pyramid level, for each gt, select k prior whose center
are closest to the gt center, so we total select k*l prior as
candidates for each gt
4. get corresponding iou for the these candidates, and compute the
mean and std, set mean + std as the iou threshold
5. select these candidates whose iou are greater than or equal to
the threshold as positive
6. limit the positive sample's center in gt
Args:
pred_bboxes (Tensor): Predicted bounding boxes,
shape(batch_size, num_priors, 4)
priors (Tensor): Model priors, shape(num_priors, 4)
num_level_priors (List): Number of bboxes in each level, len(3)
gt_labels (Tensor): Ground truth label,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground truth bbox,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
Returns:
assigned_result (dict): Assigned result
'assigned_labels' (Tensor): shape(batch_size, num_gt)
'assigned_bboxes' (Tensor): shape(batch_size, num_gt, 4)
'assigned_scores' (Tensor):
shape(batch_size, num_gt, number_classes)
'fg_mask_pre_prior' (Tensor): shape(bs, num_gt)
"""
# generate priors
cell_half_size = priors[:, 2:] * 2.5
priors_gen = torch.zeros_like(priors)
priors_gen[:, :2] = priors[:, :2] - cell_half_size
priors_gen[:, 2:] = priors[:, :2] + cell_half_size
priors = priors_gen
batch_size = gt_bboxes.size(0)
num_gt, num_priors = gt_bboxes.size(1), priors.size(0)
assigned_result = {
'assigned_labels':
gt_bboxes.new_full([batch_size, num_priors], self.num_classes),
'assigned_bboxes':
gt_bboxes.new_full([batch_size, num_priors, 4], 0),
'assigned_scores':
gt_bboxes.new_full([batch_size, num_priors, self.num_classes], 0),
'fg_mask_pre_prior':
gt_bboxes.new_full([batch_size, num_priors], 0)
}
if num_gt == 0:
return assigned_result
# compute iou between all prior (prior of all pyramid levels) and gt
overlaps = self.iou_calculator(gt_bboxes.reshape([-1, 4]), priors)
overlaps = overlaps.reshape([batch_size, -1, num_priors])
# compute center distance between all prior and gt
distances, priors_points = bbox_center_distance(
gt_bboxes.reshape([-1, 4]), priors)
distances = distances.reshape([batch_size, -1, num_priors])
# Selecting candidates based on the center distance
is_in_candidate, candidate_idxs = self.select_topk_candidates(
distances, num_level_priors, pad_bbox_flag)
# get corresponding iou for the these candidates, and compute the
# mean and std, set mean + std as the iou threshold
overlaps_thr_per_gt, iou_candidates = self.threshold_calculator(
is_in_candidate, candidate_idxs, overlaps, num_priors, batch_size,
num_gt)
# select candidates iou >= threshold as positive
is_pos = torch.where(
iou_candidates > overlaps_thr_per_gt.repeat([1, 1, num_priors]),
is_in_candidate, torch.zeros_like(is_in_candidate))
is_in_gts = select_candidates_in_gts(priors_points, gt_bboxes)
pos_mask = is_pos * is_in_gts * pad_bbox_flag
# if an anchor box is assigned to multiple gts,
# the one with the highest IoU will be selected.
gt_idx_pre_prior, fg_mask_pre_prior, pos_mask = \
select_highest_overlaps(pos_mask, overlaps, num_gt)
# assigned target
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
gt_labels, gt_bboxes, gt_idx_pre_prior, fg_mask_pre_prior,
num_priors, batch_size, num_gt)
# soft label with iou
if pred_bboxes is not None:
ious = yolov6_iou_calculator(gt_bboxes, pred_bboxes) * pos_mask
ious = ious.max(axis=-2)[0].unsqueeze(-1)
assigned_scores *= ious
assigned_result['assigned_labels'] = assigned_labels.long()
assigned_result['assigned_bboxes'] = assigned_bboxes
assigned_result['assigned_scores'] = assigned_scores
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
return assigned_result
def select_topk_candidates(self, distances: Tensor,
num_level_priors: List[int],
pad_bbox_flag: Tensor) -> Tuple[Tensor, Tensor]:
"""Selecting candidates based on the center distance.
Args:
distances (Tensor): Distance between all bbox and gt,
shape(batch_size, num_gt, num_priors)
num_level_priors (List[int]): Number of bboxes in each level,
len(3)
pad_bbox_flag (Tensor): Ground truth bbox mask,
shape(batch_size, num_gt, 1)
Return:
is_in_candidate_list (Tensor): Flag show that each level have
topk candidates or not, shape(batch_size, num_gt, num_priors)
candidate_idxs (Tensor): Candidates index,
shape(batch_size, num_gt, num_gt)
"""
is_in_candidate_list = []
candidate_idxs = []
start_idx = 0
distances_dtype = distances.dtype
distances = torch.split(distances, num_level_priors, dim=-1)
pad_bbox_flag = pad_bbox_flag.repeat(1, 1, self.topk).bool()
for distances_per_level, priors_per_level in zip(
distances, num_level_priors):
# on each pyramid level, for each gt,
# select k bbox whose center are closest to the gt center
end_index = start_idx + priors_per_level
selected_k = min(self.topk, priors_per_level)
_, topk_idxs_per_level = distances_per_level.topk(
selected_k, dim=-1, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
topk_idxs_per_level = torch.where(
pad_bbox_flag, topk_idxs_per_level,
torch.zeros_like(topk_idxs_per_level))
is_in_candidate = F.one_hot(topk_idxs_per_level,
priors_per_level).sum(dim=-2)
is_in_candidate = torch.where(is_in_candidate > 1,
torch.zeros_like(is_in_candidate),
is_in_candidate)
is_in_candidate_list.append(is_in_candidate.to(distances_dtype))
start_idx = end_index
is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)
candidate_idxs = torch.cat(candidate_idxs, dim=-1)
return is_in_candidate_list, candidate_idxs
@staticmethod
def threshold_calculator(is_in_candidate: List, candidate_idxs: Tensor,
overlaps: Tensor, num_priors: int,
batch_size: int,
num_gt: int) -> Tuple[Tensor, Tensor]:
"""Get corresponding iou for the these candidates, and compute the mean
and std, set mean + std as the iou threshold.
Args:
is_in_candidate (Tensor): Flag show that each level have
topk candidates or not, shape(batch_size, num_gt, num_priors).
candidate_idxs (Tensor): Candidates index,
shape(batch_size, num_gt, num_gt)
overlaps (Tensor): Overlaps area,
shape(batch_size, num_gt, num_priors).
num_priors (int): Number of priors.
batch_size (int): Batch size.
num_gt (int): Number of ground truth.
Return:
overlaps_thr_per_gt (Tensor): Overlap threshold of
per ground truth, shape(batch_size, num_gt, 1).
candidate_overlaps (Tensor): Candidate overlaps,
shape(batch_size, num_gt, num_priors).
"""
batch_size_num_gt = batch_size * num_gt
candidate_overlaps = torch.where(is_in_candidate > 0, overlaps,
torch.zeros_like(overlaps))
candidate_idxs = candidate_idxs.reshape([batch_size_num_gt, -1])
assist_indexes = num_priors * torch.arange(
batch_size_num_gt, device=candidate_idxs.device)
assist_indexes = assist_indexes[:, None]
flatten_indexes = candidate_idxs + assist_indexes
candidate_overlaps_reshape = candidate_overlaps.reshape(
-1)[flatten_indexes]
candidate_overlaps_reshape = candidate_overlaps_reshape.reshape(
[batch_size, num_gt, -1])
overlaps_mean_per_gt = candidate_overlaps_reshape.mean(
axis=-1, keepdim=True)
overlaps_std_per_gt = candidate_overlaps_reshape.std(
axis=-1, keepdim=True)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
return overlaps_thr_per_gt, candidate_overlaps
def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor,
assigned_gt_inds: Tensor, fg_mask_pre_prior: Tensor,
num_priors: int, batch_size: int,
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Get target info.
Args:
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
assigned_gt_inds (Tensor): Assigned ground truth indexes,
shape(batch_size, num_priors)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
num_priors (int): Number of priors.
batch_size (int): Batch size.
num_gt (int): Number of ground truth.
Return:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned bboxes,
shape(batch_size, num_priors)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors)
"""
# assigned target labels
batch_index = torch.arange(
batch_size, dtype=gt_labels.dtype, device=gt_labels.device)
batch_index = batch_index[..., None]
assigned_gt_inds = (assigned_gt_inds + batch_index * num_gt).long()
assigned_labels = gt_labels.flatten()[assigned_gt_inds.flatten()]
assigned_labels = assigned_labels.reshape([batch_size, num_priors])
assigned_labels = torch.where(
fg_mask_pre_prior > 0, assigned_labels,
torch.full_like(assigned_labels, self.num_classes))
# assigned target boxes
assigned_bboxes = gt_bboxes.reshape([-1,
4])[assigned_gt_inds.flatten()]
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_priors, 4])
# assigned target scores
assigned_scores = F.one_hot(assigned_labels.long(),
self.num_classes + 1).float()
assigned_scores = assigned_scores[:, :, :self.num_classes]
return assigned_labels, assigned_bboxes, assigned_scores

View File

@ -0,0 +1,298 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmyolo.registry import TASK_UTILS
from .utils import (select_candidates_in_gts, select_highest_overlaps,
yolov6_iou_calculator)
@TASK_UTILS.register_module()
class BatchTaskAlignedAssigner(nn.Module):
"""This code referenced to
https://github.com/meituan/YOLOv6/blob/main/yolov6/
assigners/tal_assigner.py.
Batch Task aligned assigner base on the paper:
`TOOD: Task-aligned One-stage Object Detection.
<https://arxiv.org/abs/2108.07755>`_.
Assign a corresponding gt bboxes or background to a batch of
predicted bboxes. Each bbox will be assigned with `0` or a
positive integer indicating the ground truth index.
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
num_classes (int): number of class
topk (int): number of bbox selected in each level
alpha (float): Hyper-parameters related to alignment_metrics.
Defaults to 1.0
beta (float): Hyper-parameters related to alignment_metrics.
Defaults to 6.
eps (float): Eps to avoid log(0). Default set to 1e-9
"""
def __init__(self,
num_classes: int,
topk: int = 13,
alpha: float = 1.0,
beta: float = 6.0,
eps: float = 1e-7):
super().__init__()
self.num_classes = num_classes
self.topk = topk
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(
self,
pred_bboxes: Tensor,
pred_scores: Tensor,
priors: Tensor,
gt_labels: Tensor,
gt_bboxes: Tensor,
pad_bbox_flag: Tensor,
) -> dict:
"""Assign gt to bboxes.
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid
levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free
detector only can predict positive distance)
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bboxes,
shape(batch_size, num_priors, num_classes)
priors (Tensor): Model priors, shape (num_priors, 4)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
Returns:
assigned_result (dict) Assigned result:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned boxes,
shape(batch_size, num_priors, 4)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors, num_classes)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
"""
# (num_priors, 4) -> (num_priors, 2)
priors = priors[:, :2]
batch_size = pred_scores.size(0)
num_gt = gt_bboxes.size(1)
assigned_result = {
'assigned_labels':
gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assigned_scores':
gt_bboxes.new_full(pred_scores.shape, 0),
'fg_mask_pre_prior':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
if num_gt == 0:
return assigned_result
pos_mask, alignment_metrics, overlaps = self.get_pos_mask(
pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes,
pad_bbox_flag, batch_size, num_gt)
(assigned_gt_idxs, fg_mask_pre_prior,
pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt)
# assigned target
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior,
batch_size, num_gt)
# normalize
alignment_metrics *= pos_mask
pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0]
pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0]
norm_align_metric = (
alignment_metrics * pos_overlaps /
(pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
assigned_scores = assigned_scores * norm_align_metric
assigned_result['assigned_labels'] = assigned_labels
assigned_result['assigned_bboxes'] = assigned_bboxes
assigned_result['assigned_scores'] = assigned_scores
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
return assigned_result
def get_pos_mask(self, pred_bboxes: Tensor, pred_scores: Tensor,
priors: Tensor, gt_labels: Tensor, gt_bboxes: Tensor,
pad_bbox_flag: Tensor, batch_size: int,
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Get possible mask.
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bbox,
shape(batch_size, num_priors, num_classes)
priors (Tensor): Model priors, shape (num_priors, 2)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
batch_size (int): Batch size.
num_gt (int): Number of ground truth.
Returns:
pos_mask (Tensor): Possible mask,
shape(batch_size, num_gt, num_priors)
alignment_metrics (Tensor): Alignment metrics,
shape(batch_size, num_gt, num_priors)
overlaps (Tensor): Overlaps of gt_bboxes and pred_bboxes,
shape(batch_size, num_gt, num_priors)
"""
# Compute alignment metric between all bbox and gt
alignment_metrics, overlaps = \
self.get_box_metrics(pred_bboxes, pred_scores, gt_labels,
gt_bboxes, batch_size, num_gt)
# get is_in_gts mask
is_in_gts = select_candidates_in_gts(priors, gt_bboxes)
# get topk_metric mask
topk_metric = self.select_topk_candidates(
alignment_metrics * is_in_gts,
topk_mask=pad_bbox_flag.repeat([1, 1, self.topk]).bool())
# merge all mask to a final mask
pos_mask = topk_metric * is_in_gts * pad_bbox_flag
return pos_mask, alignment_metrics, overlaps
def get_box_metrics(self, pred_bboxes: Tensor, pred_scores: Tensor,
gt_labels: Tensor, gt_bboxes: Tensor, batch_size: int,
num_gt: int) -> Tuple[Tensor, Tensor]:
"""Compute alignment metric between all bbox and gt.
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bbox,
shape(batch_size, num_priors, num_classes)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
batch_size (int): Batch size.
num_gt (int): Number of ground truth.
Returns:
alignment_metrics (Tensor): Align metric,
shape(batch_size, num_gt, num_priors)
overlaps (Tensor): Overlaps, shape(batch_size, num_gt, num_priors)
"""
pred_scores = pred_scores.permute(0, 2, 1)
gt_labels = gt_labels.to(torch.long)
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
idx[1] = gt_labels.squeeze(-1)
bbox_scores = pred_scores[idx[0], idx[1]]
overlaps = yolov6_iou_calculator(gt_bboxes, pred_bboxes)
alignment_metrics = bbox_scores.pow(self.alpha) * overlaps.pow(
self.beta)
return alignment_metrics, overlaps
def select_topk_candidates(self,
alignment_gt_metrics: Tensor,
using_largest_topk: bool = True,
topk_mask: Optional[Tensor] = None) -> Tensor:
"""Compute alignment metric between all bbox and gt.
Args:
alignment_gt_metrics (Tensor): Alignment metric of gt candidates,
shape(batch_size, num_gt, num_priors)
using_largest_topk (bool): Controls whether to using largest or
smallest elements.
topk_mask (Tensor): Topk mask,
shape(batch_size, num_gt, self.topk)
Returns:
Tensor: Topk candidates mask,
shape(batch_size, num_gt, num_priors)
"""
num_priors = alignment_gt_metrics.shape[-1]
topk_metrics, topk_idxs = torch.topk(
alignment_gt_metrics,
self.topk,
axis=-1,
largest=using_largest_topk)
if topk_mask is None:
topk_mask = (topk_metrics.max(axis=-1, keepdim=True) >
self.eps).tile([1, 1, self.topk])
topk_idxs = torch.where(topk_mask, topk_idxs,
torch.zeros_like(topk_idxs))
is_in_topk = F.one_hot(topk_idxs, num_priors).sum(axis=-2)
is_in_topk = torch.where(is_in_topk > 1, torch.zeros_like(is_in_topk),
is_in_topk)
return is_in_topk.to(alignment_gt_metrics.dtype)
def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor,
assigned_gt_idxs: Tensor, fg_mask_pre_prior: Tensor,
batch_size: int,
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Get assigner info.
Args:
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
assigned_gt_idxs (Tensor): Assigned ground truth indexes,
shape(batch_size, num_priors)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
batch_size (int): Batch size.
num_gt (int): Number of ground truth.
Returns:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned bboxes,
shape(batch_size, num_priors)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors)
"""
# assigned target labels
batch_ind = torch.arange(
end=batch_size, dtype=torch.int64, device=gt_labels.device)[...,
None]
assigned_gt_idxs = assigned_gt_idxs + batch_ind * num_gt
assigned_labels = gt_labels.long().flatten()[assigned_gt_idxs]
# assigned target boxes
assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_idxs]
# assigned target scores
assigned_labels[assigned_labels < 0] = 0
assigned_scores = F.one_hot(assigned_labels, self.num_classes)
force_gt_scores_mask = fg_mask_pre_prior[:, :, None].repeat(
1, 1, self.num_classes)
assigned_scores = torch.where(force_gt_scores_mask > 0,
assigned_scores,
torch.full_like(assigned_scores, 0))
return assigned_labels, assigned_bboxes, assigned_scores

View File

@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def select_candidates_in_gts(priors_points: Tensor,
gt_bboxes: Tensor,
eps: float = 1e-9) -> Tensor:
"""Select the positive priors' center in gt.
Args:
priors_points (Tensor): Model priors points,
shape(num_priors, 2)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
eps (float): Default to 1e-9.
Return:
(Tensor): shape(batch_size, num_gt, num_priors)
"""
batch_size, num_gt, _ = gt_bboxes.size()
gt_bboxes = gt_bboxes.reshape([-1, 4])
priors_number = priors_points.size(0)
priors_points = priors_points.unsqueeze(0).repeat(batch_size * num_gt, 1,
1)
# calculate the left, top, right, bottom distance between positive
# prior center and gt side
gt_bboxes_lt = gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, priors_number, 1)
gt_bboxes_rb = gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, priors_number, 1)
bbox_deltas = torch.cat(
[priors_points - gt_bboxes_lt, gt_bboxes_rb - priors_points], dim=-1)
bbox_deltas = bbox_deltas.reshape([batch_size, num_gt, priors_number, -1])
return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
def select_highest_overlaps(pos_mask: Tensor, overlaps: Tensor,
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
"""If an anchor box is assigned to multiple gts, the one with the highest
iou will be selected.
Args:
pos_mask (Tensor): The assigned positive sample mask,
shape(batch_size, num_gt, num_priors)
overlaps (Tensor): IoU between all bbox and ground truth,
shape(batch_size, num_gt, num_priors)
num_gt (int): Number of ground truth.
Return:
gt_idx_pre_prior (Tensor): Target ground truth index,
shape(batch_size, num_priors)
fg_mask_pre_prior (Tensor): Force matching ground truth,
shape(batch_size, num_priors)
pos_mask (Tensor): The assigned positive sample mask,
shape(batch_size, num_gt, num_priors)
"""
fg_mask_pre_prior = pos_mask.sum(axis=-2)
# Make sure the positive sample matches the only one and is the largest IoU
if fg_mask_pre_prior.max() > 1:
mask_multi_gts = (fg_mask_pre_prior.unsqueeze(1) > 1).repeat(
[1, num_gt, 1])
index = overlaps.argmax(axis=1)
is_max_overlaps = F.one_hot(index, num_gt)
is_max_overlaps = \
is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
pos_mask = torch.where(mask_multi_gts, is_max_overlaps, pos_mask)
fg_mask_pre_prior = pos_mask.sum(axis=-2)
gt_idx_pre_prior = pos_mask.argmax(axis=-2)
return gt_idx_pre_prior, fg_mask_pre_prior, pos_mask
# TODO:'mmdet.BboxOverlaps2D' will cause gradient inconsistency,
# which will be found and solved in a later version.
def yolov6_iou_calculator(bbox1: Tensor,
bbox2: Tensor,
eps: float = 1e-9) -> Tensor:
"""Calculate iou for batch.
Args:
bbox1 (Tensor): shape(batch size, num_gt, 4)
bbox2 (Tensor): shape(batch size, num_priors, 4)
eps (float): Default to 1e-9.
Return:
(Tensor): IoU, shape(size, num_gt, num_priors)
"""
bbox1 = bbox1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
bbox2 = bbox2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
# calculate xy info of predict and gt bbox
bbox1_x1y1, bbox1_x2y2 = bbox1[:, :, :, 0:2], bbox1[:, :, :, 2:4]
bbox2_x1y1, bbox2_x2y2 = bbox2[:, :, :, 0:2], bbox2[:, :, :, 2:4]
# calculate overlap area
overlap = (torch.minimum(bbox1_x2y2, bbox2_x2y2) -
torch.maximum(bbox1_x1y1, bbox2_x1y1)).clip(0).prod(-1)
# calculate bbox area
bbox1_area = (bbox1_x2y2 - bbox1_x1y1).clip(0).prod(-1)
bbox2_area = (bbox2_x2y2 - bbox2_x1y1).clip(0).prod(-1)
union = bbox1_area + bbox2_area - overlap + eps
return overlap / union

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .misc import switch_to_deploy
from .setup_env import register_all_modules
__all__ = ['register_all_modules', 'collect_env']
__all__ = ['register_all_modules', 'collect_env', 'switch_to_deploy']

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmyolo.models import RepVGGBlock
def switch_to_deploy(model):
"""Model switch to deploy status."""
for layer in model.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
print('Switch model to deploy modality.')

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
from mmyolo.engine.hooks import SwitchToDeployHook
from mmyolo.models import RepVGGBlock
from mmyolo.utils import register_all_modules
register_all_modules()
class TestSwitchToDeployHook(TestCase):
def test(self):
runner = Mock()
runner.model = RepVGGBlock(256, 256)
hook = SwitchToDeployHook()
self.assertFalse(runner.model.deploy)
# test after change mode
hook.before_test_epoch(runner)
self.assertTrue(runner.model.deploy)

View File

@ -20,7 +20,7 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
'yolov6/yolov6_s_syncbn_8xb32-400e_coco.py',
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
'yolox/yolox_tiny_8xb8-300e_coco.py',
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py'
])
@ -67,7 +67,7 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
'cpu')),
('yolov6/yolov6_s_syncbn_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
])
@ -98,7 +98,7 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
'cpu')),
('yolov6/yolov6_s_syncbn_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
])

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,175 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.task_modules.assigners import BatchATSSAssigner
class TestBatchATSSAssigner(TestCase):
def test_batch_atss_assigner(self):
num_classes = 2
batch_size = 2
batch_atss_assigner = BatchATSSAssigner(
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.FloatTensor([
[4., 4., 8., 8.],
[12., 4., 8., 8.],
[20., 4., 8., 8.],
[28., 4., 8., 8.],
]).repeat(21, 1)
gt_bboxes = torch.FloatTensor([
[0, 0, 60, 93],
[229, 0, 532, 157],
]).unsqueeze(0).repeat(batch_size, 1, 1)
gt_labels = torch.LongTensor([
[0],
[11],
]).unsqueeze(0).repeat(batch_size, 1, 1)
num_level_bboxes = [64, 16, 4]
pad_bbox_flag = torch.FloatTensor([
[1],
[0],
]).unsqueeze(0).repeat(batch_size, 1, 1)
pred_bboxes = torch.FloatTensor([
[-4., -4., 12., 12.],
[4., -4., 20., 12.],
[12., -4., 28., 12.],
[20., -4., 36., 12.],
]).unsqueeze(0).repeat(batch_size, 21, 1)
batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_labels = batch_assign_result['assigned_labels']
assigned_bboxes = batch_assign_result['assigned_bboxes']
assigned_scores = batch_assign_result['assigned_scores']
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
4]))
self.assertEqual(assigned_scores.shape,
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
def test_batch_atss_assigner_with_empty_gt(self):
"""Test corner case where an image might have no true detections."""
num_classes = 2
batch_size = 2
batch_atss_assigner = BatchATSSAssigner(
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.FloatTensor([
[4., 4., 8., 8.],
[12., 4., 8., 8.],
[20., 4., 8., 8.],
[28., 4., 8., 8.],
]).repeat(21, 1)
num_level_bboxes = [64, 16, 4]
pad_bbox_flag = torch.FloatTensor([
[1],
[0],
]).unsqueeze(0).repeat(batch_size, 1, 1)
pred_bboxes = torch.FloatTensor([
[-4., -4., 12., 12.],
[4., -4., 20., 12.],
[12., -4., 28., 12.],
[20., -4., 36., 12.],
]).unsqueeze(0).repeat(batch_size, 21, 1)
gt_bboxes = torch.empty(batch_size, 2, 4)
gt_labels = torch.empty(batch_size, 2, 1)
batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_labels = batch_assign_result['assigned_labels']
assigned_bboxes = batch_assign_result['assigned_bboxes']
assigned_scores = batch_assign_result['assigned_scores']
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
4]))
self.assertEqual(assigned_scores.shape,
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
def test_batch_atss_assigner_with_empty_boxes(self):
"""Test corner case where a network might predict no boxes."""
num_classes = 2
batch_size = 2
batch_atss_assigner = BatchATSSAssigner(
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.empty(84, 4)
gt_bboxes = torch.FloatTensor([
[0, 0, 60, 93],
[229, 0, 532, 157],
]).unsqueeze(0).repeat(batch_size, 1, 1)
gt_labels = torch.LongTensor([
[0],
[11],
]).unsqueeze(0).repeat(batch_size, 1, 1)
num_level_bboxes = [64, 16, 4]
pad_bbox_flag = torch.FloatTensor([[1], [0]]).unsqueeze(0).repeat(
batch_size, 1, 1)
pred_bboxes = torch.FloatTensor([
[-4., -4., 12., 12.],
[4., -4., 20., 12.],
[12., -4., 28., 12.],
[20., -4., 36., 12.],
]).unsqueeze(0).repeat(batch_size, 21, 1)
batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_labels = batch_assign_result['assigned_labels']
assigned_bboxes = batch_assign_result['assigned_bboxes']
assigned_scores = batch_assign_result['assigned_scores']
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
4]))
self.assertEqual(assigned_scores.shape,
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
def test_batch_atss_assigner_with_empty_boxes_and_gt(self):
"""Test corner case where a network might predict no boxes and no
gt."""
num_classes = 2
batch_size = 2
batch_atss_assigner = BatchATSSAssigner(
topk=3,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
num_classes=num_classes)
priors = torch.empty(84, 4)
gt_bboxes = torch.empty(batch_size, 2, 4)
gt_labels = torch.empty(batch_size, 2, 1)
num_level_bboxes = [64, 16, 4]
pad_bbox_flag = torch.empty(batch_size, 2, 1)
pred_bboxes = torch.empty(batch_size, 84, 4)
batch_assign_result = batch_atss_assigner.forward(
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_labels = batch_assign_result['assigned_labels']
assigned_bboxes = batch_assign_result['assigned_bboxes']
assigned_scores = batch_assign_result['assigned_scores']
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
4]))
self.assertEqual(assigned_scores.shape,
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.task_modules.assigners import BatchTaskAlignedAssigner
class TestBatchTaskAlignedAssigner(TestCase):
def test_batch_task_aligned_assigner(self):
batch_size = 2
num_classes = 4
assigner = BatchTaskAlignedAssigner(
num_classes=num_classes, alpha=1, beta=6, topk=13, eps=1e-9)
pred_scores = torch.FloatTensor([
[0.1, 0.2],
[0.2, 0.3],
[0.3, 0.4],
[0.4, 0.5],
]).unsqueeze(0).repeat(batch_size, 21, 1)
priors = torch.FloatTensor([
[0, 0, 4., 4.],
[0, 0, 12., 4.],
[0, 0, 20., 4.],
[0, 0, 28., 4.],
]).repeat(21, 1)
gt_bboxes = torch.FloatTensor([
[0, 0, 60, 93],
[229, 0, 532, 157],
]).unsqueeze(0).repeat(batch_size, 1, 1)
gt_labels = torch.LongTensor([[0], [1]
]).unsqueeze(0).repeat(batch_size, 1, 1)
pad_bbox_flag = torch.FloatTensor([[1], [0]]).unsqueeze(0).repeat(
batch_size, 1, 1)
pred_bboxes = torch.FloatTensor([
[-4., -4., 12., 12.],
[4., -4., 20., 12.],
[12., -4., 28., 12.],
[20., -4., 36., 12.],
]).unsqueeze(0).repeat(batch_size, 21, 1)
assign_result = assigner.forward(pred_bboxes, pred_scores, priors,
gt_labels, gt_bboxes, pad_bbox_flag)
assigned_labels = assign_result['assigned_labels']
assigned_bboxes = assign_result['assigned_bboxes']
assigned_scores = assign_result['assigned_scores']
fg_mask_pre_prior = assign_result['fg_mask_pre_prior']
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
4]))
self.assertEqual(assigned_scores.shape,
torch.Size([batch_size, 84, num_classes]))
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))

View File

@ -27,6 +27,10 @@ def parse_args():
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--show-dir',
help='directory where painted images will be saved. '
@ -85,6 +89,9 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
if args.deploy:
cfg.custom_hooks.append(dict(type='SwitchToDeployHook'))
# Dump predictions
if args.out is not None:
assert args.out.endswith(('.pkl', '.pickle')), \