mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support YOLOv6 training (#183)
* init v6 loss * init v6s train * Add train pipeline * Add lr scheduler * update * update * update * update * update * update * update * update * update * fix detach bug * fix detach bug * update * Add stop aug hook * Add save best ckpt * update * Add PipelineSwitchHook * Fix train pipeline stage 2 * update * Fix train pipeline * update * fix stage2 randomaffine bug update update clean clean * update letterResize param * add v6affine config * add v6 randomaffine * update v6 config * update * update * update * update * update config param * update * update * refactor iou loss % rm v6affine * update * rm dfl * add v6 300 epoch config * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * clean config * add nano tiny config * pre-commit * Refactor * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * pre commit * Add UT * fix config * pre commit * Improve code * Improve code * Improve code * Improve code * [Refactor] YOLOv6 BatchATSSAssigner (#179) * Factor batch atss assigner * Format code * Format code * Roll back * Refactor dist_calculator * Refactor select_candidates_in_gts * Refactor select_highest_overlaps * Refactor iou_calculator * Refactor all code * Improve docstr * Improve code * Improve code * Improve naming and link * Add UT * pre commit * Add UT * Add UT * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code, using mmdet.BboxOverlaps2D for all iou calculation * Improve code * pre commit * Fix conflicts * Improve code * Improve code * Improve code * Improve code * Improve code * Improve code * add utils.py, order the input param * Improve docstr * Fix lint * Improve param mapping * Improve param mapping * Improve naming * assigner return dict * update * update config * update config * Fix * Fix UT * Improve UT * Improve naming * Improve coding * pre commit * pre commit * pre commit * Fix ci * Improve naming * Improve coding * Fix training iou calculate error * Improve naming * Improve naming * Improve type hint * fix lint * fix conflicts * fix UT * Improve type hint * Improve naming * Improve coding * Improve coding * Fix UT * Refactor SIoU * Pre commit * Fix * Improve ciou * Improve ciou * refactor varifocal * Improve ciou * Improve ciou * Improve siou * Improve type hint * Improve siou * Improve siou * Fix lint * refactor varifocal * fix iou bug * fix siou and loss_cls bug * update * update * add scope * update * update * Improve func `gt_instances_preprocess` * support deploy mode * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `gt_instances_preprocess` * Improve func `bbox_overlaps` * Improve coding * Improve bbox_overlaps * Delete useless code * add yolov6 deploy mode hook * fix lint * Add common attributes to reduce calculation * Improve code * Improve code * Fix bug * Fix bug * update * add readme * update readme * update readme url Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>pull/249/head
parent
177eb4ea13
commit
980e908618
demo
mmyolo
engine/hooks
models
dense_heads
layers
losses
task_modules
utils
tests
test_engine/test_hooks
test_models
test_detectors
test_task_modules
tools
|
@ -16,9 +16,17 @@ For years, YOLO series have been de facto industry-level standard for efficient
|
|||
|
||||
### COCO
|
||||
|
||||
| Backbone | Arch | size | SyncBN | AMP | Mem (GB) | box AP | Config | Download |
|
||||
| :------: | :--: | :--: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOv6-n | P5 | 640 | Yes | Yes | 6.04 | 36.2 | [config](../yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726-d99b2e82.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_n_syncbn_fast_8xb32-400e_coco/yolov6_n_syncbn_fast_8xb32-400e_coco_20221030_202726.log.json) |
|
||||
| YOLOv6-t | P5 | 640 | Yes | Yes | 8.13 | 41.0 | [config](../yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755-cf0d278f.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_t_syncbn_fast_8xb32-400e_coco/yolov6_t_syncbn_fast_8xb32-400e_coco_20221030_143755.log.json) |
|
||||
| YOLOv6-s | P5 | 640 | Yes | Yes | 8.88 | 43.7 | [config](../yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221030_202704-2ba343db.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221030_202704.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. We don't support training just yet. But you can use the `tools/model_converters/yolov6_to_mmyolo.py` script to convert the official weight.
|
||||
1. The performance is unstable and may fluctuate by about 0.3 mAP.
|
||||
2. YOLOv6-m,l,x will be supported in later version.
|
||||
3. If users need the weight of 300 epoch, they can train according to the configs of 300 epoch provided by us, or convert the official weight according to the [converter script](../../tools/model_converters/).
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.25
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(
|
||||
head_module=dict(widen_factor=widen_factor),
|
||||
loss_bbox=dict(iou_mode='siou')))
|
||||
|
||||
default_hooks = dict(param_scheduler=dict(lr_factor=0.02))
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.25
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(
|
||||
head_module=dict(widen_factor=widen_factor),
|
||||
loss_bbox=dict(iou_mode='siou')))
|
||||
|
||||
default_hooks = dict(param_scheduler=dict(lr_factor=0.02))
|
|
@ -1,39 +0,0 @@
|
|||
# Training mode is currently not supported
|
||||
|
||||
_base_ = '../yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
max_epochs = 400
|
||||
train_batch_size_per_gpu = 32
|
||||
|
||||
deepen_factor = _base_.deepen_factor
|
||||
widen_factor = _base_.widen_factor
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
_delete_=True,
|
||||
type='YOLOv6EfficientRep',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='ReLU', inplace=True)),
|
||||
neck=dict(
|
||||
_delete_=True,
|
||||
type='YOLOv6RepPAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[128, 256, 512],
|
||||
num_csp_blocks=12,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
bbox_head=dict(
|
||||
_delete_=True,
|
||||
type='YOLOv6Head',
|
||||
head_module=dict(
|
||||
type='YOLOv6HeadModule',
|
||||
num_classes=80,
|
||||
in_channels=[128, 256, 512],
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True),
|
||||
featmap_strides=[8, 16, 32])),
|
||||
train_cfg=None)
|
|
@ -0,0 +1,29 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
|
||||
|
||||
max_epochs = 300
|
||||
num_last_epochs = 15
|
||||
|
||||
default_hooks = dict(
|
||||
param_scheduler=dict(
|
||||
type='YOLOv5ParamSchedulerHook',
|
||||
scheduler_type='cosine',
|
||||
lr_factor=0.01,
|
||||
max_epochs=max_epochs))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0001,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=max_epochs - num_last_epochs,
|
||||
switch_pipeline=_base_.train_pipeline_stage2)
|
||||
]
|
||||
|
||||
train_cfg = dict(
|
||||
max_epochs=max_epochs,
|
||||
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
|
|
@ -0,0 +1,250 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
dataset_type = 'YOLOv5CocoDataset'
|
||||
|
||||
num_last_epochs = 15
|
||||
max_epochs = 400
|
||||
num_classes = 80
|
||||
|
||||
# parameters that often need to be modified
|
||||
img_scale = (640, 640) # height, width
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.5
|
||||
save_epoch_intervals = 10
|
||||
train_batch_size_per_gpu = 32
|
||||
train_num_workers = 8
|
||||
val_batch_size_per_gpu = 1
|
||||
val_num_workers = 2
|
||||
|
||||
# persistent_workers must be False if num_workers is 0.
|
||||
persistent_workers = True
|
||||
|
||||
# only on Val
|
||||
batch_shapes_cfg = dict(
|
||||
type='BatchShapePolicy',
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
img_size=img_scale[0],
|
||||
size_divisor=32,
|
||||
extra_pad_ratio=0.5)
|
||||
|
||||
# single-scale training is recommended to
|
||||
# be turned on, which can speed up training.
|
||||
env_cfg = dict(cudnn_benchmark=True)
|
||||
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
data_preprocessor=dict(
|
||||
type='YOLOv5DetDataPreprocessor',
|
||||
mean=[0., 0., 0.],
|
||||
std=[255., 255., 255.],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='YOLOv6EfficientRep',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='ReLU', inplace=True)),
|
||||
neck=dict(
|
||||
type='YOLOv6RepPAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[128, 256, 512],
|
||||
num_csp_blocks=12,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
bbox_head=dict(
|
||||
type='YOLOv6Head',
|
||||
head_module=dict(
|
||||
type='YOLOv6HeadModule',
|
||||
num_classes=num_classes,
|
||||
in_channels=[128, 256, 512],
|
||||
widen_factor=widen_factor,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True),
|
||||
featmap_strides=[8, 16, 32]),
|
||||
loss_bbox=dict(
|
||||
type='IoULoss',
|
||||
iou_mode='giou',
|
||||
bbox_format='xyxy',
|
||||
reduction='mean',
|
||||
loss_weight=2.5,
|
||||
return_iou=False)),
|
||||
train_cfg=dict(
|
||||
initial_epoch=4,
|
||||
initial_assigner=dict(
|
||||
type='BatchATSSAssigner',
|
||||
num_classes=num_classes,
|
||||
topk=9,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||
assigner=dict(
|
||||
type='BatchTaskAlignedAssigner',
|
||||
num_classes=num_classes,
|
||||
topk=13,
|
||||
alpha=1,
|
||||
beta=6),
|
||||
),
|
||||
test_cfg=dict(
|
||||
multi_label=True,
|
||||
nms_pre=30000,
|
||||
score_thr=0.001,
|
||||
nms=dict(type='nms', iou_threshold=0.65),
|
||||
max_per_img=300))
|
||||
|
||||
# The training pipeline of YOLOv6 is basically the same as YOLOv5.
|
||||
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa
|
||||
pre_transform = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
*pre_transform,
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=img_scale,
|
||||
pad_val=114.0,
|
||||
pre_transform=pre_transform),
|
||||
dict(
|
||||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_translate_ratio=0.1,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2),
|
||||
border_val=(114, 114, 114),
|
||||
max_shear_degree=0.0),
|
||||
dict(type='YOLOv5HSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
|
||||
train_pipeline_stage2 = [
|
||||
*pre_transform,
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=True,
|
||||
pad_val=dict(img=114)),
|
||||
dict(
|
||||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_translate_ratio=0.1,
|
||||
scaling_ratio_range=(0.5, 1.5),
|
||||
max_shear_degree=0.0,
|
||||
),
|
||||
dict(type='YOLOv5HSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
collate_fn=dict(type='yolov5_collate'),
|
||||
persistent_workers=persistent_workers,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/instances_train2017.json',
|
||||
data_prefix=dict(img='train2017/'),
|
||||
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
num_workers=val_num_workers,
|
||||
persistent_workers=persistent_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
test_mode=True,
|
||||
data_prefix=dict(img='val2017/'),
|
||||
ann_file='annotations/instances_val2017.json',
|
||||
pipeline=test_pipeline,
|
||||
batch_shapes_cfg=batch_shapes_cfg))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa
|
||||
# The difference is that the scheduler_type of YOLOv6 is cosine.
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='SGD',
|
||||
lr=0.01,
|
||||
momentum=0.937,
|
||||
weight_decay=0.0005,
|
||||
nesterov=True,
|
||||
batch_size_per_gpu=train_batch_size_per_gpu),
|
||||
constructor='YOLOv5OptimizerConstructor')
|
||||
|
||||
default_hooks = dict(
|
||||
param_scheduler=dict(
|
||||
type='YOLOv5ParamSchedulerHook',
|
||||
scheduler_type='cosine',
|
||||
lr_factor=0.01,
|
||||
max_epochs=max_epochs),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
interval=save_epoch_intervals,
|
||||
max_keep_ckpts=3,
|
||||
save_best='auto'))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0001,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=max_epochs - num_last_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
||||
val_evaluator = dict(
|
||||
type='mmdet.CocoMetric',
|
||||
proposal_nums=(100, 1, 10),
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
metric='bbox')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop',
|
||||
max_epochs=max_epochs,
|
||||
val_interval=save_epoch_intervals,
|
||||
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.375
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(
|
||||
type='YOLOv6Head',
|
||||
head_module=dict(widen_factor=widen_factor),
|
||||
loss_bbox=dict(iou_mode='siou')))
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.375
|
||||
|
||||
model = dict(
|
||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
||||
bbox_head=dict(
|
||||
type='YOLOv6Head',
|
||||
head_module=dict(widen_factor=widen_factor),
|
||||
loss_bbox=dict(iou_mode='siou')))
|
|
@ -11,7 +11,7 @@ from mmengine.logging import print_log
|
|||
from mmengine.utils import ProgressBar, scandir
|
||||
|
||||
from mmyolo.registry import VISUALIZERS
|
||||
from mmyolo.utils import register_all_modules
|
||||
from mmyolo.utils import register_all_modules, switch_to_deploy
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
||||
'.tiff', '.webp')
|
||||
|
@ -30,7 +30,11 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show the detection results')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='bbox score threshold')
|
||||
'--deploy',
|
||||
action='store_true',
|
||||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -42,6 +46,9 @@ def main(args):
|
|||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
|
||||
if args.deploy:
|
||||
switch_to_deploy(model)
|
||||
|
||||
# init visualizer
|
||||
visualizer = VISUALIZERS.build(model.cfg.visualizer)
|
||||
visualizer.dataset_meta = model.dataset_meta
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .switch_to_deploy_hook import SwitchToDeployHook
|
||||
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
|
||||
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
||||
|
||||
__all__ = ['YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook']
|
||||
__all__ = [
|
||||
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmyolo.registry import HOOKS
|
||||
from mmyolo.utils import switch_to_deploy
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SwitchToDeployHook(Hook):
|
||||
"""Switch to deploy mode before testing.
|
||||
|
||||
This hook converts the multi-channel structure of the training network
|
||||
(high performance) to the one-way structure of the testing network (fast
|
||||
speed and memory saving).
|
||||
"""
|
||||
|
||||
def before_test_epoch(self, runner: Runner):
|
||||
switch_to_deploy(runner.model)
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Union
|
||||
from typing import Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -7,11 +7,13 @@ from mmcv.cnn import ConvModule
|
|||
from mmdet.models.utils import multi_apply
|
||||
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
||||
OptMultiConfig)
|
||||
from mmengine import MessageHub
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.model import BaseModule, bias_init_with_prob
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
from mmyolo.registry import MODELS, TASK_UTILS
|
||||
from ..utils import make_divisible
|
||||
from .yolov5_head import YOLOv5Head
|
||||
|
||||
|
@ -72,19 +74,6 @@ class YOLOv6HeadModule(BaseModule):
|
|||
|
||||
self._init_layers()
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights of the head."""
|
||||
# Use prior in model initialization to improve stability
|
||||
super().init_weights()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
m.reset_parameters()
|
||||
|
||||
bias_init = bias_init_with_prob(0.01)
|
||||
for conv_cls in self.cls_preds:
|
||||
conv_cls.bias.data.fill_(bias_init)
|
||||
|
||||
def _init_layers(self):
|
||||
"""initialize conv layers in YOLOv6 head."""
|
||||
# Init decouple head
|
||||
|
@ -132,7 +121,18 @@ class YOLOv6HeadModule(BaseModule):
|
|||
out_channels=self.num_base_priors * 4,
|
||||
kernel_size=1))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
bias_init = bias_init_with_prob(0.01)
|
||||
for conv in self.cls_preds:
|
||||
conv.bias.data.fill_(bias_init)
|
||||
conv.weight.data.fill_(0.)
|
||||
|
||||
for conv in self.reg_preds:
|
||||
conv.bias.data.fill_(1.0)
|
||||
conv.weight.data.fill_(0.)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward features from the upstream network.
|
||||
|
||||
Args:
|
||||
|
@ -146,10 +146,10 @@ class YOLOv6HeadModule(BaseModule):
|
|||
return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
|
||||
self.cls_preds, self.reg_convs, self.reg_preds)
|
||||
|
||||
def forward_single(self, x: torch.Tensor, stem: nn.ModuleList,
|
||||
def forward_single(self, x: Tensor, stem: nn.ModuleList,
|
||||
cls_conv: nn.ModuleList, cls_pred: nn.ModuleList,
|
||||
reg_conv: nn.ModuleList,
|
||||
reg_pred: nn.ModuleList) -> torch.Tensor:
|
||||
reg_pred: nn.ModuleList) -> Tuple[Tensor, Tensor]:
|
||||
"""Forward feature of a single scale level."""
|
||||
y = stem(x)
|
||||
cls_x = y
|
||||
|
@ -192,12 +192,20 @@ class YOLOv6Head(YOLOv5Head):
|
|||
strides=[8, 16, 32]),
|
||||
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
||||
loss_cls: ConfigType = dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
type='mmdet.VarifocalLoss',
|
||||
use_sigmoid=True,
|
||||
alpha=0.75,
|
||||
gamma=2.0,
|
||||
iou_weighted=True,
|
||||
reduction='sum',
|
||||
loss_weight=1.0),
|
||||
loss_bbox: ConfigType = dict(
|
||||
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
|
||||
type='IoULoss',
|
||||
iou_mode='giou',
|
||||
bbox_format='xyxy',
|
||||
reduction='mean',
|
||||
loss_weight=2.5,
|
||||
return_iou=False),
|
||||
loss_obj: ConfigType = dict(
|
||||
type='mmdet.CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
|
@ -217,13 +225,27 @@ class YOLOv6Head(YOLOv5Head):
|
|||
test_cfg=test_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.loss_bbox = MODELS.build(loss_bbox)
|
||||
self.loss_cls = MODELS.build(loss_cls)
|
||||
|
||||
def special_init(self):
|
||||
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
||||
different algorithms have special initialization process.
|
||||
|
||||
The special_init function is designed to deal with this situation.
|
||||
"""
|
||||
pass
|
||||
if self.train_cfg:
|
||||
self.initial_epoch = self.train_cfg['initial_epoch']
|
||||
self.initial_assigner = TASK_UTILS.build(
|
||||
self.train_cfg.initial_assigner)
|
||||
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
||||
|
||||
# Add common attributes to reduce calculation
|
||||
self.featmap_sizes = None
|
||||
self.mlvl_priors = None
|
||||
self.num_level_priors = None
|
||||
self.flatten_priors = None
|
||||
self.stride_tensor = None
|
||||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
|
@ -254,4 +276,148 @@ class YOLOv6Head(YOLOv5Head):
|
|||
Returns:
|
||||
dict[str, Tensor]: A dictionary of losses.
|
||||
"""
|
||||
raise NotImplementedError('Not implemented yet!')
|
||||
|
||||
# get epoch information from message hub
|
||||
message_hub = MessageHub.get_current_instance()
|
||||
current_epoch = message_hub.get_info('epoch')
|
||||
|
||||
num_imgs = len(batch_img_metas)
|
||||
if batch_gt_instances_ignore is None:
|
||||
batch_gt_instances_ignore = [None] * num_imgs
|
||||
|
||||
current_featmap_sizes = [
|
||||
cls_score.shape[2:] for cls_score in cls_scores
|
||||
]
|
||||
# If the shape does not equal, generate new one
|
||||
if current_featmap_sizes != self.featmap_sizes:
|
||||
self.featmap_sizes = current_featmap_sizes
|
||||
|
||||
self.mlvl_priors = self.prior_generator.grid_priors(
|
||||
self.featmap_sizes,
|
||||
dtype=cls_scores[0].dtype,
|
||||
device=cls_scores[0].device,
|
||||
with_stride=True)
|
||||
|
||||
self.num_level_priors = [len(n) for n in self.mlvl_priors]
|
||||
self.flatten_priors = torch.cat(self.mlvl_priors, dim=0)
|
||||
self.stride_tensor = self.flatten_priors[..., [2]]
|
||||
|
||||
# gt info
|
||||
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||
gt_labels = gt_info[:, :, :1]
|
||||
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
||||
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
||||
|
||||
# pred info
|
||||
flatten_cls_preds = [
|
||||
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
self.num_classes)
|
||||
for cls_pred in cls_scores
|
||||
]
|
||||
|
||||
flatten_pred_bboxes = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
|
||||
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
||||
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
||||
flatten_pred_bboxes = self.bbox_coder.decode(
|
||||
self.flatten_priors[..., :2], flatten_pred_bboxes,
|
||||
self.flatten_priors[..., 2])
|
||||
pred_scores = torch.sigmoid(flatten_cls_preds)
|
||||
|
||||
if current_epoch < self.initial_epoch:
|
||||
assigned_result = self.initial_assigner(
|
||||
flatten_pred_bboxes.detach(), self.flatten_priors,
|
||||
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
||||
else:
|
||||
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
||||
pred_scores.detach(),
|
||||
self.flatten_priors, gt_labels,
|
||||
gt_bboxes, pad_bbox_flag)
|
||||
|
||||
assigned_bboxes = assigned_result['assigned_bboxes']
|
||||
assigned_scores = assigned_result['assigned_scores']
|
||||
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
|
||||
|
||||
# cls loss
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
|
||||
|
||||
# rescale bbox
|
||||
assigned_bboxes /= self.stride_tensor
|
||||
flatten_pred_bboxes /= self.stride_tensor
|
||||
|
||||
# TODO: Add all_reduce makes training more stable
|
||||
assigned_scores_sum = assigned_scores.sum()
|
||||
if assigned_scores_sum > 0:
|
||||
loss_cls /= assigned_scores_sum
|
||||
|
||||
# select positive samples mask
|
||||
num_pos = fg_mask_pre_prior.sum()
|
||||
if num_pos > 0:
|
||||
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
|
||||
# will not report an error
|
||||
# iou loss
|
||||
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
|
||||
pred_bboxes_pos = torch.masked_select(
|
||||
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
|
||||
assigned_bboxes_pos = torch.masked_select(
|
||||
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
|
||||
bbox_weight = torch.masked_select(
|
||||
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
|
||||
loss_bbox = self.loss_bbox(
|
||||
pred_bboxes_pos,
|
||||
assigned_bboxes_pos,
|
||||
weight=bbox_weight,
|
||||
avg_factor=assigned_scores_sum)
|
||||
else:
|
||||
loss_bbox = flatten_pred_bboxes.sum() * 0
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
return dict(
|
||||
loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
|
||||
|
||||
@staticmethod
|
||||
def gt_instances_preprocess(batch_gt_instances: Tensor,
|
||||
batch_size: int) -> Tensor:
|
||||
"""Split batch_gt_instances with batch size, from [all_gt_bboxes, 6]
|
||||
to.
|
||||
|
||||
[batch_size, number_gt, 5]. If some shape of single batch smaller than
|
||||
gt bbox len, then using [-1., 0., 0., 0., 0.] to fill.
|
||||
|
||||
Args:
|
||||
batch_gt_instances (Sequence[Tensor]): Ground truth
|
||||
instances for whole batch, shape [all_gt_bboxes, 6]
|
||||
batch_size (int): Batch size.
|
||||
|
||||
Returns:
|
||||
Tensor: batch gt instances data, shape [batch_size, number_gt, 5]
|
||||
"""
|
||||
|
||||
# sqlit batch gt instance [all_gt_bboxes, 6] ->
|
||||
# [batch_size, number_gt_each_batch, 5]
|
||||
batch_instance_list = []
|
||||
max_gt_bbox_len = 0
|
||||
for i in range(batch_size):
|
||||
single_batch_instance = \
|
||||
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
||||
single_batch_instance = single_batch_instance[:, 1:]
|
||||
batch_instance_list.append(single_batch_instance)
|
||||
if len(single_batch_instance) > max_gt_bbox_len:
|
||||
max_gt_bbox_len = len(single_batch_instance)
|
||||
|
||||
# fill [-1., 0., 0., 0., 0.] if some shape of
|
||||
# single batch not equal max_gt_bbox_len
|
||||
for index, gt_instance in enumerate(batch_instance_list):
|
||||
if gt_instance.shape[0] >= max_gt_bbox_len:
|
||||
continue
|
||||
fill_tensor = batch_gt_instances.new_full(
|
||||
[max_gt_bbox_len - gt_instance.shape[0], 5], 0)
|
||||
fill_tensor[:, 0] = -1.
|
||||
batch_instance_list[index] = torch.cat(
|
||||
(batch_instance_list[index], fill_tensor), dim=0)
|
||||
|
||||
return torch.stack(batch_instance_list)
|
||||
|
|
|
@ -294,7 +294,7 @@ class RepVGGBlock(nn.Module):
|
|||
"""
|
||||
if branch is None:
|
||||
return 0, 0
|
||||
if isinstance(branch, nn.Sequential):
|
||||
if isinstance(branch, ConvModule):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
|
@ -302,7 +302,7 @@ class RepVGGBlock(nn.Module):
|
|||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d))
|
||||
if not hasattr(self, 'id_tensor'):
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
|
||||
|
|
|
@ -10,18 +10,17 @@ from mmdet.structures.bbox import HorizontalBoxes
|
|||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
# TODO: unify all code
|
||||
def bbox_overlaps(pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
iou_mode: str = 'ciou',
|
||||
bbox_format: str = 'xywh',
|
||||
is_aligned: bool = False,
|
||||
siou_theta: float = 4.0,
|
||||
eps: float = 1e-7) -> torch.Tensor:
|
||||
r"""Calculate overlap between two set of bboxes.
|
||||
`Implementation of paper `Enhancing Geometric Factors into
|
||||
Model Learning and Inference for Object Detection and Instance
|
||||
Segmentation <https://arxiv.org/abs/2005.03572>`_.
|
||||
In the CIoU implementation of YOLOv5 and mmdetection, there is a slight
|
||||
In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
|
||||
difference in the way the alpha parameter is computed.
|
||||
mmdet version:
|
||||
alpha = (ious > 0.5).float() * v / (1 - ious + v)
|
||||
|
@ -35,27 +34,36 @@ def bbox_overlaps(pred: torch.Tensor,
|
|||
Defaults to "ciou".
|
||||
bbox_format (str): Options are "xywh" and "xyxy".
|
||||
Defaults to "xywh".
|
||||
is_aligned (bool):
|
||||
siou_theta (float): siou_theta for SIoU when calculate shape cost.
|
||||
Defaults to 4.0.
|
||||
eps (float): Eps to avoid log(0).
|
||||
Returns:
|
||||
Tensor: shape (n,).
|
||||
"""
|
||||
assert iou_mode in ('ciou', )
|
||||
assert iou_mode in ('ciou', 'giou', 'siou')
|
||||
assert bbox_format in ('xyxy', 'xywh')
|
||||
if bbox_format == 'xywh':
|
||||
pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
|
||||
target = HorizontalBoxes.cxcywh_to_xyxy(target)
|
||||
|
||||
# overlap
|
||||
lt = torch.max(pred[:, :2], target[:, :2])
|
||||
rb = torch.min(pred[:, 2:], target[:, 2:])
|
||||
wh = (rb - lt).clamp(min=0)
|
||||
overlap = wh[:, 0] * wh[:, 1]
|
||||
bbox1_x1, bbox1_y1 = pred[:, 0], pred[:, 1]
|
||||
bbox1_x2, bbox1_y2 = pred[:, 2], pred[:, 3]
|
||||
bbox2_x1, bbox2_y1 = target[:, 0], target[:, 1]
|
||||
bbox2_x2, bbox2_y2 = target[:, 2], target[:, 3]
|
||||
|
||||
# union
|
||||
ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
|
||||
ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
|
||||
union = ap + ag - overlap + eps
|
||||
# Overlap
|
||||
overlap = (torch.min(bbox1_x2, bbox2_x2) -
|
||||
torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
|
||||
(torch.min(bbox1_y2, bbox2_y2) -
|
||||
torch.max(bbox1_y1, bbox2_y1)).clamp(0)
|
||||
|
||||
# Union
|
||||
w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
|
||||
w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
|
||||
union = (w1 * h1) + (w2 * h2) - overlap + eps
|
||||
|
||||
h1 = bbox1_y2 - bbox1_y1 + eps
|
||||
h2 = bbox2_y2 - bbox2_y1 + eps
|
||||
|
||||
# IoU
|
||||
ious = overlap / union
|
||||
|
@ -65,32 +73,78 @@ def bbox_overlaps(pred: torch.Tensor,
|
|||
enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
|
||||
enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
|
||||
|
||||
cw = enclose_wh[:, 0]
|
||||
ch = enclose_wh[:, 1]
|
||||
enclose_w = enclose_wh[:, 0] # cw
|
||||
enclose_h = enclose_wh[:, 1] # ch
|
||||
|
||||
c2 = cw**2 + ch**2 + eps
|
||||
if iou_mode == 'ciou':
|
||||
# CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
|
||||
|
||||
b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
|
||||
b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
|
||||
b2_x1, b2_y1 = target[:, 0], target[:, 1]
|
||||
b2_x2, b2_y2 = target[:, 2], target[:, 3]
|
||||
# calculate enclose area (c^2)
|
||||
enclose_area = enclose_w**2 + enclose_h**2 + eps
|
||||
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
# calculate ρ^2(b_pred,b_gt):
|
||||
# euclidean distance between b_pred(bbox2) and b_gt(bbox1)
|
||||
# center point, because bbox format is xyxy -> left-top xy and
|
||||
# right-bottom xy, so need to / 4 to get center point.
|
||||
rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
|
||||
rho2_right_item = ((bbox2_y1 + bbox2_y2) -
|
||||
(bbox1_y1 + bbox1_y2))**2 / 4
|
||||
rho2 = rho2_left_item + rho2_right_item # rho^2 (ρ^2)
|
||||
|
||||
left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
|
||||
right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
|
||||
rho2 = left + right
|
||||
# Width and height ratio (v)
|
||||
wh_ratio = (4 / (math.pi**2)) * torch.pow(
|
||||
torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
|
||||
factor = 4 / math.pi**2
|
||||
v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
with torch.no_grad():
|
||||
alpha = wh_ratio / (wh_ratio - ious + (1 + eps))
|
||||
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - ious + (1 + eps))
|
||||
# CIoU
|
||||
ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))
|
||||
|
||||
# CIoU
|
||||
cious = ious - (rho2 / c2 + alpha * v)
|
||||
return cious.clamp(min=-1.0, max=1.0)
|
||||
elif iou_mode == 'giou':
|
||||
# GIoU = IoU - ( (A_c - union) / A_c )
|
||||
convex_area = enclose_w * enclose_h + eps # convex area (A_c)
|
||||
ious = ious - (convex_area - union) / convex_area
|
||||
|
||||
elif iou_mode == 'siou':
|
||||
# SIoU: https://arxiv.org/pdf/2205.12740.pdf
|
||||
# SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 )
|
||||
|
||||
# calculate sigma (σ):
|
||||
# euclidean distance between bbox2(pred) and bbox1(gt) center point,
|
||||
# sigma_cw = b_cx_gt - b_cx
|
||||
sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps
|
||||
# sigma_ch = b_cy_gt - b_cy
|
||||
sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps
|
||||
# sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) )
|
||||
sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5)
|
||||
|
||||
# choose minimize alpha, sin(alpha)
|
||||
sin_alpha = torch.abs(sigma_ch) / sigma
|
||||
sin_beta = torch.abs(sigma_cw) / sigma
|
||||
sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha,
|
||||
sin_beta)
|
||||
|
||||
# Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) )
|
||||
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
|
||||
|
||||
# Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t))
|
||||
rho_x = (sigma_cw / enclose_w)**2 # ρ_x
|
||||
rho_y = (sigma_ch / enclose_h)**2 # ρ_y
|
||||
gamma = 2 - angle_cost # γ
|
||||
distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + (
|
||||
1 - torch.exp(-1 * gamma * rho_y))
|
||||
|
||||
# Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ )
|
||||
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2) # ω_w
|
||||
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2) # ω_h
|
||||
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w),
|
||||
siou_theta) + torch.pow(
|
||||
1 - torch.exp(-1 * omiga_h), siou_theta)
|
||||
|
||||
ious = ious - ((distance_cost + shape_cost) * 0.5)
|
||||
|
||||
return ious.clamp(min=-1.0, max=1.0)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -118,7 +172,7 @@ class IoULoss(nn.Module):
|
|||
return_iou: bool = True):
|
||||
super().__init__()
|
||||
assert bbox_format in ('xywh', 'xyxy')
|
||||
assert iou_mode in ('ciou', )
|
||||
assert iou_mode in ('ciou', 'siou', 'giou')
|
||||
self.iou_mode = iou_mode
|
||||
self.bbox_format = bbox_format
|
||||
self.eps = eps
|
||||
|
@ -131,7 +185,7 @@ class IoULoss(nn.Module):
|
|||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
avg_factor: Optional[str] = None,
|
||||
avg_factor: Optional[float] = None,
|
||||
reduction_override: Optional[Union[str, bool]] = None
|
||||
) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Forward function.
|
||||
|
@ -155,11 +209,8 @@ class IoULoss(nn.Module):
|
|||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
|
||||
if weight is not None and weight.dim() > 1:
|
||||
# TODO: remove this in the future
|
||||
# reduce the weight of shape (n, 4) to (n,) to match the
|
||||
# giou_loss of shape (n,)
|
||||
assert weight.shape == pred.shape
|
||||
weight = weight.mean(-1)
|
||||
|
||||
iou = bbox_overlaps(
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigners import BatchATSSAssigner, BatchTaskAlignedAssigner
|
||||
from .coders import YOLOv5BBoxCoder, YOLOXBBoxCoder
|
||||
|
||||
__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder']
|
||||
__all__ = [
|
||||
'YOLOv5BBoxCoder', 'YOLOXBBoxCoder', 'BatchATSSAssigner',
|
||||
'BatchTaskAlignedAssigner'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .batch_atss_assigner import BatchATSSAssigner
|
||||
from .batch_task_aligned_assigner import BatchTaskAlignedAssigner
|
||||
from .utils import (select_candidates_in_gts, select_highest_overlaps,
|
||||
yolov6_iou_calculator)
|
||||
|
||||
__all__ = [
|
||||
'BatchATSSAssigner', 'BatchTaskAlignedAssigner',
|
||||
'select_candidates_in_gts', 'select_highest_overlaps',
|
||||
'yolov6_iou_calculator'
|
||||
]
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmdet.utils import ConfigType
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import TASK_UTILS
|
||||
from .utils import (select_candidates_in_gts, select_highest_overlaps,
|
||||
yolov6_iou_calculator)
|
||||
|
||||
|
||||
def bbox_center_distance(bboxes: Tensor,
|
||||
priors: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute the center distance between bboxes and priors.
|
||||
|
||||
Args:
|
||||
bboxes (Tensor): Shape (n, 4) for bbox, "xyxy" format.
|
||||
priors (Tensor): Shape (num_priors, 4) for priors, "xyxy" format.
|
||||
|
||||
Returns:
|
||||
distances (Tensor): Center distances between bboxes and priors,
|
||||
shape (num_priors, n).
|
||||
priors_points (Tensor): Priors cx cy points,
|
||||
shape (num_priors, 2).
|
||||
"""
|
||||
bbox_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
|
||||
bbox_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
|
||||
bbox_points = torch.stack((bbox_cx, bbox_cy), dim=1)
|
||||
|
||||
priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0
|
||||
priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0
|
||||
priors_points = torch.stack((priors_cx, priors_cy), dim=1)
|
||||
|
||||
distances = (bbox_points[:, None, :] -
|
||||
priors_points[None, :, :]).pow(2).sum(-1).sqrt()
|
||||
|
||||
return distances, priors_points
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchATSSAssigner(nn.Module):
|
||||
"""Assign a batch of corresponding gt bboxes or background to each prior.
|
||||
|
||||
This code is based on
|
||||
https://github.com/meituan/YOLOv6/blob/main/yolov6/assigners/atss_assigner.py
|
||||
|
||||
Each proposal will be assigned with `0` or a positive integer
|
||||
indicating the ground truth index.
|
||||
|
||||
- 0: negative sample, no assigned gt
|
||||
- positive integer: positive sample, index (1-based) of assigned gt
|
||||
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou
|
||||
calculator. Defaults to ``dict(type='BboxOverlaps2D')``
|
||||
topk (int): number of priors selected in each level
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'),
|
||||
topk: int = 9):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.iou_calculator = TASK_UTILS.build(iou_calculator)
|
||||
self.topk = topk
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pred_bboxes: Tensor, priors: Tensor,
|
||||
num_level_priors: List, gt_labels: Tensor, gt_bboxes: Tensor,
|
||||
pad_bbox_flag: Tensor) -> dict:
|
||||
"""Assign gt to priors.
|
||||
|
||||
The assignment is done in following steps
|
||||
|
||||
1. compute iou between all prior (prior of all pyramid levels) and gt
|
||||
2. compute center distance between all prior and gt
|
||||
3. on each pyramid level, for each gt, select k prior whose center
|
||||
are closest to the gt center, so we total select k*l prior as
|
||||
candidates for each gt
|
||||
4. get corresponding iou for the these candidates, and compute the
|
||||
mean and std, set mean + std as the iou threshold
|
||||
5. select these candidates whose iou are greater than or equal to
|
||||
the threshold as positive
|
||||
6. limit the positive sample's center in gt
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predicted bounding boxes,
|
||||
shape(batch_size, num_priors, 4)
|
||||
priors (Tensor): Model priors, shape(num_priors, 4)
|
||||
num_level_priors (List): Number of bboxes in each level, len(3)
|
||||
gt_labels (Tensor): Ground truth label,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground truth bbox,
|
||||
shape(batch_size, num_gt, 4)
|
||||
pad_bbox_flag (Tensor): Ground truth bbox mask,
|
||||
1 means bbox, 0 means no bbox,
|
||||
shape(batch_size, num_gt, 1)
|
||||
Returns:
|
||||
assigned_result (dict): Assigned result
|
||||
'assigned_labels' (Tensor): shape(batch_size, num_gt)
|
||||
'assigned_bboxes' (Tensor): shape(batch_size, num_gt, 4)
|
||||
'assigned_scores' (Tensor):
|
||||
shape(batch_size, num_gt, number_classes)
|
||||
'fg_mask_pre_prior' (Tensor): shape(bs, num_gt)
|
||||
"""
|
||||
# generate priors
|
||||
cell_half_size = priors[:, 2:] * 2.5
|
||||
priors_gen = torch.zeros_like(priors)
|
||||
priors_gen[:, :2] = priors[:, :2] - cell_half_size
|
||||
priors_gen[:, 2:] = priors[:, :2] + cell_half_size
|
||||
priors = priors_gen
|
||||
|
||||
batch_size = gt_bboxes.size(0)
|
||||
num_gt, num_priors = gt_bboxes.size(1), priors.size(0)
|
||||
|
||||
assigned_result = {
|
||||
'assigned_labels':
|
||||
gt_bboxes.new_full([batch_size, num_priors], self.num_classes),
|
||||
'assigned_bboxes':
|
||||
gt_bboxes.new_full([batch_size, num_priors, 4], 0),
|
||||
'assigned_scores':
|
||||
gt_bboxes.new_full([batch_size, num_priors, self.num_classes], 0),
|
||||
'fg_mask_pre_prior':
|
||||
gt_bboxes.new_full([batch_size, num_priors], 0)
|
||||
}
|
||||
|
||||
if num_gt == 0:
|
||||
return assigned_result
|
||||
|
||||
# compute iou between all prior (prior of all pyramid levels) and gt
|
||||
overlaps = self.iou_calculator(gt_bboxes.reshape([-1, 4]), priors)
|
||||
overlaps = overlaps.reshape([batch_size, -1, num_priors])
|
||||
|
||||
# compute center distance between all prior and gt
|
||||
distances, priors_points = bbox_center_distance(
|
||||
gt_bboxes.reshape([-1, 4]), priors)
|
||||
distances = distances.reshape([batch_size, -1, num_priors])
|
||||
|
||||
# Selecting candidates based on the center distance
|
||||
is_in_candidate, candidate_idxs = self.select_topk_candidates(
|
||||
distances, num_level_priors, pad_bbox_flag)
|
||||
|
||||
# get corresponding iou for the these candidates, and compute the
|
||||
# mean and std, set mean + std as the iou threshold
|
||||
overlaps_thr_per_gt, iou_candidates = self.threshold_calculator(
|
||||
is_in_candidate, candidate_idxs, overlaps, num_priors, batch_size,
|
||||
num_gt)
|
||||
|
||||
# select candidates iou >= threshold as positive
|
||||
is_pos = torch.where(
|
||||
iou_candidates > overlaps_thr_per_gt.repeat([1, 1, num_priors]),
|
||||
is_in_candidate, torch.zeros_like(is_in_candidate))
|
||||
|
||||
is_in_gts = select_candidates_in_gts(priors_points, gt_bboxes)
|
||||
pos_mask = is_pos * is_in_gts * pad_bbox_flag
|
||||
|
||||
# if an anchor box is assigned to multiple gts,
|
||||
# the one with the highest IoU will be selected.
|
||||
gt_idx_pre_prior, fg_mask_pre_prior, pos_mask = \
|
||||
select_highest_overlaps(pos_mask, overlaps, num_gt)
|
||||
|
||||
# assigned target
|
||||
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
|
||||
gt_labels, gt_bboxes, gt_idx_pre_prior, fg_mask_pre_prior,
|
||||
num_priors, batch_size, num_gt)
|
||||
|
||||
# soft label with iou
|
||||
if pred_bboxes is not None:
|
||||
ious = yolov6_iou_calculator(gt_bboxes, pred_bboxes) * pos_mask
|
||||
ious = ious.max(axis=-2)[0].unsqueeze(-1)
|
||||
assigned_scores *= ious
|
||||
|
||||
assigned_result['assigned_labels'] = assigned_labels.long()
|
||||
assigned_result['assigned_bboxes'] = assigned_bboxes
|
||||
assigned_result['assigned_scores'] = assigned_scores
|
||||
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
|
||||
return assigned_result
|
||||
|
||||
def select_topk_candidates(self, distances: Tensor,
|
||||
num_level_priors: List[int],
|
||||
pad_bbox_flag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Selecting candidates based on the center distance.
|
||||
|
||||
Args:
|
||||
distances (Tensor): Distance between all bbox and gt,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
num_level_priors (List[int]): Number of bboxes in each level,
|
||||
len(3)
|
||||
pad_bbox_flag (Tensor): Ground truth bbox mask,
|
||||
shape(batch_size, num_gt, 1)
|
||||
|
||||
Return:
|
||||
is_in_candidate_list (Tensor): Flag show that each level have
|
||||
topk candidates or not, shape(batch_size, num_gt, num_priors)
|
||||
candidate_idxs (Tensor): Candidates index,
|
||||
shape(batch_size, num_gt, num_gt)
|
||||
"""
|
||||
is_in_candidate_list = []
|
||||
candidate_idxs = []
|
||||
start_idx = 0
|
||||
|
||||
distances_dtype = distances.dtype
|
||||
distances = torch.split(distances, num_level_priors, dim=-1)
|
||||
pad_bbox_flag = pad_bbox_flag.repeat(1, 1, self.topk).bool()
|
||||
|
||||
for distances_per_level, priors_per_level in zip(
|
||||
distances, num_level_priors):
|
||||
# on each pyramid level, for each gt,
|
||||
# select k bbox whose center are closest to the gt center
|
||||
end_index = start_idx + priors_per_level
|
||||
selected_k = min(self.topk, priors_per_level)
|
||||
|
||||
_, topk_idxs_per_level = distances_per_level.topk(
|
||||
selected_k, dim=-1, largest=False)
|
||||
candidate_idxs.append(topk_idxs_per_level + start_idx)
|
||||
|
||||
topk_idxs_per_level = torch.where(
|
||||
pad_bbox_flag, topk_idxs_per_level,
|
||||
torch.zeros_like(topk_idxs_per_level))
|
||||
|
||||
is_in_candidate = F.one_hot(topk_idxs_per_level,
|
||||
priors_per_level).sum(dim=-2)
|
||||
is_in_candidate = torch.where(is_in_candidate > 1,
|
||||
torch.zeros_like(is_in_candidate),
|
||||
is_in_candidate)
|
||||
is_in_candidate_list.append(is_in_candidate.to(distances_dtype))
|
||||
|
||||
start_idx = end_index
|
||||
|
||||
is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)
|
||||
candidate_idxs = torch.cat(candidate_idxs, dim=-1)
|
||||
|
||||
return is_in_candidate_list, candidate_idxs
|
||||
|
||||
@staticmethod
|
||||
def threshold_calculator(is_in_candidate: List, candidate_idxs: Tensor,
|
||||
overlaps: Tensor, num_priors: int,
|
||||
batch_size: int,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Get corresponding iou for the these candidates, and compute the mean
|
||||
and std, set mean + std as the iou threshold.
|
||||
|
||||
Args:
|
||||
is_in_candidate (Tensor): Flag show that each level have
|
||||
topk candidates or not, shape(batch_size, num_gt, num_priors).
|
||||
candidate_idxs (Tensor): Candidates index,
|
||||
shape(batch_size, num_gt, num_gt)
|
||||
overlaps (Tensor): Overlaps area,
|
||||
shape(batch_size, num_gt, num_priors).
|
||||
num_priors (int): Number of priors.
|
||||
batch_size (int): Batch size.
|
||||
num_gt (int): Number of ground truth.
|
||||
|
||||
Return:
|
||||
overlaps_thr_per_gt (Tensor): Overlap threshold of
|
||||
per ground truth, shape(batch_size, num_gt, 1).
|
||||
candidate_overlaps (Tensor): Candidate overlaps,
|
||||
shape(batch_size, num_gt, num_priors).
|
||||
"""
|
||||
|
||||
batch_size_num_gt = batch_size * num_gt
|
||||
candidate_overlaps = torch.where(is_in_candidate > 0, overlaps,
|
||||
torch.zeros_like(overlaps))
|
||||
candidate_idxs = candidate_idxs.reshape([batch_size_num_gt, -1])
|
||||
|
||||
assist_indexes = num_priors * torch.arange(
|
||||
batch_size_num_gt, device=candidate_idxs.device)
|
||||
assist_indexes = assist_indexes[:, None]
|
||||
flatten_indexes = candidate_idxs + assist_indexes
|
||||
|
||||
candidate_overlaps_reshape = candidate_overlaps.reshape(
|
||||
-1)[flatten_indexes]
|
||||
candidate_overlaps_reshape = candidate_overlaps_reshape.reshape(
|
||||
[batch_size, num_gt, -1])
|
||||
|
||||
overlaps_mean_per_gt = candidate_overlaps_reshape.mean(
|
||||
axis=-1, keepdim=True)
|
||||
overlaps_std_per_gt = candidate_overlaps_reshape.std(
|
||||
axis=-1, keepdim=True)
|
||||
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
|
||||
|
||||
return overlaps_thr_per_gt, candidate_overlaps
|
||||
|
||||
def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor,
|
||||
assigned_gt_inds: Tensor, fg_mask_pre_prior: Tensor,
|
||||
num_priors: int, batch_size: int,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Get target info.
|
||||
|
||||
Args:
|
||||
gt_labels (Tensor): Ground true labels,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
assigned_gt_inds (Tensor): Assigned ground truth indexes,
|
||||
shape(batch_size, num_priors)
|
||||
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
|
||||
shape(batch_size, num_priors)
|
||||
num_priors (int): Number of priors.
|
||||
batch_size (int): Batch size.
|
||||
num_gt (int): Number of ground truth.
|
||||
|
||||
Return:
|
||||
assigned_labels (Tensor): Assigned labels,
|
||||
shape(batch_size, num_priors)
|
||||
assigned_bboxes (Tensor): Assigned bboxes,
|
||||
shape(batch_size, num_priors)
|
||||
assigned_scores (Tensor): Assigned scores,
|
||||
shape(batch_size, num_priors)
|
||||
"""
|
||||
|
||||
# assigned target labels
|
||||
batch_index = torch.arange(
|
||||
batch_size, dtype=gt_labels.dtype, device=gt_labels.device)
|
||||
batch_index = batch_index[..., None]
|
||||
assigned_gt_inds = (assigned_gt_inds + batch_index * num_gt).long()
|
||||
assigned_labels = gt_labels.flatten()[assigned_gt_inds.flatten()]
|
||||
assigned_labels = assigned_labels.reshape([batch_size, num_priors])
|
||||
assigned_labels = torch.where(
|
||||
fg_mask_pre_prior > 0, assigned_labels,
|
||||
torch.full_like(assigned_labels, self.num_classes))
|
||||
|
||||
# assigned target boxes
|
||||
assigned_bboxes = gt_bboxes.reshape([-1,
|
||||
4])[assigned_gt_inds.flatten()]
|
||||
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_priors, 4])
|
||||
|
||||
# assigned target scores
|
||||
assigned_scores = F.one_hot(assigned_labels.long(),
|
||||
self.num_classes + 1).float()
|
||||
assigned_scores = assigned_scores[:, :, :self.num_classes]
|
||||
|
||||
return assigned_labels, assigned_bboxes, assigned_scores
|
|
@ -0,0 +1,298 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.registry import TASK_UTILS
|
||||
from .utils import (select_candidates_in_gts, select_highest_overlaps,
|
||||
yolov6_iou_calculator)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchTaskAlignedAssigner(nn.Module):
|
||||
"""This code referenced to
|
||||
https://github.com/meituan/YOLOv6/blob/main/yolov6/
|
||||
assigners/tal_assigner.py.
|
||||
Batch Task aligned assigner base on the paper:
|
||||
`TOOD: Task-aligned One-stage Object Detection.
|
||||
<https://arxiv.org/abs/2108.07755>`_.
|
||||
Assign a corresponding gt bboxes or background to a batch of
|
||||
predicted bboxes. Each bbox will be assigned with `0` or a
|
||||
positive integer indicating the ground truth index.
|
||||
- 0: negative sample, no assigned gt
|
||||
- positive integer: positive sample, index (1-based) of assigned gt
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
topk (int): number of bbox selected in each level
|
||||
alpha (float): Hyper-parameters related to alignment_metrics.
|
||||
Defaults to 1.0
|
||||
beta (float): Hyper-parameters related to alignment_metrics.
|
||||
Defaults to 6.
|
||||
eps (float): Eps to avoid log(0). Default set to 1e-9
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
topk: int = 13,
|
||||
alpha: float = 1.0,
|
||||
beta: float = 6.0,
|
||||
eps: float = 1e-7):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.topk = topk
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
pred_bboxes: Tensor,
|
||||
pred_scores: Tensor,
|
||||
priors: Tensor,
|
||||
gt_labels: Tensor,
|
||||
gt_bboxes: Tensor,
|
||||
pad_bbox_flag: Tensor,
|
||||
) -> dict:
|
||||
"""Assign gt to bboxes.
|
||||
|
||||
The assignment is done in following steps
|
||||
1. compute alignment metric between all bbox (bbox of all pyramid
|
||||
levels) and gt
|
||||
2. select top-k bbox as candidates for each gt
|
||||
3. limit the positive sample's center in gt (because the anchor-free
|
||||
detector only can predict positive distance)
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predict bboxes,
|
||||
shape(batch_size, num_priors, 4)
|
||||
pred_scores (Tensor): Scores of predict bboxes,
|
||||
shape(batch_size, num_priors, num_classes)
|
||||
priors (Tensor): Model priors, shape (num_priors, 4)
|
||||
gt_labels (Tensor): Ground true labels,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
pad_bbox_flag (Tensor): Ground truth bbox mask,
|
||||
1 means bbox, 0 means no bbox,
|
||||
shape(batch_size, num_gt, 1)
|
||||
Returns:
|
||||
assigned_result (dict) Assigned result:
|
||||
assigned_labels (Tensor): Assigned labels,
|
||||
shape(batch_size, num_priors)
|
||||
assigned_bboxes (Tensor): Assigned boxes,
|
||||
shape(batch_size, num_priors, 4)
|
||||
assigned_scores (Tensor): Assigned scores,
|
||||
shape(batch_size, num_priors, num_classes)
|
||||
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
|
||||
shape(batch_size, num_priors)
|
||||
"""
|
||||
# (num_priors, 4) -> (num_priors, 2)
|
||||
priors = priors[:, :2]
|
||||
|
||||
batch_size = pred_scores.size(0)
|
||||
num_gt = gt_bboxes.size(1)
|
||||
|
||||
assigned_result = {
|
||||
'assigned_labels':
|
||||
gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes),
|
||||
'assigned_bboxes':
|
||||
gt_bboxes.new_full(pred_bboxes.shape, 0),
|
||||
'assigned_scores':
|
||||
gt_bboxes.new_full(pred_scores.shape, 0),
|
||||
'fg_mask_pre_prior':
|
||||
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
|
||||
}
|
||||
|
||||
if num_gt == 0:
|
||||
return assigned_result
|
||||
|
||||
pos_mask, alignment_metrics, overlaps = self.get_pos_mask(
|
||||
pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes,
|
||||
pad_bbox_flag, batch_size, num_gt)
|
||||
|
||||
(assigned_gt_idxs, fg_mask_pre_prior,
|
||||
pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt)
|
||||
|
||||
# assigned target
|
||||
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
|
||||
gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior,
|
||||
batch_size, num_gt)
|
||||
|
||||
# normalize
|
||||
alignment_metrics *= pos_mask
|
||||
pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0]
|
||||
pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0]
|
||||
norm_align_metric = (
|
||||
alignment_metrics * pos_overlaps /
|
||||
(pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
|
||||
assigned_scores = assigned_scores * norm_align_metric
|
||||
|
||||
assigned_result['assigned_labels'] = assigned_labels
|
||||
assigned_result['assigned_bboxes'] = assigned_bboxes
|
||||
assigned_result['assigned_scores'] = assigned_scores
|
||||
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
|
||||
return assigned_result
|
||||
|
||||
def get_pos_mask(self, pred_bboxes: Tensor, pred_scores: Tensor,
|
||||
priors: Tensor, gt_labels: Tensor, gt_bboxes: Tensor,
|
||||
pad_bbox_flag: Tensor, batch_size: int,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Get possible mask.
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predict bboxes,
|
||||
shape(batch_size, num_priors, 4)
|
||||
pred_scores (Tensor): Scores of predict bbox,
|
||||
shape(batch_size, num_priors, num_classes)
|
||||
priors (Tensor): Model priors, shape (num_priors, 2)
|
||||
gt_labels (Tensor): Ground true labels,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
pad_bbox_flag (Tensor): Ground truth bbox mask,
|
||||
1 means bbox, 0 means no bbox,
|
||||
shape(batch_size, num_gt, 1)
|
||||
batch_size (int): Batch size.
|
||||
num_gt (int): Number of ground truth.
|
||||
Returns:
|
||||
pos_mask (Tensor): Possible mask,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
alignment_metrics (Tensor): Alignment metrics,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
overlaps (Tensor): Overlaps of gt_bboxes and pred_bboxes,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
"""
|
||||
|
||||
# Compute alignment metric between all bbox and gt
|
||||
alignment_metrics, overlaps = \
|
||||
self.get_box_metrics(pred_bboxes, pred_scores, gt_labels,
|
||||
gt_bboxes, batch_size, num_gt)
|
||||
|
||||
# get is_in_gts mask
|
||||
is_in_gts = select_candidates_in_gts(priors, gt_bboxes)
|
||||
|
||||
# get topk_metric mask
|
||||
topk_metric = self.select_topk_candidates(
|
||||
alignment_metrics * is_in_gts,
|
||||
topk_mask=pad_bbox_flag.repeat([1, 1, self.topk]).bool())
|
||||
|
||||
# merge all mask to a final mask
|
||||
pos_mask = topk_metric * is_in_gts * pad_bbox_flag
|
||||
|
||||
return pos_mask, alignment_metrics, overlaps
|
||||
|
||||
def get_box_metrics(self, pred_bboxes: Tensor, pred_scores: Tensor,
|
||||
gt_labels: Tensor, gt_bboxes: Tensor, batch_size: int,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute alignment metric between all bbox and gt.
|
||||
|
||||
Args:
|
||||
pred_bboxes (Tensor): Predict bboxes,
|
||||
shape(batch_size, num_priors, 4)
|
||||
pred_scores (Tensor): Scores of predict bbox,
|
||||
shape(batch_size, num_priors, num_classes)
|
||||
gt_labels (Tensor): Ground true labels,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
batch_size (int): Batch size.
|
||||
num_gt (int): Number of ground truth.
|
||||
Returns:
|
||||
alignment_metrics (Tensor): Align metric,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
overlaps (Tensor): Overlaps, shape(batch_size, num_gt, num_priors)
|
||||
"""
|
||||
pred_scores = pred_scores.permute(0, 2, 1)
|
||||
gt_labels = gt_labels.to(torch.long)
|
||||
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
|
||||
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
|
||||
idx[1] = gt_labels.squeeze(-1)
|
||||
bbox_scores = pred_scores[idx[0], idx[1]]
|
||||
|
||||
overlaps = yolov6_iou_calculator(gt_bboxes, pred_bboxes)
|
||||
alignment_metrics = bbox_scores.pow(self.alpha) * overlaps.pow(
|
||||
self.beta)
|
||||
|
||||
return alignment_metrics, overlaps
|
||||
|
||||
def select_topk_candidates(self,
|
||||
alignment_gt_metrics: Tensor,
|
||||
using_largest_topk: bool = True,
|
||||
topk_mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""Compute alignment metric between all bbox and gt.
|
||||
|
||||
Args:
|
||||
alignment_gt_metrics (Tensor): Alignment metric of gt candidates,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
using_largest_topk (bool): Controls whether to using largest or
|
||||
smallest elements.
|
||||
topk_mask (Tensor): Topk mask,
|
||||
shape(batch_size, num_gt, self.topk)
|
||||
Returns:
|
||||
Tensor: Topk candidates mask,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
"""
|
||||
num_priors = alignment_gt_metrics.shape[-1]
|
||||
topk_metrics, topk_idxs = torch.topk(
|
||||
alignment_gt_metrics,
|
||||
self.topk,
|
||||
axis=-1,
|
||||
largest=using_largest_topk)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(axis=-1, keepdim=True) >
|
||||
self.eps).tile([1, 1, self.topk])
|
||||
topk_idxs = torch.where(topk_mask, topk_idxs,
|
||||
torch.zeros_like(topk_idxs))
|
||||
is_in_topk = F.one_hot(topk_idxs, num_priors).sum(axis=-2)
|
||||
is_in_topk = torch.where(is_in_topk > 1, torch.zeros_like(is_in_topk),
|
||||
is_in_topk)
|
||||
return is_in_topk.to(alignment_gt_metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor,
|
||||
assigned_gt_idxs: Tensor, fg_mask_pre_prior: Tensor,
|
||||
batch_size: int,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Get assigner info.
|
||||
|
||||
Args:
|
||||
gt_labels (Tensor): Ground true labels,
|
||||
shape(batch_size, num_gt, 1)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
assigned_gt_idxs (Tensor): Assigned ground truth indexes,
|
||||
shape(batch_size, num_priors)
|
||||
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
|
||||
shape(batch_size, num_priors)
|
||||
batch_size (int): Batch size.
|
||||
num_gt (int): Number of ground truth.
|
||||
Returns:
|
||||
assigned_labels (Tensor): Assigned labels,
|
||||
shape(batch_size, num_priors)
|
||||
assigned_bboxes (Tensor): Assigned bboxes,
|
||||
shape(batch_size, num_priors)
|
||||
assigned_scores (Tensor): Assigned scores,
|
||||
shape(batch_size, num_priors)
|
||||
"""
|
||||
# assigned target labels
|
||||
batch_ind = torch.arange(
|
||||
end=batch_size, dtype=torch.int64, device=gt_labels.device)[...,
|
||||
None]
|
||||
assigned_gt_idxs = assigned_gt_idxs + batch_ind * num_gt
|
||||
assigned_labels = gt_labels.long().flatten()[assigned_gt_idxs]
|
||||
|
||||
# assigned target boxes
|
||||
assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_idxs]
|
||||
|
||||
# assigned target scores
|
||||
assigned_labels[assigned_labels < 0] = 0
|
||||
assigned_scores = F.one_hot(assigned_labels, self.num_classes)
|
||||
force_gt_scores_mask = fg_mask_pre_prior[:, :, None].repeat(
|
||||
1, 1, self.num_classes)
|
||||
assigned_scores = torch.where(force_gt_scores_mask > 0,
|
||||
assigned_scores,
|
||||
torch.full_like(assigned_scores, 0))
|
||||
|
||||
return assigned_labels, assigned_bboxes, assigned_scores
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def select_candidates_in_gts(priors_points: Tensor,
|
||||
gt_bboxes: Tensor,
|
||||
eps: float = 1e-9) -> Tensor:
|
||||
"""Select the positive priors' center in gt.
|
||||
|
||||
Args:
|
||||
priors_points (Tensor): Model priors points,
|
||||
shape(num_priors, 2)
|
||||
gt_bboxes (Tensor): Ground true bboxes,
|
||||
shape(batch_size, num_gt, 4)
|
||||
eps (float): Default to 1e-9.
|
||||
Return:
|
||||
(Tensor): shape(batch_size, num_gt, num_priors)
|
||||
"""
|
||||
batch_size, num_gt, _ = gt_bboxes.size()
|
||||
gt_bboxes = gt_bboxes.reshape([-1, 4])
|
||||
|
||||
priors_number = priors_points.size(0)
|
||||
priors_points = priors_points.unsqueeze(0).repeat(batch_size * num_gt, 1,
|
||||
1)
|
||||
|
||||
# calculate the left, top, right, bottom distance between positive
|
||||
# prior center and gt side
|
||||
gt_bboxes_lt = gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, priors_number, 1)
|
||||
gt_bboxes_rb = gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, priors_number, 1)
|
||||
bbox_deltas = torch.cat(
|
||||
[priors_points - gt_bboxes_lt, gt_bboxes_rb - priors_points], dim=-1)
|
||||
bbox_deltas = bbox_deltas.reshape([batch_size, num_gt, priors_number, -1])
|
||||
|
||||
return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
|
||||
|
||||
|
||||
def select_highest_overlaps(pos_mask: Tensor, overlaps: Tensor,
|
||||
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""If an anchor box is assigned to multiple gts, the one with the highest
|
||||
iou will be selected.
|
||||
|
||||
Args:
|
||||
pos_mask (Tensor): The assigned positive sample mask,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
overlaps (Tensor): IoU between all bbox and ground truth,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
num_gt (int): Number of ground truth.
|
||||
Return:
|
||||
gt_idx_pre_prior (Tensor): Target ground truth index,
|
||||
shape(batch_size, num_priors)
|
||||
fg_mask_pre_prior (Tensor): Force matching ground truth,
|
||||
shape(batch_size, num_priors)
|
||||
pos_mask (Tensor): The assigned positive sample mask,
|
||||
shape(batch_size, num_gt, num_priors)
|
||||
"""
|
||||
fg_mask_pre_prior = pos_mask.sum(axis=-2)
|
||||
|
||||
# Make sure the positive sample matches the only one and is the largest IoU
|
||||
if fg_mask_pre_prior.max() > 1:
|
||||
mask_multi_gts = (fg_mask_pre_prior.unsqueeze(1) > 1).repeat(
|
||||
[1, num_gt, 1])
|
||||
index = overlaps.argmax(axis=1)
|
||||
is_max_overlaps = F.one_hot(index, num_gt)
|
||||
is_max_overlaps = \
|
||||
is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
|
||||
|
||||
pos_mask = torch.where(mask_multi_gts, is_max_overlaps, pos_mask)
|
||||
fg_mask_pre_prior = pos_mask.sum(axis=-2)
|
||||
|
||||
gt_idx_pre_prior = pos_mask.argmax(axis=-2)
|
||||
return gt_idx_pre_prior, fg_mask_pre_prior, pos_mask
|
||||
|
||||
|
||||
# TODO:'mmdet.BboxOverlaps2D' will cause gradient inconsistency,
|
||||
# which will be found and solved in a later version.
|
||||
def yolov6_iou_calculator(bbox1: Tensor,
|
||||
bbox2: Tensor,
|
||||
eps: float = 1e-9) -> Tensor:
|
||||
"""Calculate iou for batch.
|
||||
|
||||
Args:
|
||||
bbox1 (Tensor): shape(batch size, num_gt, 4)
|
||||
bbox2 (Tensor): shape(batch size, num_priors, 4)
|
||||
eps (float): Default to 1e-9.
|
||||
Return:
|
||||
(Tensor): IoU, shape(size, num_gt, num_priors)
|
||||
"""
|
||||
bbox1 = bbox1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
|
||||
bbox2 = bbox2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
|
||||
|
||||
# calculate xy info of predict and gt bbox
|
||||
bbox1_x1y1, bbox1_x2y2 = bbox1[:, :, :, 0:2], bbox1[:, :, :, 2:4]
|
||||
bbox2_x1y1, bbox2_x2y2 = bbox2[:, :, :, 0:2], bbox2[:, :, :, 2:4]
|
||||
|
||||
# calculate overlap area
|
||||
overlap = (torch.minimum(bbox1_x2y2, bbox2_x2y2) -
|
||||
torch.maximum(bbox1_x1y1, bbox2_x1y1)).clip(0).prod(-1)
|
||||
|
||||
# calculate bbox area
|
||||
bbox1_area = (bbox1_x2y2 - bbox1_x1y1).clip(0).prod(-1)
|
||||
bbox2_area = (bbox2_x2y2 - bbox2_x1y1).clip(0).prod(-1)
|
||||
|
||||
union = bbox1_area + bbox2_area - overlap + eps
|
||||
|
||||
return overlap / union
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .collect_env import collect_env
|
||||
from .misc import switch_to_deploy
|
||||
from .setup_env import register_all_modules
|
||||
|
||||
__all__ = ['register_all_modules', 'collect_env']
|
||||
__all__ = ['register_all_modules', 'collect_env', 'switch_to_deploy']
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmyolo.models import RepVGGBlock
|
||||
|
||||
|
||||
def switch_to_deploy(model):
|
||||
"""Model switch to deploy status."""
|
||||
for layer in model.modules():
|
||||
if isinstance(layer, RepVGGBlock):
|
||||
layer.switch_to_deploy()
|
||||
|
||||
print('Switch model to deploy modality.')
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
from mmyolo.engine.hooks import SwitchToDeployHook
|
||||
from mmyolo.models import RepVGGBlock
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestSwitchToDeployHook(TestCase):
|
||||
|
||||
def test(self):
|
||||
|
||||
runner = Mock()
|
||||
runner.model = RepVGGBlock(256, 256)
|
||||
|
||||
hook = SwitchToDeployHook()
|
||||
self.assertFalse(runner.model.deploy)
|
||||
|
||||
# test after change mode
|
||||
hook.before_test_epoch(runner)
|
||||
self.assertTrue(runner.model.deploy)
|
|
@ -20,7 +20,7 @@ class TestSingleStageDetector(TestCase):
|
|||
|
||||
@parameterized.expand([
|
||||
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
|
||||
'yolov6/yolov6_s_syncbn_8xb32-400e_coco.py',
|
||||
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
|
||||
'yolox/yolox_tiny_8xb8-300e_coco.py',
|
||||
'rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py'
|
||||
])
|
||||
|
@ -67,7 +67,7 @@ class TestSingleStageDetector(TestCase):
|
|||
@parameterized.expand([
|
||||
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
|
||||
'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
])
|
||||
|
@ -98,7 +98,7 @@ class TestSingleStageDetector(TestCase):
|
|||
@parameterized.expand([
|
||||
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
|
||||
'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_8xb32-300e_coco.py', ('cuda', 'cpu'))
|
||||
])
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmyolo.models.task_modules.assigners import BatchATSSAssigner
|
||||
|
||||
|
||||
class TestBatchATSSAssigner(TestCase):
|
||||
|
||||
def test_batch_atss_assigner(self):
|
||||
num_classes = 2
|
||||
batch_size = 2
|
||||
batch_atss_assigner = BatchATSSAssigner(
|
||||
topk=3,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
|
||||
num_classes=num_classes)
|
||||
priors = torch.FloatTensor([
|
||||
[4., 4., 8., 8.],
|
||||
[12., 4., 8., 8.],
|
||||
[20., 4., 8., 8.],
|
||||
[28., 4., 8., 8.],
|
||||
]).repeat(21, 1)
|
||||
gt_bboxes = torch.FloatTensor([
|
||||
[0, 0, 60, 93],
|
||||
[229, 0, 532, 157],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
gt_labels = torch.LongTensor([
|
||||
[0],
|
||||
[11],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
num_level_bboxes = [64, 16, 4]
|
||||
pad_bbox_flag = torch.FloatTensor([
|
||||
[1],
|
||||
[0],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
pred_bboxes = torch.FloatTensor([
|
||||
[-4., -4., 12., 12.],
|
||||
[4., -4., 20., 12.],
|
||||
[12., -4., 28., 12.],
|
||||
[20., -4., 36., 12.],
|
||||
]).unsqueeze(0).repeat(batch_size, 21, 1)
|
||||
batch_assign_result = batch_atss_assigner.forward(
|
||||
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
|
||||
pad_bbox_flag)
|
||||
|
||||
assigned_labels = batch_assign_result['assigned_labels']
|
||||
assigned_bboxes = batch_assign_result['assigned_bboxes']
|
||||
assigned_scores = batch_assign_result['assigned_scores']
|
||||
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
|
||||
|
||||
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
|
||||
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
|
||||
4]))
|
||||
self.assertEqual(assigned_scores.shape,
|
||||
torch.Size([batch_size, 84, num_classes]))
|
||||
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
|
||||
|
||||
def test_batch_atss_assigner_with_empty_gt(self):
|
||||
"""Test corner case where an image might have no true detections."""
|
||||
num_classes = 2
|
||||
batch_size = 2
|
||||
batch_atss_assigner = BatchATSSAssigner(
|
||||
topk=3,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
|
||||
num_classes=num_classes)
|
||||
priors = torch.FloatTensor([
|
||||
[4., 4., 8., 8.],
|
||||
[12., 4., 8., 8.],
|
||||
[20., 4., 8., 8.],
|
||||
[28., 4., 8., 8.],
|
||||
]).repeat(21, 1)
|
||||
num_level_bboxes = [64, 16, 4]
|
||||
pad_bbox_flag = torch.FloatTensor([
|
||||
[1],
|
||||
[0],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
pred_bboxes = torch.FloatTensor([
|
||||
[-4., -4., 12., 12.],
|
||||
[4., -4., 20., 12.],
|
||||
[12., -4., 28., 12.],
|
||||
[20., -4., 36., 12.],
|
||||
]).unsqueeze(0).repeat(batch_size, 21, 1)
|
||||
|
||||
gt_bboxes = torch.empty(batch_size, 2, 4)
|
||||
gt_labels = torch.empty(batch_size, 2, 1)
|
||||
|
||||
batch_assign_result = batch_atss_assigner.forward(
|
||||
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
|
||||
pad_bbox_flag)
|
||||
|
||||
assigned_labels = batch_assign_result['assigned_labels']
|
||||
assigned_bboxes = batch_assign_result['assigned_bboxes']
|
||||
assigned_scores = batch_assign_result['assigned_scores']
|
||||
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
|
||||
|
||||
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
|
||||
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
|
||||
4]))
|
||||
self.assertEqual(assigned_scores.shape,
|
||||
torch.Size([batch_size, 84, num_classes]))
|
||||
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
|
||||
|
||||
def test_batch_atss_assigner_with_empty_boxes(self):
|
||||
"""Test corner case where a network might predict no boxes."""
|
||||
num_classes = 2
|
||||
batch_size = 2
|
||||
batch_atss_assigner = BatchATSSAssigner(
|
||||
topk=3,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
|
||||
num_classes=num_classes)
|
||||
priors = torch.empty(84, 4)
|
||||
gt_bboxes = torch.FloatTensor([
|
||||
[0, 0, 60, 93],
|
||||
[229, 0, 532, 157],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
gt_labels = torch.LongTensor([
|
||||
[0],
|
||||
[11],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
num_level_bboxes = [64, 16, 4]
|
||||
pad_bbox_flag = torch.FloatTensor([[1], [0]]).unsqueeze(0).repeat(
|
||||
batch_size, 1, 1)
|
||||
pred_bboxes = torch.FloatTensor([
|
||||
[-4., -4., 12., 12.],
|
||||
[4., -4., 20., 12.],
|
||||
[12., -4., 28., 12.],
|
||||
[20., -4., 36., 12.],
|
||||
]).unsqueeze(0).repeat(batch_size, 21, 1)
|
||||
|
||||
batch_assign_result = batch_atss_assigner.forward(
|
||||
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
|
||||
pad_bbox_flag)
|
||||
assigned_labels = batch_assign_result['assigned_labels']
|
||||
assigned_bboxes = batch_assign_result['assigned_bboxes']
|
||||
assigned_scores = batch_assign_result['assigned_scores']
|
||||
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
|
||||
|
||||
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
|
||||
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
|
||||
4]))
|
||||
self.assertEqual(assigned_scores.shape,
|
||||
torch.Size([batch_size, 84, num_classes]))
|
||||
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
|
||||
|
||||
def test_batch_atss_assigner_with_empty_boxes_and_gt(self):
|
||||
"""Test corner case where a network might predict no boxes and no
|
||||
gt."""
|
||||
num_classes = 2
|
||||
batch_size = 2
|
||||
batch_atss_assigner = BatchATSSAssigner(
|
||||
topk=3,
|
||||
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
|
||||
num_classes=num_classes)
|
||||
priors = torch.empty(84, 4)
|
||||
gt_bboxes = torch.empty(batch_size, 2, 4)
|
||||
gt_labels = torch.empty(batch_size, 2, 1)
|
||||
num_level_bboxes = [64, 16, 4]
|
||||
pad_bbox_flag = torch.empty(batch_size, 2, 1)
|
||||
pred_bboxes = torch.empty(batch_size, 84, 4)
|
||||
|
||||
batch_assign_result = batch_atss_assigner.forward(
|
||||
pred_bboxes, priors, num_level_bboxes, gt_labels, gt_bboxes,
|
||||
pad_bbox_flag)
|
||||
assigned_labels = batch_assign_result['assigned_labels']
|
||||
assigned_bboxes = batch_assign_result['assigned_bboxes']
|
||||
assigned_scores = batch_assign_result['assigned_scores']
|
||||
fg_mask_pre_prior = batch_assign_result['fg_mask_pre_prior']
|
||||
|
||||
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
|
||||
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
|
||||
4]))
|
||||
self.assertEqual(assigned_scores.shape,
|
||||
torch.Size([batch_size, 84, num_classes]))
|
||||
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmyolo.models.task_modules.assigners import BatchTaskAlignedAssigner
|
||||
|
||||
|
||||
class TestBatchTaskAlignedAssigner(TestCase):
|
||||
|
||||
def test_batch_task_aligned_assigner(self):
|
||||
batch_size = 2
|
||||
num_classes = 4
|
||||
assigner = BatchTaskAlignedAssigner(
|
||||
num_classes=num_classes, alpha=1, beta=6, topk=13, eps=1e-9)
|
||||
pred_scores = torch.FloatTensor([
|
||||
[0.1, 0.2],
|
||||
[0.2, 0.3],
|
||||
[0.3, 0.4],
|
||||
[0.4, 0.5],
|
||||
]).unsqueeze(0).repeat(batch_size, 21, 1)
|
||||
priors = torch.FloatTensor([
|
||||
[0, 0, 4., 4.],
|
||||
[0, 0, 12., 4.],
|
||||
[0, 0, 20., 4.],
|
||||
[0, 0, 28., 4.],
|
||||
]).repeat(21, 1)
|
||||
gt_bboxes = torch.FloatTensor([
|
||||
[0, 0, 60, 93],
|
||||
[229, 0, 532, 157],
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
gt_labels = torch.LongTensor([[0], [1]
|
||||
]).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
pad_bbox_flag = torch.FloatTensor([[1], [0]]).unsqueeze(0).repeat(
|
||||
batch_size, 1, 1)
|
||||
pred_bboxes = torch.FloatTensor([
|
||||
[-4., -4., 12., 12.],
|
||||
[4., -4., 20., 12.],
|
||||
[12., -4., 28., 12.],
|
||||
[20., -4., 36., 12.],
|
||||
]).unsqueeze(0).repeat(batch_size, 21, 1)
|
||||
|
||||
assign_result = assigner.forward(pred_bboxes, pred_scores, priors,
|
||||
gt_labels, gt_bboxes, pad_bbox_flag)
|
||||
|
||||
assigned_labels = assign_result['assigned_labels']
|
||||
assigned_bboxes = assign_result['assigned_bboxes']
|
||||
assigned_scores = assign_result['assigned_scores']
|
||||
fg_mask_pre_prior = assign_result['fg_mask_pre_prior']
|
||||
|
||||
self.assertEqual(assigned_labels.shape, torch.Size([batch_size, 84]))
|
||||
self.assertEqual(assigned_bboxes.shape, torch.Size([batch_size, 84,
|
||||
4]))
|
||||
self.assertEqual(assigned_scores.shape,
|
||||
torch.Size([batch_size, 84, num_classes]))
|
||||
self.assertEqual(fg_mask_pre_prior.shape, torch.Size([batch_size, 84]))
|
|
@ -27,6 +27,10 @@ def parse_args():
|
|||
help='dump predictions to a pickle file for offline evaluation')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show prediction results')
|
||||
parser.add_argument(
|
||||
'--deploy',
|
||||
action='store_true',
|
||||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--show-dir',
|
||||
help='directory where painted images will be saved. '
|
||||
|
@ -85,6 +89,9 @@ def main():
|
|||
if args.show or args.show_dir:
|
||||
cfg = trigger_visualization_hook(cfg, args)
|
||||
|
||||
if args.deploy:
|
||||
cfg.custom_hooks.append(dict(type='SwitchToDeployHook'))
|
||||
|
||||
# Dump predictions
|
||||
if args.out is not None:
|
||||
assert args.out.endswith(('.pkl', '.pickle')), \
|
||||
|
|
Loading…
Reference in New Issue