[Feature] Implement fast version of YOLOX (#518)

* Implement fast version of YOLOX

* config change

* Update yolox_head.py

* Update mmyolo/models/data_preprocessors/data_preprocessor.py

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

* Update mmyolo/models/data_preprocessors/data_preprocessor.py

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

* add test and modify faults

* fix lint

* fix lint

* modify metafile and README

* modify metafile and readme

* fix

* fix

* fix

* fix

* fix

* fix test

---------

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
pull/532/head^2
Youfu 2023-02-08 20:10:03 +08:00 committed by GitHub
parent 031e7450bc
commit 2813e89f44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 130 additions and 35 deletions

View File

@ -19,10 +19,10 @@ YOLOX-l model structure
## Results and Models
| Backbone | size | Mem (GB) | box AP | Config | Download |
| :--------: | :--: | :------: | :----: | :---------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
| Backbone | size | Mem (GB) | box AP | Config | Download |
| :--------: | :--: | :------: | :----: | :--------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
**Note**:

View File

@ -20,9 +20,9 @@ Collections:
Models:
- Name: yolox_tiny_8xb8-300e_coco
- Name: yolox_tiny_fast_8xb8-300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_tiny_8xb8-300e_coco.py
Config: configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 2.8
Epochs: 300
@ -32,9 +32,9 @@ Models:
Metrics:
box AP: 32.7
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth
- Name: yolox_s_8xb8-300e_coco
- Name: yolox_s_fast_8xb8-300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_s_8xb8-300e_coco.py
Config: configs/yolox/yolox_s_fast_8xb8-300e_coco.py
Metadata:
Training Memory (GB): 5.6
Epochs: 300

View File

@ -1,4 +1,4 @@
_base_ = './yolox_s_8xb8-300e_coco.py'
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
deepen_factor = 1.0
widen_factor = 1.0

View File

@ -1,4 +1,4 @@
_base_ = './yolox_s_8xb8-300e_coco.py'
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
deepen_factor = 0.67
widen_factor = 0.75

View File

@ -1,4 +1,4 @@
_base_ = './yolox_tiny_8xb8-300e_coco.py'
_base_ = './yolox_tiny_fast_8xb8-300e_coco.py'
deepen_factor = 0.33
widen_factor = 0.25

View File

@ -29,11 +29,11 @@ model = dict(
# TODO: Waiting for mmengine support
use_syncbn=False,
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
type='YOLOv5DetDataPreprocessor',
pad_size_divisor=32,
batch_augments=[
dict(
type='mmdet.BatchSyncRandomResize',
type='YOLOXBatchSyncRandomResize',
random_size_range=(480, 800),
size_divisor=32,
interval=10)
@ -157,6 +157,7 @@ train_dataloader = dict(
num_workers=train_num_workers,
persistent_workers=True,
pin_memory=True,
collate_fn=dict(type='yolov5_collate'),
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,

View File

@ -1,4 +1,4 @@
_base_ = './yolox_s_8xb8-300e_coco.py'
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
deepen_factor = 0.33
widen_factor = 0.375

View File

@ -1,4 +1,4 @@
_base_ = './yolox_s_8xb8-300e_coco.py'
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
deepen_factor = 1.33
widen_factor = 1.25

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_preprocessor import (PPYOLOEBatchRandomResize,
PPYOLOEDetDataPreprocessor,
YOLOv5DetDataPreprocessor)
YOLOv5DetDataPreprocessor,
YOLOXBatchSyncRandomResize)
__all__ = [
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
'PPYOLOEBatchRandomResize'
'PPYOLOEBatchRandomResize', 'YOLOXBatchSyncRandomResize'
]

View File

@ -16,6 +16,47 @@ CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
None]
@MODELS.register_module()
class YOLOXBatchSyncRandomResize(BatchSyncRandomResize):
"""YOLOX batch random resize.
Args:
random_size_range (tuple): The multi-scale random range during
multi-scale training.
interval (int): The iter interval of change
image size. Defaults to 10.
size_divisor (int): Image size divisible factor.
Defaults to 32.
"""
def forward(self, inputs: Tensor, data_samples: dict) -> Tensor and dict:
"""resize a batch of images and bboxes to shape ``self._input_size``"""
h, w = inputs.shape[-2:]
inputs = inputs.float()
assert isinstance(data_samples, dict)
if self._input_size is None:
self._input_size = (h, w)
scale_y = self._input_size[0] / h
scale_x = self._input_size[1] / w
if scale_x != 1 or scale_y != 1:
inputs = F.interpolate(
inputs,
size=self._input_size,
mode='bilinear',
align_corners=False)
data_samples['bboxes_labels'][:, 2::2] *= scale_x
data_samples['bboxes_labels'][:, 3::2] *= scale_y
message_hub = MessageHub.get_current_instance()
if (message_hub.get_info('iter') + 1) % self._interval == 0:
self._input_size = self._get_random_size(
aspect_ratio=float(w / h), device=inputs.device)
return inputs, data_samples
@MODELS.register_module()
class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
"""Rewrite collate_fn to get faster training speed.

View File

@ -265,7 +265,7 @@ class YOLOXHead(YOLOv5Head):
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_gt_instances: Tensor,
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
@ -297,6 +297,9 @@ class YOLOXHead(YOLOv5Head):
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
batch_gt_instances = self.gt_instances_preprocess(
batch_gt_instances, len(batch_img_metas))
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
@ -484,3 +487,28 @@ class YOLOXHead(YOLOv5Head):
bbox_aux_target[:,
2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
return bbox_aux_target
@staticmethod
def gt_instances_preprocess(batch_gt_instances: Tensor,
batch_size: int) -> List[InstanceData]:
"""Split batch_gt_instances with batch size.
Args:
batch_gt_instances (Tensor): Ground truth
a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
batch_size (int): Batch size.
Returns:
List: batch gt instances data, shape [batch_size, InstanceData]
"""
# faster version
batch_instance_list = []
for i in range(batch_size):
batch_gt_instance_ = InstanceData()
single_batch_instance = \
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
batch_gt_instance_.labels = single_batch_instance[:, 1]
batch_instance_list.append(batch_gt_instance_)
return batch_instance_list

View File

@ -6,7 +6,8 @@ from mmdet.structures import DetDataSample
from mmengine import MessageHub
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
from mmyolo.models.data_preprocessors import (YOLOv5DetDataPreprocessor,
YOLOXBatchSyncRandomResize)
from mmyolo.utils import register_all_modules
register_all_modules()
@ -125,3 +126,31 @@ class TestPPYOLOEDetDataPreprocessor(TestCase):
# data_samples must be list
with self.assertRaises(AssertionError):
processor(data, training=True)
class TestYOLOXDetDataPreprocessor(TestCase):
def test_batch_sync_random_size(self):
processor = YOLOXBatchSyncRandomResize(
random_size_range=(480, 800), size_divisor=32, interval=1)
self.assertTrue(isinstance(processor, YOLOXBatchSyncRandomResize))
message_hub = MessageHub.get_instance(
'test_yolox_batch_sync_random_resize')
message_hub.update_info('iter', 0)
# test training
inputs = torch.randint(0, 256, (4, 3, 10, 11))
data_samples = {'bboxes_labels': torch.randint(0, 11, (18, 6)).float()}
inputs, data_samples = processor(inputs, data_samples)
self.assertIn('bboxes_labels', data_samples)
self.assertIsInstance(data_samples['bboxes_labels'], torch.Tensor)
self.assertIsInstance(inputs, torch.Tensor)
inputs = torch.randint(0, 256, (4, 3, 10, 11))
data_samples = DetDataSample()
# data_samples must be dict
with self.assertRaises(AssertionError):
processor(inputs, data_samples)

View File

@ -4,7 +4,6 @@ from unittest import TestCase
import torch
from mmengine.config import Config
from mmengine.model import bias_init_with_prob
from mmengine.structures import InstanceData
from mmengine.testing import assert_allclose
from mmyolo.models.dense_heads import YOLOXHead
@ -98,11 +97,10 @@ class TestYOLOXHead(TestCase):
# Test that empty ground truth encourages the network to predict
# background
gt_instances = InstanceData(
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
gt_instances = torch.empty((0, 6))
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
objectnesses, [gt_instances],
objectnesses, gt_instances,
img_metas)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
@ -122,12 +120,11 @@ class TestYOLOXHead(TestCase):
# for random inputs
head = YOLOXHead(head_module=self.head_module, train_cfg=train_cfg)
head.use_bbox_aux = True
gt_instances = InstanceData(
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
labels=torch.LongTensor([2]))
gt_instances = torch.Tensor(
[[0, 2, 23.6667, 23.8757, 238.6326, 151.8874]])
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
[gt_instances], img_metas)
gt_instances, img_metas)
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
onegt_obj_loss = one_gt_losses['loss_obj'].sum()
@ -142,11 +139,10 @@ class TestYOLOXHead(TestCase):
'l1 loss should be non-zero')
# Test groud truth out of bound
gt_instances = InstanceData(
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
labels=torch.LongTensor([2]))
gt_instances = torch.Tensor(
[[0, 2, s * 4, s * 4, s * 4 + 10, s * 4 + 10]])
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
objectnesses, [gt_instances],
objectnesses, gt_instances,
img_metas)
# When gt_bboxes out of bound, the assign results should be empty,
# so the cls and bbox loss should be zero.

View File

@ -21,7 +21,7 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
'yolox/yolox_tiny_8xb8-300e_coco.py',
'yolox/yolox_tiny_fast_8xb8-300e_coco.py',
'rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py',
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py',
'yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py'
@ -38,7 +38,6 @@ class TestSingleStageDetector(TestCase):
@parameterized.expand([
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
@ -79,7 +78,7 @@ class TestSingleStageDetector(TestCase):
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
@ -112,7 +111,7 @@ class TestSingleStageDetector(TestCase):
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
'cpu')),
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))