mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support YOLOv7 P6 training (#310)
* support p6 train * fix bug * add readme * add linkpull/330/head
parent
4a8699d6fe
commit
e85ddeac0d
|
@ -0,0 +1,45 @@
|
|||
# YOLOv7
|
||||
|
||||
> [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
YOLOv7 surpasses all known object detectors in both speed and accuracy in the range from 5 FPS to 160 FPS and has the highest accuracy 56.8% AP among all known real-time object detectors with 30 FPS or higher on GPU V100. YOLOv7-E6 object detector (56 FPS V100, 55.9% AP) outperforms both transformer-based detector SWIN-L Cascade-Mask R-CNN (9.2 FPS A100, 53.9% AP) by 509% in speed and 2% in accuracy, and convolutional-based detector ConvNeXt-XL Cascade-Mask R-CNN (8.6 FPS A100, 55.2% AP) by 551% in speed and 0.7% AP in accuracy, as well as YOLOv7 outperforms: YOLOR, YOLOX, Scaled-YOLOv4, YOLOv5, DETR, Deformable DETR, DINO-5scale-R50, ViT-Adapter-B and many other object detectors in speed and accuracy. Moreover, we train YOLOv7 only on MS COCO dataset from scratch without using any other datasets or pre-trained weights. Source code is released in [this https URL](https://github.com/WongKinYiu/yolov7).
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/204231759-cc5c77a9-38c6-4a41-85be-eb97e4b2bcbb.png"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### COCO
|
||||
|
||||
| Backbone | Arch | Size | SyncBN | AMP | Mem (GB) | Box AP | Config | Download |
|
||||
| :------: | :--: | :--: | :---: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOv7-tiny | P5 | 640 | Yes | Yes | 2.7 | 37.5 | [config](../yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719-0ee5bbdf.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719.log.json) |
|
||||
| YOLOv7-l | P5 | 640 | Yes | Yes | 10.3 | 50.9 | [config](../yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601-8113c0eb.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601.log.json) |
|
||||
| YOLOv7-x | P5 | 640 | Yes | Yes | 13.7 | 52.8 | [config](../yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331-ef949a68.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331.log.json) |
|
||||
| YOLOv7-w | P6 | 1280 | Yes | Yes | 27.0 | 54.1 | [config](../yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031-a68ef9d2.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031.log.json) |
|
||||
| YOLOv7-e | P6 | 1280 | Yes | Yes | 42.5 | 55.1 | [config](../yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636-34425033.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636.log.json) |
|
||||
|
||||
**Note**:
|
||||
In the official YOLOv7 code, the `random_perspective` data augmentation in COCO object detection task training uses mask annotation information, which leads to higher performance. Object detection should not use mask annotation, so only box annotation information is used in `MMYOLO`. We will use the mask annotation information in the instance segmentation task.
|
||||
|
||||
1. The performance is unstable and may fluctuate by about 0.3 mAP. The performance shown above is the best model.
|
||||
2. If users need the weight of `YOLOv7-e2e`, they can train according to the configs provided by us, or convert the official weight according to the [converter script](https://github.com/open-mmlab/mmyolo/blob/main/tools/model_converters/yolov7_to_mmyolo.py).
|
||||
3. `fast` means that `YOLOv5DetDataPreprocessor` and `yolov5_collate` are used for data preprocessing, which is faster for training, but less flexible for multitasking. Recommended to use fast version config if you only care about object detection.
|
||||
4. `SyncBN` means use SyncBN, `AMP` indicates training with mixed precision.
|
||||
5. We use 8x A100 for training, and the single-GPU batch size is 16. This is different from the official code.
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@article{wang2022yolov7,
|
||||
title={{YOLOv7}: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors},
|
||||
author={Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark},
|
||||
journal={arXiv preprint arXiv:2207.02696},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,83 @@
|
|||
Collections:
|
||||
- Name: YOLOv7
|
||||
Metadata:
|
||||
Training Data: COCO
|
||||
Training Techniques:
|
||||
- SGD with Nesterov
|
||||
- Weight Decay
|
||||
- AMP
|
||||
- Synchronize BN
|
||||
Training Resources: 8x A100 GPUs
|
||||
Architecture:
|
||||
- EELAN
|
||||
- PAFPN
|
||||
- RepVGG
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2207.02696
|
||||
Title: 'YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors'
|
||||
README: configs/yolov7/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmyolo/blob/v0.0.1/mmyolo/models/detectors/yolo_detector.py#L12
|
||||
Version: v0.0.1
|
||||
|
||||
Models:
|
||||
- Name: yolov7_tiny_syncbn_fast_8x16b-300e_coco
|
||||
In Collection: YOLOv7
|
||||
Config: configs/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 2.7
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 37.5
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719-0ee5bbdf.pth
|
||||
- Name: yolov7_l_syncbn_fast_8x16b-300e_coco
|
||||
In Collection: YOLOv7
|
||||
Config: configs/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 10.3
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 50.9
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601-8113c0eb.pth
|
||||
- Name: yolov7_x_syncbn_fast_8x16b-300e_coco
|
||||
In Collection: YOLOv7
|
||||
Config: configs/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 13.7
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 52.8
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331-ef949a68.pth
|
||||
- Name: yolov7_w-p6_syncbn_fast_8x16b-300e_coco
|
||||
In Collection: YOLOv7
|
||||
Config: configs/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 27.0
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 54.1
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031-a68ef9d2.pth
|
||||
- Name: yolov7_e-p6_syncbn_fast_8x16b-300e_coco
|
||||
In Collection: YOLOv7
|
||||
Config: configs/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 42.5
|
||||
Epochs: 300
|
||||
Results:
|
||||
- Task: Object Detection
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
box AP: 55.1
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636-34425033.pth
|
|
@ -29,8 +29,72 @@ model = dict(
|
|||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
prior_generator=dict(base_sizes=anchors, strides=strides),
|
||||
simota_candidate_topk=20, # note
|
||||
# scaled based on number of detection layers
|
||||
loss_cls=dict(loss_weight=0.3 *
|
||||
(num_classes / 80 * 3 / num_det_layers)),
|
||||
loss_bbox=dict(loss_weight=0.05 * (3 / num_det_layers)),
|
||||
loss_obj=dict(loss_weight=0.7 *
|
||||
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
|
||||
obj_level_weights=[4.0, 1.0, 0.25, 0.06]))
|
||||
|
||||
pre_transform = _base_.pre_transform
|
||||
|
||||
mosiac4_pipeline = [
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=img_scale,
|
||||
pad_val=114.0,
|
||||
pre_transform=pre_transform),
|
||||
dict(
|
||||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_shear_degree=0.0,
|
||||
max_translate_ratio=0.2, # note
|
||||
scaling_ratio_range=(0.1, 2.0), # note
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2),
|
||||
border_val=(114, 114, 114)),
|
||||
]
|
||||
|
||||
mosiac9_pipeline = [
|
||||
dict(
|
||||
type='Mosaic9',
|
||||
img_scale=img_scale,
|
||||
pad_val=114.0,
|
||||
pre_transform=pre_transform),
|
||||
dict(
|
||||
type='YOLOv5RandomAffine',
|
||||
max_rotate_degree=0.0,
|
||||
max_shear_degree=0.0,
|
||||
max_translate_ratio=0.2, # note
|
||||
scaling_ratio_range=(0.1, 2.0), # note
|
||||
border=(-img_scale[0] // 2, -img_scale[1] // 2),
|
||||
border_val=(114, 114, 114)),
|
||||
]
|
||||
|
||||
randchoice_mosaic_pipeline = dict(
|
||||
type='RandomChoice',
|
||||
transforms=[mosiac4_pipeline, mosiac9_pipeline],
|
||||
prob=[0.8, 0.2])
|
||||
|
||||
train_pipeline = [
|
||||
*pre_transform,
|
||||
randchoice_mosaic_pipeline,
|
||||
dict(
|
||||
type='YOLOv5MixUp',
|
||||
alpha=8.0, # note
|
||||
beta=8.0, # note
|
||||
prob=0.15,
|
||||
pre_transform=[*pre_transform, randchoice_mosaic_pipeline]),
|
||||
dict(type='YOLOv5HSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction'))
|
||||
]
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
|
@ -45,8 +109,10 @@ test_pipeline = [
|
|||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# The only difference between P6 and P5 in terms of
|
||||
# hyperparameters is lr_factor
|
||||
default_hooks = dict(param_scheduler=dict(lr_factor=0.2))
|
||||
|
|
|
@ -167,6 +167,7 @@ class YOLOv5Head(BaseDenseHead):
|
|||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
prior_match_thr: float = 4.0,
|
||||
near_neighbor_thr: float = 0.5,
|
||||
obj_level_weights: List[float] = [4.0, 1.0, 0.4],
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
|
@ -192,6 +193,7 @@ class YOLOv5Head(BaseDenseHead):
|
|||
self.featmap_sizes = [torch.empty(1)] * self.num_levels
|
||||
|
||||
self.prior_match_thr = prior_match_thr
|
||||
self.near_neighbor_thr = near_neighbor_thr
|
||||
self.obj_level_weights = obj_level_weights
|
||||
|
||||
self.special_init()
|
||||
|
@ -231,7 +233,7 @@ class YOLOv5Head(BaseDenseHead):
|
|||
[0, 1], # up
|
||||
[-1, 0], # right
|
||||
[0, -1], # bottom
|
||||
]).float() * 0.5
|
||||
]).float()
|
||||
self.register_buffer(
|
||||
'grid_offset', grid_offset[:, None], persistent=False)
|
||||
|
||||
|
@ -534,9 +536,10 @@ class YOLOv5Head(BaseDenseHead):
|
|||
# them as positive samples as well.
|
||||
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
||||
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
||||
left, up = ((batch_targets_cxcy % 1 < 0.5) &
|
||||
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
|
||||
(batch_targets_cxcy > 1)).T
|
||||
right, bottom = ((grid_xy % 1 < 0.5) & (grid_xy > 1)).T
|
||||
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
|
||||
(grid_xy > 1)).T
|
||||
offset_inds = torch.stack(
|
||||
(torch.ones_like(left), left, up, right, bottom))
|
||||
|
||||
|
@ -552,7 +555,8 @@ class YOLOv5Head(BaseDenseHead):
|
|||
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
|
||||
-1), img_class_inds.long().T
|
||||
|
||||
grid_xy_long = (grid_xy - retained_offsets).long()
|
||||
grid_xy_long = (grid_xy -
|
||||
retained_offsets * self.near_neighbor_thr).long()
|
||||
grid_x_inds, grid_y_inds = grid_xy_long.T
|
||||
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -47,7 +47,6 @@ class YOLOv7HeadModule(YOLOv5HeadModule):
|
|||
mi.bias.data = b.view(-1)
|
||||
|
||||
|
||||
# TODO: to check
|
||||
@MODELS.register_module()
|
||||
class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
||||
"""YOLOv7Head head module used in YOLOv7."""
|
||||
|
@ -56,12 +55,14 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
*args,
|
||||
main_out_channels: Sequence[int] = [256, 512, 768, 1024],
|
||||
aux_out_channels: Sequence[int] = [320, 640, 960, 1280],
|
||||
use_aux: bool = True,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
**kwargs):
|
||||
self.main_out_channels = main_out_channels
|
||||
self.aux_out_channels = aux_out_channels
|
||||
self.use_aux = use_aux
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -69,7 +70,6 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
def _init_layers(self):
|
||||
"""initialize conv layers in YOLOv7 head."""
|
||||
self.main_convs_pred = nn.ModuleList()
|
||||
self.aux_convs_pred = nn.ModuleList()
|
||||
for i in range(self.num_levels):
|
||||
conv_pred = nn.Sequential(
|
||||
ConvModule(
|
||||
|
@ -86,17 +86,22 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
)
|
||||
self.main_convs_pred.append(conv_pred)
|
||||
|
||||
aux_pred = nn.Sequential(
|
||||
ConvModule(
|
||||
self.in_channels[i],
|
||||
self.aux_out_channels[i],
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(self.aux_out_channels[i],
|
||||
self.num_base_priors * self.num_out_attrib, 1))
|
||||
self.aux_convs_pred.append(aux_pred)
|
||||
if self.use_aux:
|
||||
self.aux_convs_pred = nn.ModuleList()
|
||||
for i in range(self.num_levels):
|
||||
aux_pred = nn.Sequential(
|
||||
ConvModule(
|
||||
self.in_channels[i],
|
||||
self.aux_out_channels[i],
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
nn.Conv2d(self.aux_out_channels[i],
|
||||
self.num_base_priors * self.num_out_attrib, 1))
|
||||
self.aux_convs_pred.append(aux_pred)
|
||||
else:
|
||||
self.aux_convs_pred = [None] * len(self.main_convs_pred)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the bias of YOLOv5 head."""
|
||||
|
@ -110,12 +115,13 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
||||
mi.bias.data = b.view(-1)
|
||||
|
||||
aux = aux[1] # nn.Conv2d
|
||||
b = aux.bias.data.view(3, -1)
|
||||
# obj (8 objects per 640 image)
|
||||
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
||||
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
||||
mi.bias.data = b.view(-1)
|
||||
if self.use_aux:
|
||||
aux = aux[1] # nn.Conv2d
|
||||
b = aux.bias.data.view(3, -1)
|
||||
# obj (8 objects per 640 image)
|
||||
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
||||
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
||||
mi.bias.data = b.view(-1)
|
||||
|
||||
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
||||
"""Forward features from the upstream network.
|
||||
|
@ -132,7 +138,7 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
self.aux_convs_pred)
|
||||
|
||||
def forward_single(self, x: Tensor, convs: nn.Module,
|
||||
aux_convs: nn.Module) \
|
||||
aux_convs: Optional[nn.Module]) \
|
||||
-> Tuple[Union[Tensor, List], Union[Tensor, List],
|
||||
Union[Tensor, List]]:
|
||||
"""Forward feature of a single scale level."""
|
||||
|
@ -146,7 +152,7 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
|||
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
|
||||
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
|
||||
|
||||
if not self.training:
|
||||
if not self.training or not self.use_aux:
|
||||
return cls_score, bbox_pred, objectness
|
||||
else:
|
||||
aux_pred_map = aux_convs(x)
|
||||
|
@ -178,11 +184,13 @@ class YOLOv7Head(YOLOv5Head):
|
|||
|
||||
def __init__(self,
|
||||
*args,
|
||||
simota_candidate_topk: int = 10,
|
||||
simota_candidate_topk: int = 20,
|
||||
simota_iou_weight: float = 3.0,
|
||||
simota_cls_weight: float = 1.0,
|
||||
aux_loss_weights: float = 0.25,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aux_loss_weights = aux_loss_weights
|
||||
self.assigner = BatchYOLOv7Assigner(
|
||||
num_classes=self.num_classes,
|
||||
num_base_priors=self.num_base_priors,
|
||||
|
@ -194,9 +202,9 @@ class YOLOv7Head(YOLOv5Head):
|
|||
|
||||
def loss_by_feat(
|
||||
self,
|
||||
cls_scores: Sequence[Tensor],
|
||||
bbox_preds: Sequence[Tensor],
|
||||
objectnesses: Sequence[Tensor],
|
||||
cls_scores: Sequence[Union[Tensor, List]],
|
||||
bbox_preds: Sequence[Union[Tensor, List]],
|
||||
objectnesses: Sequence[Union[Tensor, List]],
|
||||
batch_gt_instances: Sequence[InstanceData],
|
||||
batch_img_metas: Sequence[dict],
|
||||
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
||||
|
@ -225,36 +233,91 @@ class YOLOv7Head(YOLOv5Head):
|
|||
Returns:
|
||||
dict[str, Tensor]: A dictionary of losses.
|
||||
"""
|
||||
batch_size = cls_scores[0].shape[0]
|
||||
device = cls_scores[0].device
|
||||
loss_cls = torch.zeros(1, device=device)
|
||||
loss_box = torch.zeros(1, device=device)
|
||||
loss_obj = torch.zeros(1, device=device)
|
||||
|
||||
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
|
||||
cls_scores)
|
||||
scaled_factors = [
|
||||
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
|
||||
for head_pred in head_preds
|
||||
]
|
||||
if isinstance(cls_scores[0], Sequence):
|
||||
with_aux = True
|
||||
batch_size = cls_scores[0][0].shape[0]
|
||||
device = cls_scores[0][0].device
|
||||
|
||||
# 1. Convert gt to norm xywh format
|
||||
bbox_preds_main, bbox_preds_aux = zip(*bbox_preds)
|
||||
objectnesses_main, objectnesses_aux = zip(*objectnesses)
|
||||
cls_scores_main, cls_scores_aux = zip(*cls_scores)
|
||||
|
||||
head_preds = self._merge_predict_results(bbox_preds_main,
|
||||
objectnesses_main,
|
||||
cls_scores_main)
|
||||
head_preds_aux = self._merge_predict_results(
|
||||
bbox_preds_aux, objectnesses_aux, cls_scores_aux)
|
||||
else:
|
||||
with_aux = False
|
||||
batch_size = cls_scores[0].shape[0]
|
||||
device = cls_scores[0].device
|
||||
|
||||
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
|
||||
cls_scores)
|
||||
|
||||
# Convert gt to norm xywh format
|
||||
# (num_base_priors, num_batch_gt, 7)
|
||||
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
|
||||
# w_norm, h_norm, prior_idx)
|
||||
batch_targets_normed = self._convert_gt_to_norm_format(
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
scaled_factors = [
|
||||
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
|
||||
for head_pred in head_preds
|
||||
]
|
||||
|
||||
loss_cls, loss_obj, loss_box = self._calc_loss(
|
||||
head_preds=head_preds,
|
||||
head_preds_aux=None,
|
||||
batch_targets_normed=batch_targets_normed,
|
||||
near_neighbor_thr=self.near_neighbor_thr,
|
||||
scaled_factors=scaled_factors,
|
||||
batch_img_metas=batch_img_metas,
|
||||
device=device)
|
||||
|
||||
if with_aux:
|
||||
loss_cls_aux, loss_obj_aux, loss_box_aux = self._calc_loss(
|
||||
head_preds=head_preds,
|
||||
head_preds_aux=head_preds_aux,
|
||||
batch_targets_normed=batch_targets_normed,
|
||||
near_neighbor_thr=self.near_neighbor_thr * 2,
|
||||
scaled_factors=scaled_factors,
|
||||
batch_img_metas=batch_img_metas,
|
||||
device=device)
|
||||
loss_cls += self.aux_loss_weights * loss_cls_aux
|
||||
loss_obj += self.aux_loss_weights * loss_obj_aux
|
||||
loss_box += self.aux_loss_weights * loss_box_aux
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
return dict(
|
||||
loss_cls=loss_cls * batch_size * world_size,
|
||||
loss_obj=loss_obj * batch_size * world_size,
|
||||
loss_bbox=loss_box * batch_size * world_size)
|
||||
|
||||
def _calc_loss(self, head_preds, head_preds_aux, batch_targets_normed,
|
||||
near_neighbor_thr, scaled_factors, batch_img_metas, device):
|
||||
loss_cls = torch.zeros(1, device=device)
|
||||
loss_box = torch.zeros(1, device=device)
|
||||
loss_obj = torch.zeros(1, device=device)
|
||||
|
||||
assigner_results = self.assigner(
|
||||
head_preds, batch_targets_normed,
|
||||
batch_img_metas[0]['batch_input_shape'], self.priors_base_sizes,
|
||||
self.grid_offset)
|
||||
head_preds,
|
||||
batch_targets_normed,
|
||||
batch_img_metas[0]['batch_input_shape'],
|
||||
self.priors_base_sizes,
|
||||
self.grid_offset,
|
||||
near_neighbor_thr=near_neighbor_thr)
|
||||
# mlvl is mean multi_level
|
||||
mlvl_positive_infos = assigner_results['mlvl_positive_infos']
|
||||
mlvl_priors = assigner_results['mlvl_priors']
|
||||
mlvl_targets_normed = assigner_results['mlvl_targets_normed']
|
||||
|
||||
# calc losses
|
||||
if head_preds_aux is not None:
|
||||
# This is mean calc aux branch loss
|
||||
head_preds = head_preds_aux
|
||||
|
||||
for i, head_pred in enumerate(head_preds):
|
||||
batch_inds, proir_idx, grid_x, grid_y = mlvl_positive_infos[i].T
|
||||
num_pred_positive = batch_inds.shape[0]
|
||||
|
@ -299,12 +362,7 @@ class YOLOv7Head(YOLOv5Head):
|
|||
target_class)
|
||||
else:
|
||||
loss_cls += head_pred_positive[:, 5:].sum() * 0
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
return dict(
|
||||
loss_cls=loss_cls * batch_size * world_size,
|
||||
loss_obj=loss_obj * batch_size * world_size,
|
||||
loss_bbox=loss_box * batch_size * world_size)
|
||||
return loss_cls, loss_obj, loss_box
|
||||
|
||||
def _merge_predict_results(self, bbox_preds: Sequence[Tensor],
|
||||
objectnesses: Sequence[Tensor],
|
||||
|
|
|
@ -49,8 +49,13 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
self.cls_weight = cls_weight
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pred_results, batch_targets_normed, batch_input_shape,
|
||||
priors_base_sizes, grid_offset) -> dict:
|
||||
def forward(self,
|
||||
pred_results,
|
||||
batch_targets_normed,
|
||||
batch_input_shape,
|
||||
priors_base_sizes,
|
||||
grid_offset,
|
||||
near_neighbor_thr=0.5) -> dict:
|
||||
# (num_base_priors, num_batch_gt, 7)
|
||||
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
|
||||
# w_norm, h_norm, prior_idx)
|
||||
|
@ -65,8 +70,16 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
mlvl_priors=[] * num_levels,
|
||||
mlvl_targets_normed=[] * num_levels)
|
||||
|
||||
# if near_neighbor_thr = 0.5 are mean the nearest
|
||||
# 3 neighbors are also considered positive samples.
|
||||
# if near_neighbor_thr = 1.0 are mean the nearest
|
||||
# 5 neighbors are also considered positive samples.
|
||||
mlvl_positive_infos, mlvl_priors = self.yolov5_assigner(
|
||||
pred_results, batch_targets_normed, priors_base_sizes, grid_offset)
|
||||
pred_results,
|
||||
batch_targets_normed,
|
||||
priors_base_sizes,
|
||||
grid_offset,
|
||||
near_neighbor_thr=near_neighbor_thr)
|
||||
|
||||
mlvl_positive_infos, mlvl_priors, \
|
||||
mlvl_targets_normed = self.simota_assigner(
|
||||
|
@ -85,8 +98,12 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
mlvl_priors=mlvl_priors,
|
||||
mlvl_targets_normed=mlvl_targets_normed)
|
||||
|
||||
def yolov5_assigner(self, pred_results, batch_targets_normed,
|
||||
priors_base_sizes, grid_offset):
|
||||
def yolov5_assigner(self,
|
||||
pred_results,
|
||||
batch_targets_normed,
|
||||
priors_base_sizes,
|
||||
grid_offset,
|
||||
near_neighbor_thr=0.5):
|
||||
num_batch_gts = batch_targets_normed.shape[1]
|
||||
assert num_batch_gts > 0
|
||||
|
||||
|
@ -121,9 +138,10 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
# Positive samples with additional neighbors
|
||||
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
||||
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
||||
left, up = ((batch_targets_cxcy % 1 < 0.5) &
|
||||
left, up = ((batch_targets_cxcy % 1 < near_neighbor_thr) &
|
||||
(batch_targets_cxcy > 1)).T
|
||||
right, bottom = ((grid_xy % 1 < 0.5) & (grid_xy > 1)).T
|
||||
right, bottom = ((grid_xy % 1 < near_neighbor_thr) &
|
||||
(grid_xy > 1)).T
|
||||
offset_inds = torch.stack(
|
||||
(torch.ones_like(left), left, up, right, bottom))
|
||||
batch_targets_scaled = batch_targets_scaled.repeat(
|
||||
|
@ -138,6 +156,7 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
# mlvl_positive_info: (num_matched_target, 4)
|
||||
# 4 is mean (batch_idx, prior_idx, x_scaled, y_scaled)
|
||||
mlvl_positive_info = batch_targets_scaled[:, [0, 6, 2, 3]]
|
||||
retained_offsets = retained_offsets * near_neighbor_thr
|
||||
mlvl_positive_info[:,
|
||||
2:] = mlvl_positive_info[:,
|
||||
2:] - retained_offsets
|
||||
|
|
Loading…
Reference in New Issue