mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Implement fast version of YOLOX (#518)
* Implement fast version of YOLOX * config change * Update yolox_head.py * Update mmyolo/models/data_preprocessors/data_preprocessor.py Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> * Update mmyolo/models/data_preprocessors/data_preprocessor.py Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> * add test and modify faults * fix lint * fix lint * modify metafile and README * modify metafile and readme * fix * fix * fix * fix * fix * fix test --------- Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>pull/532/head^2
parent
031e7450bc
commit
2813e89f44
|
@ -19,10 +19,10 @@ YOLOX-l model structure
|
|||
|
||||
## Results and Models
|
||||
|
||||
| Backbone | size | Mem (GB) | box AP | Config | Download |
|
||||
| :--------: | :--: | :------: | :----: | :---------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
|
||||
| Backbone | size | Mem (GB) | box AP | Config | Download |
|
||||
| :--------: | :--: | :------: | :----: | :--------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| YOLOX-tiny | 416 | 2.8 | 32.7 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908.log.json) |
|
||||
| YOLOX-s | 640 | 5.6 | 40.8 | [config](https://github.com/open-mmlab/mmyolo/tree/master/configs/yolox/yolox_s_fast_8xb8-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738-d7e60cb2.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_8xb8-300e_coco/yolox_s_8xb8-300e_coco_20220917_030738.log.json) |
|
||||
|
||||
**Note**:
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ Collections:
|
|||
|
||||
|
||||
Models:
|
||||
- Name: yolox_tiny_8xb8-300e_coco
|
||||
- Name: yolox_tiny_fast_8xb8-300e_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_tiny_8xb8-300e_coco.py
|
||||
Config: configs/yolox/yolox_tiny_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 2.8
|
||||
Epochs: 300
|
||||
|
@ -32,9 +32,9 @@ Models:
|
|||
Metrics:
|
||||
box AP: 32.7
|
||||
Weights: https://download.openmmlab.com/mmyolo/v0/yolox/yolox_tiny_8xb8-300e_coco/yolox_tiny_8xb8-300e_coco_20220919_090908-0e40a6fc.pth
|
||||
- Name: yolox_s_8xb8-300e_coco
|
||||
- Name: yolox_s_fast_8xb8-300e_coco
|
||||
In Collection: YOLOX
|
||||
Config: configs/yolox/yolox_s_8xb8-300e_coco.py
|
||||
Config: configs/yolox/yolox_s_fast_8xb8-300e_coco.py
|
||||
Metadata:
|
||||
Training Memory (GB): 5.6
|
||||
Epochs: 300
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './yolox_s_8xb8-300e_coco.py'
|
||||
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
||||
|
||||
deepen_factor = 1.0
|
||||
widen_factor = 1.0
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './yolox_s_8xb8-300e_coco.py'
|
||||
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.67
|
||||
widen_factor = 0.75
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './yolox_tiny_8xb8-300e_coco.py'
|
||||
_base_ = './yolox_tiny_fast_8xb8-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.25
|
|
@ -29,11 +29,11 @@ model = dict(
|
|||
# TODO: Waiting for mmengine support
|
||||
use_syncbn=False,
|
||||
data_preprocessor=dict(
|
||||
type='mmdet.DetDataPreprocessor',
|
||||
type='YOLOv5DetDataPreprocessor',
|
||||
pad_size_divisor=32,
|
||||
batch_augments=[
|
||||
dict(
|
||||
type='mmdet.BatchSyncRandomResize',
|
||||
type='YOLOXBatchSyncRandomResize',
|
||||
random_size_range=(480, 800),
|
||||
size_divisor=32,
|
||||
interval=10)
|
||||
|
@ -157,6 +157,7 @@ train_dataloader = dict(
|
|||
num_workers=train_num_workers,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
collate_fn=dict(type='yolov5_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './yolox_s_8xb8-300e_coco.py'
|
||||
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
||||
|
||||
deepen_factor = 0.33
|
||||
widen_factor = 0.375
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = './yolox_s_8xb8-300e_coco.py'
|
||||
_base_ = './yolox_s_fast_8xb8-300e_coco.py'
|
||||
|
||||
deepen_factor = 1.33
|
||||
widen_factor = 1.25
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .data_preprocessor import (PPYOLOEBatchRandomResize,
|
||||
PPYOLOEDetDataPreprocessor,
|
||||
YOLOv5DetDataPreprocessor)
|
||||
YOLOv5DetDataPreprocessor,
|
||||
YOLOXBatchSyncRandomResize)
|
||||
|
||||
__all__ = [
|
||||
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
|
||||
'PPYOLOEBatchRandomResize'
|
||||
'PPYOLOEBatchRandomResize', 'YOLOXBatchSyncRandomResize'
|
||||
]
|
||||
|
|
|
@ -16,6 +16,47 @@ CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
|
|||
None]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOXBatchSyncRandomResize(BatchSyncRandomResize):
|
||||
"""YOLOX batch random resize.
|
||||
|
||||
Args:
|
||||
random_size_range (tuple): The multi-scale random range during
|
||||
multi-scale training.
|
||||
interval (int): The iter interval of change
|
||||
image size. Defaults to 10.
|
||||
size_divisor (int): Image size divisible factor.
|
||||
Defaults to 32.
|
||||
"""
|
||||
|
||||
def forward(self, inputs: Tensor, data_samples: dict) -> Tensor and dict:
|
||||
"""resize a batch of images and bboxes to shape ``self._input_size``"""
|
||||
h, w = inputs.shape[-2:]
|
||||
inputs = inputs.float()
|
||||
assert isinstance(data_samples, dict)
|
||||
|
||||
if self._input_size is None:
|
||||
self._input_size = (h, w)
|
||||
scale_y = self._input_size[0] / h
|
||||
scale_x = self._input_size[1] / w
|
||||
if scale_x != 1 or scale_y != 1:
|
||||
inputs = F.interpolate(
|
||||
inputs,
|
||||
size=self._input_size,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
data_samples['bboxes_labels'][:, 2::2] *= scale_x
|
||||
data_samples['bboxes_labels'][:, 3::2] *= scale_y
|
||||
|
||||
message_hub = MessageHub.get_current_instance()
|
||||
if (message_hub.get_info('iter') + 1) % self._interval == 0:
|
||||
self._input_size = self._get_random_size(
|
||||
aspect_ratio=float(w / h), device=inputs.device)
|
||||
|
||||
return inputs, data_samples
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
|
||||
"""Rewrite collate_fn to get faster training speed.
|
||||
|
|
|
@ -265,7 +265,7 @@ class YOLOXHead(YOLOv5Head):
|
|||
cls_scores: Sequence[Tensor],
|
||||
bbox_preds: Sequence[Tensor],
|
||||
objectnesses: Sequence[Tensor],
|
||||
batch_gt_instances: Sequence[InstanceData],
|
||||
batch_gt_instances: Tensor,
|
||||
batch_img_metas: Sequence[dict],
|
||||
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
||||
"""Calculate the loss based on the features extracted by the detection
|
||||
|
@ -297,6 +297,9 @@ class YOLOXHead(YOLOv5Head):
|
|||
if batch_gt_instances_ignore is None:
|
||||
batch_gt_instances_ignore = [None] * num_imgs
|
||||
|
||||
batch_gt_instances = self.gt_instances_preprocess(
|
||||
batch_gt_instances, len(batch_img_metas))
|
||||
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes,
|
||||
|
@ -484,3 +487,28 @@ class YOLOXHead(YOLOv5Head):
|
|||
bbox_aux_target[:,
|
||||
2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
|
||||
return bbox_aux_target
|
||||
|
||||
@staticmethod
|
||||
def gt_instances_preprocess(batch_gt_instances: Tensor,
|
||||
batch_size: int) -> List[InstanceData]:
|
||||
"""Split batch_gt_instances with batch size.
|
||||
|
||||
Args:
|
||||
batch_gt_instances (Tensor): Ground truth
|
||||
a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
|
||||
batch_size (int): Batch size.
|
||||
|
||||
Returns:
|
||||
List: batch gt instances data, shape [batch_size, InstanceData]
|
||||
"""
|
||||
# faster version
|
||||
batch_instance_list = []
|
||||
for i in range(batch_size):
|
||||
batch_gt_instance_ = InstanceData()
|
||||
single_batch_instance = \
|
||||
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
||||
batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
|
||||
batch_gt_instance_.labels = single_batch_instance[:, 1]
|
||||
batch_instance_list.append(batch_gt_instance_)
|
||||
|
||||
return batch_instance_list
|
||||
|
|
|
@ -6,7 +6,8 @@ from mmdet.structures import DetDataSample
|
|||
from mmengine import MessageHub
|
||||
|
||||
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
|
||||
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
|
||||
from mmyolo.models.data_preprocessors import (YOLOv5DetDataPreprocessor,
|
||||
YOLOXBatchSyncRandomResize)
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
@ -125,3 +126,31 @@ class TestPPYOLOEDetDataPreprocessor(TestCase):
|
|||
# data_samples must be list
|
||||
with self.assertRaises(AssertionError):
|
||||
processor(data, training=True)
|
||||
|
||||
|
||||
class TestYOLOXDetDataPreprocessor(TestCase):
|
||||
|
||||
def test_batch_sync_random_size(self):
|
||||
processor = YOLOXBatchSyncRandomResize(
|
||||
random_size_range=(480, 800), size_divisor=32, interval=1)
|
||||
self.assertTrue(isinstance(processor, YOLOXBatchSyncRandomResize))
|
||||
message_hub = MessageHub.get_instance(
|
||||
'test_yolox_batch_sync_random_resize')
|
||||
message_hub.update_info('iter', 0)
|
||||
|
||||
# test training
|
||||
inputs = torch.randint(0, 256, (4, 3, 10, 11))
|
||||
data_samples = {'bboxes_labels': torch.randint(0, 11, (18, 6)).float()}
|
||||
|
||||
inputs, data_samples = processor(inputs, data_samples)
|
||||
|
||||
self.assertIn('bboxes_labels', data_samples)
|
||||
self.assertIsInstance(data_samples['bboxes_labels'], torch.Tensor)
|
||||
self.assertIsInstance(inputs, torch.Tensor)
|
||||
|
||||
inputs = torch.randint(0, 256, (4, 3, 10, 11))
|
||||
data_samples = DetDataSample()
|
||||
|
||||
# data_samples must be dict
|
||||
with self.assertRaises(AssertionError):
|
||||
processor(inputs, data_samples)
|
||||
|
|
|
@ -4,7 +4,6 @@ from unittest import TestCase
|
|||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.model import bias_init_with_prob
|
||||
from mmengine.structures import InstanceData
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
from mmyolo.models.dense_heads import YOLOXHead
|
||||
|
@ -98,11 +97,10 @@ class TestYOLOXHead(TestCase):
|
|||
|
||||
# Test that empty ground truth encourages the network to predict
|
||||
# background
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.empty((0, 4)), labels=torch.LongTensor([]))
|
||||
gt_instances = torch.empty((0, 6))
|
||||
|
||||
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
objectnesses, [gt_instances],
|
||||
objectnesses, gt_instances,
|
||||
img_metas)
|
||||
# When there is no truth, the cls loss should be nonzero but there
|
||||
# should be no box loss.
|
||||
|
@ -122,12 +120,11 @@ class TestYOLOXHead(TestCase):
|
|||
# for random inputs
|
||||
head = YOLOXHead(head_module=self.head_module, train_cfg=train_cfg)
|
||||
head.use_bbox_aux = True
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
|
||||
labels=torch.LongTensor([2]))
|
||||
gt_instances = torch.Tensor(
|
||||
[[0, 2, 23.6667, 23.8757, 238.6326, 151.8874]])
|
||||
|
||||
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds, objectnesses,
|
||||
[gt_instances], img_metas)
|
||||
gt_instances, img_metas)
|
||||
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
|
||||
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
|
||||
onegt_obj_loss = one_gt_losses['loss_obj'].sum()
|
||||
|
@ -142,11 +139,10 @@ class TestYOLOXHead(TestCase):
|
|||
'l1 loss should be non-zero')
|
||||
|
||||
# Test groud truth out of bound
|
||||
gt_instances = InstanceData(
|
||||
bboxes=torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]]),
|
||||
labels=torch.LongTensor([2]))
|
||||
gt_instances = torch.Tensor(
|
||||
[[0, 2, s * 4, s * 4, s * 4 + 10, s * 4 + 10]])
|
||||
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||
objectnesses, [gt_instances],
|
||||
objectnesses, gt_instances,
|
||||
img_metas)
|
||||
# When gt_bboxes out of bound, the assign results should be empty,
|
||||
# so the cls and bbox loss should be zero.
|
||||
|
|
|
@ -21,7 +21,7 @@ class TestSingleStageDetector(TestCase):
|
|||
@parameterized.expand([
|
||||
'yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py',
|
||||
'yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py',
|
||||
'yolox/yolox_tiny_8xb8-300e_coco.py',
|
||||
'yolox/yolox_tiny_fast_8xb8-300e_coco.py',
|
||||
'rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py',
|
||||
'yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py',
|
||||
'yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
@ -38,7 +38,6 @@ class TestSingleStageDetector(TestCase):
|
|||
|
||||
@parameterized.expand([
|
||||
('yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_s_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
|
||||
|
@ -79,7 +78,7 @@ class TestSingleStageDetector(TestCase):
|
|||
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
|
||||
'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
|
||||
|
@ -112,7 +111,7 @@ class TestSingleStageDetector(TestCase):
|
|||
('yolov5/yolov5_n-v61_syncbn_fast_8xb16-300e_coco.py', ('cuda',
|
||||
'cpu')),
|
||||
('yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolox/yolox_tiny_fast_8xb8-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco.py', ('cuda', 'cpu')),
|
||||
('rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py', ('cuda', 'cpu')),
|
||||
('yolov8/yolov8_n_syncbn_fast_8xb16-500e_coco.py', ('cuda', 'cpu'))
|
||||
|
|
Loading…
Reference in New Issue