[Feature] Support YOLOv7 P6 training (#310)

* support p6 train

* fix bug

* add readme

* add link
pull/330/head
Haian Huang(深度眸) 2022-11-30 18:45:08 +08:00 committed by GitHub
parent 4a8699d6fe
commit e85ddeac0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 336 additions and 61 deletions

View File

@ -0,0 +1,45 @@
# YOLOv7
> [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696)
<!-- [ALGORITHM] -->
## Abstract
YOLOv7 surpasses all known object detectors in both speed and accuracy in the range from 5 FPS to 160 FPS and has the highest accuracy 56.8% AP among all known real-time object detectors with 30 FPS or higher on GPU V100. YOLOv7-E6 object detector (56 FPS V100, 55.9% AP) outperforms both transformer-based detector SWIN-L Cascade-Mask R-CNN (9.2 FPS A100, 53.9% AP) by 509% in speed and 2% in accuracy, and convolutional-based detector ConvNeXt-XL Cascade-Mask R-CNN (8.6 FPS A100, 55.2% AP) by 551% in speed and 0.7% AP in accuracy, as well as YOLOv7 outperforms: YOLOR, YOLOX, Scaled-YOLOv4, YOLOv5, DETR, Deformable DETR, DINO-5scale-R50, ViT-Adapter-B and many other object detectors in speed and accuracy. Moreover, we train YOLOv7 only on MS COCO dataset from scratch without using any other datasets or pre-trained weights. Source code is released in [this https URL](https://github.com/WongKinYiu/yolov7).
<div align=center>
<img src="https://user-images.githubusercontent.com/17425982/204231759-cc5c77a9-38c6-4a41-85be-eb97e4b2bcbb.png"/>
</div>
## Results and models
### COCO
| Backbone | Arch | Size | SyncBN | AMP | Mem (GB) | Box AP | Config | Download |
| :------: | :--: | :--: | :---: | :----: | :-: | :------: | :----: | :---------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOv7-tiny | P5 | 640 | Yes | Yes | 2.7 | 37.5 | [config](../yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719-0ee5bbdf.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719.log.json) |
| YOLOv7-l | P5 | 640 | Yes | Yes | 10.3 | 50.9 | [config](../yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601-8113c0eb.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601.log.json) |
| YOLOv7-x | P5 | 640 | Yes | Yes | 13.7 | 52.8 | [config](../yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331-ef949a68.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331.log.json) |
| YOLOv7-w | P6 | 1280 | Yes | Yes | 27.0 | 54.1 | [config](../yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031-a68ef9d2.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031.log.json) |
| YOLOv7-e | P6 | 1280 | Yes | Yes | 42.5 | 55.1 | [config](../yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636-34425033.pth) | [log](https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636.log.json) |
**Note**:
In the official YOLOv7 code, the `random_perspective` data augmentation in COCO object detection task training uses mask annotation information, which leads to higher performance. Object detection should not use mask annotation, so only box annotation information is used in `MMYOLO`. We will use the mask annotation information in the instance segmentation task.
1. The performance is unstable and may fluctuate by about 0.3 mAP. The performance shown above is the best model.
2. If users need the weight of `YOLOv7-e2e`, they can train according to the configs provided by us, or convert the official weight according to the [converter script](https://github.com/open-mmlab/mmyolo/blob/main/tools/model_converters/yolov7_to_mmyolo.py).
3. `fast` means that `YOLOv5DetDataPreprocessor` and `yolov5_collate` are used for data preprocessing, which is faster for training, but less flexible for multitasking. Recommended to use fast version config if you only care about object detection.
4. `SyncBN` means use SyncBN, `AMP` indicates training with mixed precision.
5. We use 8x A100 for training, and the single-GPU batch size is 16. This is different from the official code.
## Citation
```latex
@article{wang2022yolov7,
title={{YOLOv7}: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors},
author={Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark},
journal={arXiv preprint arXiv:2207.02696},
year={2022}
}
```

View File

@ -0,0 +1,83 @@
Collections:
- Name: YOLOv7
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Nesterov
- Weight Decay
- AMP
- Synchronize BN
Training Resources: 8x A100 GPUs
Architecture:
- EELAN
- PAFPN
- RepVGG
Paper:
URL: https://arxiv.org/abs/2207.02696
Title: 'YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors'
README: configs/yolov7/README.md
Code:
URL: https://github.com/open-mmlab/mmyolo/blob/v0.0.1/mmyolo/models/detectors/yolo_detector.py#L12
Version: v0.0.1
Models:
- Name: yolov7_tiny_syncbn_fast_8x16b-300e_coco
In Collection: YOLOv7
Config: configs/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py
Metadata:
Training Memory (GB): 2.7
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 37.5
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719-0ee5bbdf.pth
- Name: yolov7_l_syncbn_fast_8x16b-300e_coco
In Collection: YOLOv7
Config: configs/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco.py
Metadata:
Training Memory (GB): 10.3
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 50.9
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_l_syncbn_fast_8x16b-300e_coco/yolov7_l_syncbn_fast_8x16b-300e_coco_20221123_023601-8113c0eb.pth
- Name: yolov7_x_syncbn_fast_8x16b-300e_coco
In Collection: YOLOv7
Config: configs/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco.py
Metadata:
Training Memory (GB): 13.7
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 52.8
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_x_syncbn_fast_8x16b-300e_coco/yolov7_x_syncbn_fast_8x16b-300e_coco_20221124_215331-ef949a68.pth
- Name: yolov7_w-p6_syncbn_fast_8x16b-300e_coco
In Collection: YOLOv7
Config: configs/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco.py
Metadata:
Training Memory (GB): 27.0
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 54.1
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_w-p6_syncbn_fast_8x16b-300e_coco/yolov7_w-p6_syncbn_fast_8x16b-300e_coco_20221123_053031-a68ef9d2.pth
- Name: yolov7_e-p6_syncbn_fast_8x16b-300e_coco
In Collection: YOLOv7
Config: configs/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco.py
Metadata:
Training Memory (GB): 42.5
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 55.1
Weights: https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_e-p6_syncbn_fast_8x16b-300e_coco/yolov7_e-p6_syncbn_fast_8x16b-300e_coco_20221126_102636-34425033.pth

View File

@ -29,8 +29,72 @@ model = dict(
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
prior_generator=dict(base_sizes=anchors, strides=strides),
simota_candidate_topk=20, # note
# scaled based on number of detection layers
loss_cls=dict(loss_weight=0.3 *
(num_classes / 80 * 3 / num_det_layers)),
loss_bbox=dict(loss_weight=0.05 * (3 / num_det_layers)),
loss_obj=dict(loss_weight=0.7 *
((img_scale[0] / 640)**2 * 3 / num_det_layers)),
obj_level_weights=[4.0, 1.0, 0.25, 0.06]))
pre_transform = _base_.pre_transform
mosiac4_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
mosiac9_pipeline = [
dict(
type='Mosaic9',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
max_translate_ratio=0.2, # note
scaling_ratio_range=(0.1, 2.0), # note
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114)),
]
randchoice_mosaic_pipeline = dict(
type='RandomChoice',
transforms=[mosiac4_pipeline, mosiac9_pipeline],
prob=[0.8, 0.2])
train_pipeline = [
*pre_transform,
randchoice_mosaic_pipeline,
dict(
type='YOLOv5MixUp',
alpha=8.0, # note
beta=8.0, # note
prob=0.15,
pre_transform=[*pre_transform, randchoice_mosaic_pipeline]),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
@ -45,8 +109,10 @@ test_pipeline = [
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
# The only difference between P6 and P5 in terms of
# hyperparameters is lr_factor
default_hooks = dict(param_scheduler=dict(lr_factor=0.2))

View File

@ -167,6 +167,7 @@ class YOLOv5Head(BaseDenseHead):
reduction='mean',
loss_weight=1.0),
prior_match_thr: float = 4.0,
near_neighbor_thr: float = 0.5,
obj_level_weights: List[float] = [4.0, 1.0, 0.4],
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
@ -192,6 +193,7 @@ class YOLOv5Head(BaseDenseHead):
self.featmap_sizes = [torch.empty(1)] * self.num_levels
self.prior_match_thr = prior_match_thr
self.near_neighbor_thr = near_neighbor_thr
self.obj_level_weights = obj_level_weights
self.special_init()
@ -231,7 +233,7 @@ class YOLOv5Head(BaseDenseHead):
[0, 1], # up
[-1, 0], # right
[0, -1], # bottom
]).float() * 0.5
]).float()
self.register_buffer(
'grid_offset', grid_offset[:, None], persistent=False)
@ -534,9 +536,10 @@ class YOLOv5Head(BaseDenseHead):
# them as positive samples as well.
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < 0.5) &
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < 0.5) & (grid_xy > 1)).T
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
@ -552,7 +555,8 @@ class YOLOv5Head(BaseDenseHead):
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
-1), img_class_inds.long().T
grid_xy_long = (grid_xy - retained_offsets).long()
grid_xy_long = (grid_xy -
retained_offsets * self.near_neighbor_thr).long()
grid_x_inds, grid_y_inds = grid_xy_long.T
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -47,7 +47,6 @@ class YOLOv7HeadModule(YOLOv5HeadModule):
mi.bias.data = b.view(-1)
# TODO: to check
@MODELS.register_module()
class YOLOv7p6HeadModule(YOLOv5HeadModule):
"""YOLOv7Head head module used in YOLOv7."""
@ -56,12 +55,14 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
*args,
main_out_channels: Sequence[int] = [256, 512, 768, 1024],
aux_out_channels: Sequence[int] = [320, 640, 960, 1280],
use_aux: bool = True,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.03, eps=0.001),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
**kwargs):
self.main_out_channels = main_out_channels
self.aux_out_channels = aux_out_channels
self.use_aux = use_aux
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
super().__init__(*args, **kwargs)
@ -69,7 +70,6 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
def _init_layers(self):
"""initialize conv layers in YOLOv7 head."""
self.main_convs_pred = nn.ModuleList()
self.aux_convs_pred = nn.ModuleList()
for i in range(self.num_levels):
conv_pred = nn.Sequential(
ConvModule(
@ -86,17 +86,22 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
)
self.main_convs_pred.append(conv_pred)
aux_pred = nn.Sequential(
ConvModule(
self.in_channels[i],
self.aux_out_channels[i],
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(self.aux_out_channels[i],
self.num_base_priors * self.num_out_attrib, 1))
self.aux_convs_pred.append(aux_pred)
if self.use_aux:
self.aux_convs_pred = nn.ModuleList()
for i in range(self.num_levels):
aux_pred = nn.Sequential(
ConvModule(
self.in_channels[i],
self.aux_out_channels[i],
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(self.aux_out_channels[i],
self.num_base_priors * self.num_out_attrib, 1))
self.aux_convs_pred.append(aux_pred)
else:
self.aux_convs_pred = [None] * len(self.main_convs_pred)
def init_weights(self):
"""Initialize the bias of YOLOv5 head."""
@ -110,12 +115,13 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
aux = aux[1] # nn.Conv2d
b = aux.bias.data.view(3, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
if self.use_aux:
aux = aux[1] # nn.Conv2d
b = aux.bias.data.view(3, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
mi.bias.data = b.view(-1)
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
@ -132,7 +138,7 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
self.aux_convs_pred)
def forward_single(self, x: Tensor, convs: nn.Module,
aux_convs: nn.Module) \
aux_convs: Optional[nn.Module]) \
-> Tuple[Union[Tensor, List], Union[Tensor, List],
Union[Tensor, List]]:
"""Forward feature of a single scale level."""
@ -146,7 +152,7 @@ class YOLOv7p6HeadModule(YOLOv5HeadModule):
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
if not self.training:
if not self.training or not self.use_aux:
return cls_score, bbox_pred, objectness
else:
aux_pred_map = aux_convs(x)
@ -178,11 +184,13 @@ class YOLOv7Head(YOLOv5Head):
def __init__(self,
*args,
simota_candidate_topk: int = 10,
simota_candidate_topk: int = 20,
simota_iou_weight: float = 3.0,
simota_cls_weight: float = 1.0,
aux_loss_weights: float = 0.25,
**kwargs):
super().__init__(*args, **kwargs)
self.aux_loss_weights = aux_loss_weights
self.assigner = BatchYOLOv7Assigner(
num_classes=self.num_classes,
num_base_priors=self.num_base_priors,
@ -194,9 +202,9 @@ class YOLOv7Head(YOLOv5Head):
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
cls_scores: Sequence[Union[Tensor, List]],
bbox_preds: Sequence[Union[Tensor, List]],
objectnesses: Sequence[Union[Tensor, List]],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
@ -225,36 +233,91 @@ class YOLOv7Head(YOLOv5Head):
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
batch_size = cls_scores[0].shape[0]
device = cls_scores[0].device
loss_cls = torch.zeros(1, device=device)
loss_box = torch.zeros(1, device=device)
loss_obj = torch.zeros(1, device=device)
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
cls_scores)
scaled_factors = [
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
for head_pred in head_preds
]
if isinstance(cls_scores[0], Sequence):
with_aux = True
batch_size = cls_scores[0][0].shape[0]
device = cls_scores[0][0].device
# 1. Convert gt to norm xywh format
bbox_preds_main, bbox_preds_aux = zip(*bbox_preds)
objectnesses_main, objectnesses_aux = zip(*objectnesses)
cls_scores_main, cls_scores_aux = zip(*cls_scores)
head_preds = self._merge_predict_results(bbox_preds_main,
objectnesses_main,
cls_scores_main)
head_preds_aux = self._merge_predict_results(
bbox_preds_aux, objectnesses_aux, cls_scores_aux)
else:
with_aux = False
batch_size = cls_scores[0].shape[0]
device = cls_scores[0].device
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
cls_scores)
# Convert gt to norm xywh format
# (num_base_priors, num_batch_gt, 7)
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
# w_norm, h_norm, prior_idx)
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)
scaled_factors = [
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
for head_pred in head_preds
]
loss_cls, loss_obj, loss_box = self._calc_loss(
head_preds=head_preds,
head_preds_aux=None,
batch_targets_normed=batch_targets_normed,
near_neighbor_thr=self.near_neighbor_thr,
scaled_factors=scaled_factors,
batch_img_metas=batch_img_metas,
device=device)
if with_aux:
loss_cls_aux, loss_obj_aux, loss_box_aux = self._calc_loss(
head_preds=head_preds,
head_preds_aux=head_preds_aux,
batch_targets_normed=batch_targets_normed,
near_neighbor_thr=self.near_neighbor_thr * 2,
scaled_factors=scaled_factors,
batch_img_metas=batch_img_metas,
device=device)
loss_cls += self.aux_loss_weights * loss_cls_aux
loss_obj += self.aux_loss_weights * loss_obj_aux
loss_box += self.aux_loss_weights * loss_box_aux
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * batch_size * world_size,
loss_obj=loss_obj * batch_size * world_size,
loss_bbox=loss_box * batch_size * world_size)
def _calc_loss(self, head_preds, head_preds_aux, batch_targets_normed,
near_neighbor_thr, scaled_factors, batch_img_metas, device):
loss_cls = torch.zeros(1, device=device)
loss_box = torch.zeros(1, device=device)
loss_obj = torch.zeros(1, device=device)
assigner_results = self.assigner(
head_preds, batch_targets_normed,
batch_img_metas[0]['batch_input_shape'], self.priors_base_sizes,
self.grid_offset)
head_preds,
batch_targets_normed,
batch_img_metas[0]['batch_input_shape'],
self.priors_base_sizes,
self.grid_offset,
near_neighbor_thr=near_neighbor_thr)
# mlvl is mean multi_level
mlvl_positive_infos = assigner_results['mlvl_positive_infos']
mlvl_priors = assigner_results['mlvl_priors']
mlvl_targets_normed = assigner_results['mlvl_targets_normed']
# calc losses
if head_preds_aux is not None:
# This is mean calc aux branch loss
head_preds = head_preds_aux
for i, head_pred in enumerate(head_preds):
batch_inds, proir_idx, grid_x, grid_y = mlvl_positive_infos[i].T
num_pred_positive = batch_inds.shape[0]
@ -299,12 +362,7 @@ class YOLOv7Head(YOLOv5Head):
target_class)
else:
loss_cls += head_pred_positive[:, 5:].sum() * 0
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * batch_size * world_size,
loss_obj=loss_obj * batch_size * world_size,
loss_bbox=loss_box * batch_size * world_size)
return loss_cls, loss_obj, loss_box
def _merge_predict_results(self, bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],

