mirror of https://github.com/open-mmlab/mmyolo.git
Add change log of v0.5.0 (#612)
* update * update * update * update * add configs * update * add tta * updatepull/617/head
parent
30cc772524
commit
e32838abe1
24
README.md
24
README.md
|
@ -77,10 +77,17 @@ English | [简体中文](README_zh-CN.md)
|
|||
|
||||
## 🥳 🚀 What's New [🔝](#-table-of-contents)
|
||||
|
||||
💎 **v0.4.0** was released on 18/1/2023:
|
||||
💎 **v0.5.0** was released on 2/3/2023:
|
||||
|
||||
1. Implemented [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) object detection model, and supports model deployment in [projects/easydeploy](https://github.com/open-mmlab/mmyolo/blob/dev/projects/easydeploy)
|
||||
2. Added Chinese and English versions of [Algorithm principles and implementation with YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/docs/en/algorithm_descriptions/yolov8_description.md)
|
||||
1. Support [RTMDet-R](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/README.md#rotated-object-detection) rotated object detection
|
||||
2. Support for using mask annotation to improve [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) object detection performance
|
||||
3. Support [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/razor/subnets/README.md) searchable NAS sub-network as the backbone of YOLO series algorithm
|
||||
4. Support calling [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/distillation/README.md) to distill the knowledge of RTMDet
|
||||
5. [MMYOLO](https://mmyolo.readthedocs.io/zh_CN/dev/) document structure optimization, comprehensive content upgrade
|
||||
6. Improve YOLOX mAP and training speed based on RTMDet training hyperparameters
|
||||
7. Support calculation of model parameters and FLOPs, provide GPU latency data on T4 devices, and update [Model Zoo](https://github.com/open-mmlab/mmyolo/blob/dev/docs/en/model_zoo.md)
|
||||
8. Support test-time augmentation (TTA)
|
||||
9. Support RTMDet, YOLOv8 and YOLOv7 assigner visualization
|
||||
|
||||
For release history and update details, please refer to [changelog](https://mmyolo.readthedocs.io/en/latest/notes/changelog.html).
|
||||
|
||||
|
@ -102,7 +109,7 @@ We are excited to announce our latest work on real-time object recognition tasks
|
|||
<img src="https://user-images.githubusercontent.com/12907710/208044554-1e8de6b5-48d8-44e4-a7b5-75076c7ebb71.png"/>
|
||||
</div>
|
||||
|
||||
MMYOLO currently only implements the object detection algorithm, but it has a significant training acceleration compared to the MMDeteciton version. The training speed is 2.6 times faster than the previous version.
|
||||
MMYOLO currently implements the object detection and rotated object detection algorithm, but it has a significant training acceleration compared to the MMDeteciton version. The training speed is 2.6 times faster than the previous version.
|
||||
|
||||
## 📖 Introduction [🔝](#-table-of-contents)
|
||||
|
||||
|
@ -138,8 +145,8 @@ And the figure of P6 model is in [model_design.md](docs/en/recommended_topics/mo
|
|||
MMYOLO relies on PyTorch, MMCV, MMEngine, and MMDetection. Below are quick steps for installation. Please refer to the [Install Guide](docs/en/get_started/installation.md) for more detailed instructions.
|
||||
|
||||
```shell
|
||||
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate open-mmlab
|
||||
conda create -n mmyolo python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate mmyolo
|
||||
pip install openmim
|
||||
mim install "mmengine>=0.6.0"
|
||||
mim install "mmcv>=2.0.0rc4,<2.1.0"
|
||||
|
@ -258,6 +265,10 @@ For different parts from MMDetection, we have also prepared user guides and adva
|
|||
|
||||
## 📊 Overview of Benchmark and Model Zoo [🔝](#-table-of-contents)
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222087414-168175cc-dae6-4c5c-a8e3-3109a152dd19.png"/>
|
||||
</div>
|
||||
|
||||
Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
||||
|
||||
<details open>
|
||||
|
@ -274,6 +285,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
|||
- [x] [YOLOv5](configs/yolov5)
|
||||
- [x] [YOLOX](configs/yolox)
|
||||
- [x] [RTMDet](configs/rtmdet)
|
||||
- [x] [RTMDet-Rotated](configs/rtmdet)
|
||||
- [x] [YOLOv6](configs/yolov6)
|
||||
- [x] [YOLOv7](configs/yolov7)
|
||||
- [x] [PPYOLOE](configs/ppyoloe)
|
||||
|
|
|
@ -78,10 +78,17 @@
|
|||
|
||||
## 🥳 🚀 最新进展 [🔝](#-table-of-contents)
|
||||
|
||||
💎 **v0.4.0** 版本已经在 2023.1.18 发布:
|
||||
💎 **v0.5.0** 版本已经在 2023.3.2 发布:
|
||||
|
||||
1. 实现了 [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) 目标检测模型,并通过 [projects/easydeploy](https://github.com/open-mmlab/mmyolo/blob/dev/projects/easydeploy) 支持了模型部署
|
||||
2. 新增了中英文版本的 [YOLOv8 原理和实现全解析文档](https://github.com/open-mmlab/mmyolo/blob/dev/docs/zh_cn/algorithm_descriptions/yolov8_description.md)
|
||||
1. 支持了 [RTMDet-R](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/README.md#rotated-object-detection) 旋转框目标检测任务和算法
|
||||
2. [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) 支持使用 mask 标注提升目标检测模型性能
|
||||
3. 支持 [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/razor/subnets/README.md) 搜索的 NAS 子网络作为 YOLO 系列算法的 backbone
|
||||
4. 支持调用 [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/distillation/README.md) 对 RTMDet 进行知识蒸馏
|
||||
5. [MMYOLO](https://mmyolo.readthedocs.io/zh_CN/dev/) 文档结构优化,内容全面升级
|
||||
6. 基于 RTMDet 训练超参提升 YOLOX 精度和训练速度
|
||||
7. 支持模型参数量、FLOPs 计算和提供 T4 设备上 GPU 延时数据,并更新了 [Model Zoo](https://github.com/open-mmlab/mmyolo/blob/dev/docs/zh_cn/model_zoo.md)
|
||||
8. 支持测试时增强 TTA
|
||||
9. 支持 RTMDet、YOLOv8 和 YOLOv7 assigner 可视化
|
||||
|
||||
我们提供了实用的**脚本命令速查表**
|
||||
|
||||
|
@ -123,7 +130,7 @@
|
|||
<img src="https://user-images.githubusercontent.com/12907710/208044554-1e8de6b5-48d8-44e4-a7b5-75076c7ebb71.png"/>
|
||||
</div>
|
||||
|
||||
MMYOLO 中目前仅仅实现了目标检测算法,但是相比 MMDeteciton 版本有显著训练加速,训练速度相比原先版本提升 2.6 倍。
|
||||
MMYOLO 中目前实现了目标检测和旋转框目标检测算法,但是相比 MMDeteciton 版本有显著训练加速,训练速度相比原先版本提升 2.6 倍。
|
||||
|
||||
## 📖 简介 [🔝](#-table-of-contents)
|
||||
|
||||
|
@ -159,8 +166,8 @@ P6 模型图详见 [model_design.md](docs/zh_cn/recommended_topics/model_design.
|
|||
MMYOLO 依赖 PyTorch, MMCV, MMEngine 和 MMDetection,以下是安装的简要步骤。 更详细的安装指南请参考[安装文档](docs/zh_cn/get_started/installation.md)。
|
||||
|
||||
```shell
|
||||
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate open-mmlab
|
||||
conda create -n mmyolo python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
|
||||
conda activate mmyolo
|
||||
pip install openmim
|
||||
mim install "mmengine>=0.6.0"
|
||||
mim install "mmcv>=2.0.0rc4,<2.1.0"
|
||||
|
@ -280,6 +287,10 @@ MMYOLO 用法和 MMDetection 几乎一致,所有教程都是通用的,你也
|
|||
|
||||
## 📊 基准测试和模型库 [🔝](#-table-of-contents)
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222087414-168175cc-dae6-4c5c-a8e3-3109a152dd19.png"/>
|
||||
</div>
|
||||
|
||||
测试结果和模型可以在 [模型库](docs/zh_cn/model_zoo.md) 中找到。
|
||||
|
||||
<details open>
|
||||
|
@ -296,6 +307,7 @@ MMYOLO 用法和 MMDetection 几乎一致,所有教程都是通用的,你也
|
|||
- [x] [YOLOv5](configs/yolov5)
|
||||
- [x] [YOLOX](configs/yolox)
|
||||
- [x] [RTMDet](configs/rtmdet)
|
||||
- [x] [RTMDet-Rotated](configs/rtmdet)
|
||||
- [x] [YOLOv6](configs/yolov6)
|
||||
- [x] [YOLOv7](configs/yolov7)
|
||||
- [x] [PPYOLOE](configs/ppyoloe)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Compared to other same scale models, this configuration consumes too much
|
||||
# GPU memory and is not validated for now
|
||||
_base_ = 'ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
num_last_epochs = 5
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 2
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(head_module=dict(num_classes=num_classes)),
|
||||
train_cfg=dict(
|
||||
initial_assigner=dict(num_classes=num_classes),
|
||||
assigner=dict(num_classes=num_classes)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
default_hooks = dict(
|
||||
param_scheduler=dict(
|
||||
warmup_min_iter=10,
|
||||
warmup_epochs=3,
|
||||
total_epochs=int(max_epochs * 1.2)))
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -0,0 +1,70 @@
|
|||
_base_ = 'rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
num_epochs_stage2 = 5
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 4
|
||||
val_batch_size_per_gpu = 1
|
||||
val_num_workers = 2
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(head_module=dict(num_classes=num_classes)),
|
||||
train_cfg=dict(assigner=dict(num_classes=num_classes)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size_per_gpu,
|
||||
num_workers=val_num_workers,
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=_base_.lr_start_factor,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=30),
|
||||
dict(
|
||||
# use cosine lr from 150 to 300 epoch
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=_base_.base_lr * 0.05,
|
||||
begin=max_epochs // 2,
|
||||
end=max_epochs,
|
||||
T_max=max_epochs // 2,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
_base_.custom_hooks[1].switch_epoch = max_epochs - num_epochs_stage2
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -36,17 +36,21 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/trainval.json',
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/trainval.json')
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
param_scheduler=dict(max_epochs=max_epochs),
|
||||
# The warmup_mim_iter parameter is critical.
|
||||
# The default value is 1000 which is not suitable for cat datasets.
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = './yolov6_s_syncbn_fast_8xb32-400e_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 4
|
||||
num_last_epochs = 5
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(head_module=dict(num_classes=num_classes)),
|
||||
train_cfg=dict(
|
||||
initial_assigner=dict(num_classes=num_classes),
|
||||
assigner=dict(num_classes=num_classes)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
_base_.custom_hooks[1].switch_epoch = max_epochs - num_last_epochs
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
# The warmup_mim_iter parameter is critical.
|
||||
# The default value is 1000 which is not suitable for cat datasets.
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(
|
||||
max_epochs=max_epochs,
|
||||
val_interval=10,
|
||||
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = 'yolov7_tiny_syncbn_fast_8x16b-300e_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
anchors = [
|
||||
[(68, 69), (154, 91), (143, 162)], # P3/8
|
||||
[(242, 160), (189, 287), (391, 207)], # P4/16
|
||||
[(353, 337), (539, 341), (443, 432)] # P5/32
|
||||
]
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 4
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/yolov7/yolov7_tiny_syncbn_fast_8x16b-300e_coco/yolov7_tiny_syncbn_fast_8x16b-300e_coco_20221126_102719-0ee5bbdf.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(
|
||||
head_module=dict(num_classes=num_classes),
|
||||
prior_generator=dict(base_sizes=anchors)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
# The warmup_mim_iter parameter is critical.
|
||||
# The default value is 1000 which is not suitable for cat datasets.
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -0,0 +1,52 @@
|
|||
_base_ = 'yolov8_s_syncbn_fast_8xb16-500e_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
close_mosaic_epochs = 5
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 4
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/yolov8/yolov8_s_syncbn_fast_8xb16-500e_coco/yolov8_s_syncbn_fast_8xb16-500e_coco_20230117_180101-5aa5f0f1.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(head_module=dict(num_classes=num_classes)),
|
||||
train_cfg=dict(assigner=dict(num_classes=num_classes)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
_base_.custom_hooks[1].switch_epoch = max_epochs - close_mosaic_epochs
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
# The warmup_mim_iter parameter is critical.
|
||||
# The default value is 1000 which is not suitable for cat datasets.
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = './yolox_s_fast_8xb32-300e-rtmdet-hyp_coco.py'
|
||||
|
||||
data_root = './data/cat/'
|
||||
class_name = ('cat', )
|
||||
num_classes = len(class_name)
|
||||
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
|
||||
|
||||
num_last_epochs = 5
|
||||
|
||||
max_epochs = 40
|
||||
train_batch_size_per_gpu = 12
|
||||
train_num_workers = 4
|
||||
|
||||
load_from = 'https://download.openmmlab.com/mmyolo/v0/yolox/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco/yolox_s_fast_8xb32-300e-rtmdet-hyp_coco_20230210_134645-3a8dfbd7.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(frozen_stages=4),
|
||||
bbox_head=dict(head_module=dict(num_classes=num_classes)))
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu,
|
||||
num_workers=train_num_workers,
|
||||
dataset=dict(
|
||||
data_root=data_root,
|
||||
metainfo=metainfo,
|
||||
ann_file='annotations/trainval.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
# use quadratic formula to warm up 3 epochs
|
||||
# and lr is updated by iteration
|
||||
# TODO: fix default scope in get function
|
||||
type='mmdet.QuadraticWarmupLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=3,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
# use cosine lr from 5 to 35 epoch
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=_base_.base_lr * 0.05,
|
||||
begin=5,
|
||||
T_max=max_epochs - num_last_epochs,
|
||||
end=max_epochs - num_last_epochs,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
# use fixed lr during last num_last_epochs epochs
|
||||
type='ConstantLR',
|
||||
by_epoch=True,
|
||||
factor=1,
|
||||
begin=max_epochs - num_last_epochs,
|
||||
end=max_epochs,
|
||||
)
|
||||
]
|
||||
|
||||
_base_.custom_hooks[0].num_last_epochs = num_last_epochs
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
# visualizer = dict(vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]) # noqa
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import ProgressBar, path
|
||||
|
||||
|
@ -29,6 +31,10 @@ def parse_args():
|
|||
'--deploy',
|
||||
action='store_true',
|
||||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--tta',
|
||||
action='store_true',
|
||||
help='Whether to use test time augmentation')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
|
@ -50,9 +56,37 @@ def main():
|
|||
if args.to_labelme and args.show:
|
||||
raise RuntimeError('`--to-labelme` or `--show` only '
|
||||
'can choose one at the same time.')
|
||||
config = args.config
|
||||
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
|
||||
if args.tta:
|
||||
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
|
||||
" Can't use tta !"
|
||||
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
|
||||
"in config. Can't use tta !"
|
||||
config.model = ConfigDict(**config.tta_model, module=config.model)
|
||||
test_data_cfg = config.test_dataloader.dataset
|
||||
while 'dataset' in test_data_cfg:
|
||||
test_data_cfg = test_data_cfg['dataset']
|
||||
|
||||
# batch_shapes_cfg will force control the size of the output image,
|
||||
# it is not compatible with tta.
|
||||
if 'batch_shapes_cfg' in test_data_cfg:
|
||||
test_data_cfg.batch_shapes_cfg = None
|
||||
test_data_cfg.pipeline = config.tta_pipeline
|
||||
|
||||
# TODO: TTA mode will error if cfg_options is not set.
|
||||
# This is an mmdet issue and needs to be fixed later.
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
model = init_detector(
|
||||
config, args.checkpoint, device=args.device, cfg_options={})
|
||||
|
||||
if args.deploy:
|
||||
switch_to_deploy(model)
|
||||
|
|
|
@ -14,10 +14,12 @@ python demo/large_image_demo.py \
|
|||
import os
|
||||
import random
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmdet.apis import inference_detector, init_detector
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
|
@ -50,6 +52,10 @@ def parse_args():
|
|||
'--deploy',
|
||||
action='store_true',
|
||||
help='Switch model to deployment mode')
|
||||
parser.add_argument(
|
||||
'--tta',
|
||||
action='store_true',
|
||||
help='Whether to use test time augmentation')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
|
@ -90,8 +96,37 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
config = args.config
|
||||
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
|
||||
if args.tta:
|
||||
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
|
||||
" Can't use tta !"
|
||||
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
|
||||
"in config. Can't use tta !"
|
||||
config.model = ConfigDict(**config.tta_model, module=config.model)
|
||||
test_data_cfg = config.test_dataloader.dataset
|
||||
while 'dataset' in test_data_cfg:
|
||||
test_data_cfg = test_data_cfg['dataset']
|
||||
|
||||
# batch_shapes_cfg will force control the size of the output image,
|
||||
# it is not compatible with tta.
|
||||
if 'batch_shapes_cfg' in test_data_cfg:
|
||||
test_data_cfg.batch_shapes_cfg = None
|
||||
test_data_cfg.pipeline = config.tta_pipeline
|
||||
|
||||
# TODO: TTA mode will error if cfg_options is not set.
|
||||
# This is an mmdet issue and needs to be fixed later.
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
model = init_detector(
|
||||
config, args.checkpoint, device=args.device, cfg_options={})
|
||||
|
||||
if args.deploy:
|
||||
switch_to_deploy(model)
|
||||
|
|
|
@ -15,6 +15,8 @@ Take the small dataset of cat as an example, you can easily learn MMYOLO object
|
|||
- [Testing](#testing)
|
||||
- [EasyDeploy](#easydeploy-deployment)
|
||||
|
||||
In this article, we take YOLOv5-s as an example. For the rest of the YOLO series algorithms, please see the corresponding algorithm configuration folder.
|
||||
|
||||
## Installation
|
||||
|
||||
Assuming you've already installed Conda in advance, install PyTorch
|
||||
|
@ -135,19 +137,23 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/trainval.json',
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/trainval.json')
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
# Save weights every 10 epochs and a maximum of two weights can be saved.
|
||||
# The best model is saved automatically during model evaluation
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
param_scheduler=dict(max_epochs=max_epochs),
|
||||
# The warmup_mim_iter parameter is critical.
|
||||
# The default value is 1000 which is not suitable for cat datasets.
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
# The log printing interval is 5
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
# The evaluation interval is 10
|
||||
|
@ -168,21 +174,21 @@ Run the above training command, `work_dirs/yolov5_s-v61_fast_1xb12-40e_cat` fold
|
|||
<img src="https://user-images.githubusercontent.com/17425982/220236361-bd113606-248e-4a0e-a484-c0dc9e355b5b.png" alt="image"/>
|
||||
</div>
|
||||
|
||||
The performance on `trainval.json` is as follows:
|
||||
The performance on `test.json` is as follows:
|
||||
|
||||
```text
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.685
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.953
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.852
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.631
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.909
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.747
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.685
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.664
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.749
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.761
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.627
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.703
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.703
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.761
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.703
|
||||
```
|
||||
|
||||
The above properties are printed via the COCO API, where -1 indicates that no object exists for the scale. According to the rules defined by COCO, the Cat dataset contains all large sized objects, and there are no small or medium-sized objects.
|
||||
|
@ -244,10 +250,10 @@ python tools/train.py configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py
|
|||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/220238131-08eacedc-28a7-4008-af8c-f36dc239ecaa.png" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222131114-30a79285-56bc-427d-a38d-8d6a6982ad60.png" alt="image"/>
|
||||
</div>
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/220238535-f363a6ba-876c-4bb7-80d6-9d8d8ca9b966.png" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222132585-4b4962f1-211b-46f7-86b3-7534fc52a1b4.png" alt="image"/>
|
||||
</div>
|
||||
|
||||
#### 2 Tensorboard
|
||||
|
@ -335,7 +341,7 @@ Let's choose the `data/cat/images/IMG_20221020_112705.jpg` image as an example t
|
|||
```shell
|
||||
python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layers backbone \
|
||||
--channel-reduction squeeze_mean
|
||||
```
|
||||
|
@ -351,7 +357,7 @@ The result will be saved to the output folder in current path. Three output feat
|
|||
```shell
|
||||
python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layers neck \
|
||||
--channel-reduction squeeze_mean
|
||||
```
|
||||
|
@ -371,7 +377,7 @@ Based on the above feature map visualization, we can analyze Grad CAM at the fea
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[2]
|
||||
```
|
||||
|
||||
|
@ -384,7 +390,7 @@ python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[1]
|
||||
```
|
||||
|
||||
|
@ -397,7 +403,7 @@ python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[0]
|
||||
```
|
||||
|
||||
|
@ -520,4 +526,4 @@ Here we choose to save the inference results under `output` instead of displayin
|
|||
|
||||
This completes the transformation deployment of the trained model and checks the inference results. This is the end of the tutorial.
|
||||
|
||||
The full content above can be viewed: [15_minutes_object_detection.ipynb](<>)
|
||||
The full content above can be viewed: [15_minutes_object_detection.ipynb](<>). If you encounter problems during training or testing, please check the \[common troubleshooting steps\](... /recommended_topics/troubleshooting_steps.md) first and feel free to raise an issue if you still can't solve it.
|
||||
|
|
|
@ -17,8 +17,8 @@ The following tasks are currently supported:
|
|||
<details open>
|
||||
<summary><b>Tasks currently supported</b></summary>
|
||||
|
||||
- object detection
|
||||
- rotated object detection
|
||||
- Object detection
|
||||
- Rotated object detection
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -30,6 +30,7 @@ The YOLO series of algorithms currently supported are as follows:
|
|||
- YOLOv5
|
||||
- YOLOX
|
||||
- RTMDet
|
||||
- RTMDet-Rotated
|
||||
- YOLOv6
|
||||
- YOLOv7
|
||||
- PPYOLOE
|
||||
|
|
|
@ -1,5 +1,60 @@
|
|||
# Changelog
|
||||
|
||||
## v0.5.0 (2/3/2023)
|
||||
|
||||
### Highlights
|
||||
|
||||
1. Support [RTMDet-R](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/README.md#rotated-object-detection) rotated object detection
|
||||
2. Support for using mask annotation to improve [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) object detection performance
|
||||
3. Support [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/razor/subnets/README.md) searchable NAS sub-network as the backbone of YOLO series algorithm
|
||||
4. Support calling [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/distillation/README.md) to distill the knowledge of RTMDet
|
||||
5. [MMYOLO](https://mmyolo.readthedocs.io/zh_CN/dev/) document structure optimization, comprehensive content upgrade
|
||||
6. Improve YOLOX mAP and training speed based on RTMDet training hyperparameters
|
||||
7. Support calculation of model parameters and FLOPs, provide GPU latency data on T4 devices, and update [Model Zoo](https://github.com/open-mmlab/mmyolo/blob/dev/docs/en/model_zoo.md)
|
||||
8. Support test-time augmentation (TTA)
|
||||
9. Support RTMDet, YOLOv8 and YOLOv7 assigner visualization
|
||||
|
||||
### New Features
|
||||
|
||||
01. Support inference for RTMDet instance segmentation tasks (#583)
|
||||
02. Beautify the configuration file in MMYOLO and add more comments (#501, #506, #516, #529, #531, #539)
|
||||
03. Refactor and optimize documentation (#568, #573, #579, #584, #587, #589, #596, #599, #600)
|
||||
04. Support fast version of YOLOX (#518)
|
||||
05. Support DeepStream in EasyDeploy and add documentation (#485, #545, #571)
|
||||
06. Add confusion matrix drawing script (#572)
|
||||
07. Add single channel application case (#460)
|
||||
08. Support auto registration (#597)
|
||||
09. Support Box CAM of YOLOv7, YOLOv8 and PPYOLOE (#601)
|
||||
10. Add automated generation of MM series repo registration information and tools scripts (#559)
|
||||
11. Added YOLOv7 model structure diagram (#504)
|
||||
12. Add how to specify specific GPU training and inference files (#503)
|
||||
13. Add check if `metainfo` is all lowercase when training or testing (#535)
|
||||
14. Add links to Twitter, Discord, Medium, YouTube, etc. (#555)
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
1. Fix isort version issue (#492, #497)
|
||||
2. Fix type error of assigner visualization (#509)
|
||||
3. Fix YOLOv8 documentation link error (#517)
|
||||
4. Fix RTMDet Decoder error in EasyDeploy (#519)
|
||||
5. Fix some document linking errors (#537)
|
||||
6. Fix RTMDet-Tiny weight path error (#580)
|
||||
|
||||
### Improvements
|
||||
|
||||
1. Update `contributing.md`
|
||||
2. Optimize `DetDataPreprocessor` branch to support multitasking (#511)
|
||||
3. Optimize `gt_instances_preprocess` so it can be used for other YOLO algorithms (#532)
|
||||
4. Add `yolov7-e6e` weight conversion script (#570)
|
||||
5. Reference YOLOv8 inference code modification PPYOLOE
|
||||
|
||||
### Contributors
|
||||
|
||||
A total of 22 developers contributed to this release.
|
||||
|
||||
Thank @triple-Mu, @isLinXu, @Audrey528, @TianWen580, @yechenzhi, @RangeKing, @lyviva, @Nioolek, @PeterH0323, @tianleiSHI, @aptsunny, @satuoqaq, @vansin, @xin-li-67, @VoyagerXvoyagerx,
|
||||
@landhill, @kitecats, @tang576225574, @HIT-cwh, @AI-Tianlong, @RangiLyu, @hhaAndroid
|
||||
|
||||
## v0.4.0 (18/1/2023)
|
||||
|
||||
### Highlights
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
- [模型测试](#模型测试)
|
||||
- [EasyDeploy 模型部署](#easydeploy-模型部署)
|
||||
|
||||
本文以 YOLOv5-s 为例,其余 YOLO 系列算法的猫 cat 小数据集 demo 配置请查看对应的算法配置文件夹下。
|
||||
|
||||
## 环境安装
|
||||
|
||||
假设你已经提前安装好了 Conda,接下来安装 PyTorch
|
||||
|
@ -134,19 +136,22 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
ann_file='annotations/trainval.json',
|
||||
ann_file='annotations/test.json',
|
||||
data_prefix=dict(img='images/')))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/trainval.json')
|
||||
_base_.optim_wrapper.optimizer.batch_size_per_gpu = train_batch_size_per_gpu
|
||||
|
||||
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
default_hooks = dict(
|
||||
# 每隔 10 个 epoch 保存一次权重,并且最多保存 2 个权重
|
||||
# 模型评估时候自动保存最佳模型
|
||||
checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),
|
||||
param_scheduler=dict(max_epochs=max_epochs),
|
||||
# warmup_mim_iter 参数非常关键,因为 cat 数据集非常小,默认的最小 warmup_mim_iter 是 1000,导致训练过程学习率偏小
|
||||
param_scheduler=dict(max_epochs=max_epochs, warmup_mim_iter=10),
|
||||
# 日志打印间隔为 5
|
||||
logger=dict(type='LoggerHook', interval=5))
|
||||
# 评估间隔为 10
|
||||
|
@ -167,21 +172,21 @@ python tools/train.py configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py
|
|||
<img src="https://user-images.githubusercontent.com/17425982/220236361-bd113606-248e-4a0e-a484-c0dc9e355b5b.png" alt="image"/>
|
||||
</div>
|
||||
|
||||
在 `trainval.json` 上性能如下所示:
|
||||
在 `test.json` 上性能如下所示:
|
||||
|
||||
```text
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.685
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.953
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.852
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.631
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.909
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.747
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.685
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.664
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.749
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.761
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.627
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.703
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.703
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.761
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.703
|
||||
```
|
||||
|
||||
上述性能是通过 COCO API 打印,其中 -1 表示不存在对于尺度的物体。根据 COCO 定义的规则,Cat 数据集里面全部是大物体,不存在小和中等规模物体。
|
||||
|
@ -243,10 +248,10 @@ python tools/train.py configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py
|
|||
```
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/220238131-08eacedc-28a7-4008-af8c-f36dc239ecaa.png" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222131114-30a79285-56bc-427d-a38d-8d6a6982ad60.png" alt="image"/>
|
||||
</div>
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/220238535-f363a6ba-876c-4bb7-80d6-9d8d8ca9b966.png" alt="image"/>
|
||||
<img src="https://user-images.githubusercontent.com/17425982/222132585-4b4962f1-211b-46f7-86b3-7534fc52a1b4.png" alt="image"/>
|
||||
</div>
|
||||
|
||||
#### 2 Tensorboard 可视化使用
|
||||
|
@ -260,7 +265,7 @@ pip install tensorboard
|
|||
同上述在配置文件 `configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py`配置的最后添加 `tensorboard` 配置
|
||||
|
||||
```python
|
||||
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
|
||||
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')])
|
||||
```
|
||||
|
||||
重新运行训练命令后,Tensorboard 文件会生成在可视化文件夹 `work_dirs/yolov5_s-v61_fast_1xb12-40e_cat.py/{timestamp}/vis_data` 下,
|
||||
|
@ -334,7 +339,7 @@ test_pipeline = [
|
|||
```shell
|
||||
python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layers backbone \
|
||||
--channel-reduction squeeze_mean
|
||||
```
|
||||
|
@ -350,7 +355,7 @@ python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layers neck \
|
||||
--channel-reduction squeeze_mean
|
||||
```
|
||||
|
@ -370,7 +375,7 @@ python demo/featmap_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[2]
|
||||
```
|
||||
|
||||
|
@ -383,7 +388,7 @@ python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[1]
|
||||
```
|
||||
|
||||
|
@ -396,7 +401,7 @@ python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
|||
```shell
|
||||
python demo/boxam_vis_demo.py data/cat/images/IMG_20221020_112705.jpg \
|
||||
configs/yolov5/yolov5_s-v61_fast_1xb12-40e_cat.py \
|
||||
work_dirs/yolov5_s-v61_fast_1xb8-40e_cat/epoch_40.pth \
|
||||
work_dirs/yolov5_s-v61_fast_1xb12-40e_cat/epoch_40.pth \
|
||||
--target-layer neck.out_layers[0]
|
||||
```
|
||||
|
||||
|
@ -519,4 +524,4 @@ python projects/easydeploy/tools/image-demo.py \
|
|||
|
||||
这样我们就完成了将训练完成的模型进行转换部署并且检查推理结果的工作。至此本教程结束。
|
||||
|
||||
以上完整内容可以查看 [15_minutes_object_detection.ipynb](<>)
|
||||
以上完整内容可以查看 [15_minutes_object_detection.ipynb](<>)。 如果你在训练或者测试过程中碰到问题,请先查看 [常见错误排除步骤](../recommended_topics/troubleshooting_steps.md), 如果依然无法解决欢迎提 issue。
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
- [MMYOLO 社区倾情贡献,RTMDet 原理社区开发者解读来啦!](https://zhuanlan.zhihu.com/p/569777684)
|
||||
- [MMYOLO 自定义数据集从标注到部署保姆级教程](https://zhuanlan.zhihu.com/p/595497726)
|
||||
- [满足一切需求的 MMYOLO 可视化:测试过程可视化](https://zhuanlan.zhihu.com/p/593179372)
|
||||
- [MMYOLO 想你所想: 训练过程可视化](https://zhuanlan.zhihu.com/p/608586878)
|
||||
- [YOLOv8 深度详解!一文看懂,快速上手](https://zhuanlan.zhihu.com/p/598566644)
|
||||
- [玩转 MMYOLO 基础类第一期: 配置文件太复杂?继承用法看不懂?配置全解读来了](https://zhuanlan.zhihu.com/p/577715188)
|
||||
- [玩转 MMYOLO 工具类第一期: 特征图可视化](https://zhuanlan.zhihu.com/p/578141381?)
|
||||
|
|
|
@ -30,6 +30,7 @@ MMYOLO 是一个基于 PyTorch 和 MMDetection 的 YOLO 系列算法开源工具
|
|||
- YOLOv5
|
||||
- YOLOX
|
||||
- RTMDet
|
||||
- RTMDet-Rotated
|
||||
- YOLOv6
|
||||
- YOLOv7
|
||||
- PPYOLOE
|
||||
|
|
|
@ -1,5 +1,60 @@
|
|||
# 更新日志
|
||||
|
||||
## v0.5.0 (2/3/2023)
|
||||
|
||||
### 亮点
|
||||
|
||||
1. 支持了 [RTMDet-R](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/README.md#rotated-object-detection) 旋转框目标检测任务和算法
|
||||
2. [YOLOv8](https://github.com/open-mmlab/mmyolo/blob/dev/configs/yolov8/README.md) 支持使用 mask 标注提升目标检测模型性能
|
||||
3. 支持 [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/razor/subnets/README.md) 搜索的 NAS 子网络作为 YOLO 系列算法的 backbone
|
||||
4. 支持调用 [MMRazor](https://github.com/open-mmlab/mmyolo/blob/dev/configs/rtmdet/distillation/README.md) 对 RTMDet 进行知识蒸馏
|
||||
5. [MMYOLO](https://mmyolo.readthedocs.io/zh_CN/dev/) 文档结构优化,内容全面升级
|
||||
6. 基于 RTMDet 训练超参提升 YOLOX 精度和训练速度
|
||||
7. 支持模型参数量、FLOPs 计算和提供 T4 设备上 GPU 延时数据,并更新了 [Model Zoo](https://github.com/open-mmlab/mmyolo/blob/dev/docs/zh_cn/model_zoo.md)
|
||||
8. 支持测试时增强 TTA
|
||||
9. 支持 RTMDet、YOLOv8 和 YOLOv7 assigner 可视化
|
||||
|
||||
### 新特性
|
||||
|
||||
01. 支持 RTMDet 实例分割任务的推理 (#583)
|
||||
02. 美化 MMYOLO 中配置文件并增加更多注释 (#501, #506, #516, #529, #531, #539)
|
||||
03. 重构并优化中英文文档 (#568, #573, #579, #584, #587, #589, #596, #599, #600)
|
||||
04. 支持 fast 版本的 YOLOX (#518)
|
||||
05. EasyDeploy 中支持 DeepStream,并添加说明文档 (#485, #545, #571)
|
||||
06. 新增混淆矩阵绘制脚本 (#572)
|
||||
07. 新增单通道应用案例 (#460)
|
||||
08. 支持 auto registration (#597)
|
||||
09. Box CAM 支持 YOLOv7、YOLOv8 和 PPYOLOE (#601)
|
||||
10. 新增自动化生成 MM 系列 repo 注册信息和 tools 脚本 (#559)
|
||||
11. 新增 YOLOv7 模型结构图 (#504)
|
||||
12. 新增如何指定特定 GPU 训练和推理文档 (#503)
|
||||
13. 新增训练或者测试时检查 `metainfo` 是否全为小写 (#535)
|
||||
14. 增加 Twitter、Discord、Medium 和 YouTube 等链接 (#555)
|
||||
|
||||
### Bug 修复
|
||||
|
||||
1. 修复 isort 版本问题 (#492, #497)
|
||||
2. 修复 assigner 可视化模块的 type 错误 (#509)
|
||||
3. 修复 YOLOv8 文档链接错误 (#517)
|
||||
4. 修复 EasyDeploy 中的 RTMDet Decoder 错误 (#519)
|
||||
5. 修复一些文档链接错误 (#537)
|
||||
6. 修复 RTMDet-Tiny 权重路径错误 (#580)
|
||||
|
||||
### 完善
|
||||
|
||||
1. 完善更新 `contributing.md`
|
||||
2. 优化 `DetDataPreprocessor` 支使其支持多任务 (#511)
|
||||
3. 优化 `gt_instances_preprocess` 使其可以用于其他 YOLO 算法 (#532)
|
||||
4. 新增 `yolov7-e6e` 权重转换脚本 (#570)
|
||||
5. 参考 YOLOv8 推理代码修改 PPYOLOE (#614)
|
||||
|
||||
### 贡献者
|
||||
|
||||
总共 22 位开发者参与了本次版本
|
||||
|
||||
@triple-Mu, @isLinXu, @Audrey528, @TianWen580, @yechenzhi, @RangeKing, @lyviva, @Nioolek, @PeterH0323, @tianleiSHI, @aptsunny, @satuoqaq, @vansin, @xin-li-67, @VoyagerXvoyagerx,
|
||||
@landhill, @kitecats, @tang576225574, @HIT-cwh, @AI-Tianlong, @RangiLyu, @hhaAndroid
|
||||
|
||||
## v0.4.0 (18/1/2023)
|
||||
|
||||
### 亮点
|
||||
|
|
Loading…
Reference in New Issue