View File

@ -49,8 +49,13 @@ class BatchYOLOv7Assigner(nn.Module):
self.cls_weight = cls_weight
@torch.no_grad()
def forward(self, pred_results, batch_targets_normed, batch_input_shape,
priors_base_sizes, grid_offset) -> dict:
def forward(self,
pred_results,
batch_targets_normed,
batch_input_shape,
priors_base_sizes,
grid_offset,
near_neighbor_thr=0.5) -> dict:
# (num_base_priors, num_batch_gt, 7)
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
# w_norm, h_norm, prior_idx)
@ -65,8 +70,16 @@ class BatchYOLOv7Assigner(nn.Module):
mlvl_priors=[] * num_levels,
mlvl_targets_normed=[] * num_levels)
# if near_neighbor_thr = 0.5 are mean the nearest
# 3 neighbors are also considered positive samples.
# if near_neighbor_thr = 1.0 are mean the nearest
# 5 neighbors are also considered positive samples.
mlvl_positive_infos, mlvl_priors = self.yolov5_assigner(
pred_results, batch_targets_normed, priors_base_sizes, grid_offset)
pred_results,
batch_targets_normed,
priors_base_sizes,
grid_offset,
near_neighbor_thr=near_neighbor_thr)
mlvl_positive_infos, mlvl_priors, \
mlvl_targets_normed = self.simota_assigner(
@ -85,8 +98,12 @@ class BatchYOLOv7Assigner(nn.Module):
mlvl_priors=mlvl_priors,
mlvl_targets_normed=mlvl_targets_normed)
def yolov5_assigner(self, pred_results, batch_targets_normed,
priors_base_sizes, grid_offset):
def yolov5_assigner(self,
pred_results,
batch_targets_normed,
priors_base_sizes,
grid_offset,
near_neighbor_thr=0.5):
num_batch_gts = batch_targets_normed.shape[1]
assert num_batch_gts > 0
@ -121,9 +138,10 @@ class BatchYOLOv7Assigner(nn.Module):
# Positive samples with additional neighbors
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < 0.5) &
left, up = ((batch_targets_cxcy % 1 < near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < 0.5) & (grid_xy > 1)).T
right, bottom = ((grid_xy % 1 < near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
batch_targets_scaled = batch_targets_scaled.repeat(
@ -138,6 +156,7 @@ class BatchYOLOv7Assigner(nn.Module):
# mlvl_positive_info: (num_matched_target, 4)
# 4 is mean (batch_idx, prior_idx, x_scaled, y_scaled)
mlvl_positive_info = batch_targets_scaled[:, [0, 6, 2, 3]]
retained_offsets = retained_offsets * near_neighbor_thr
mlvl_positive_info[:,
2:] = mlvl_positive_info[:,
2:] - retained_offsets