mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support PPYOLOE training (#259)
* add ppyoloe backbone, neck * add ppyoloe test * add docstring * add ppyoloe m/l/x configfile * add ppyoloe_coco.py * rename config * add typehint * format code; add ut * add datapre * add datapre * add ppyoloe datapre * add ppyoloe datapre * add ppyoloe datapre * reproduce coco v0.1 * add ut * add ut, docstring * fix transforms bug * use mmdet dfloss * add non plus model config * add non plus model config * fix * add ut * produce coco v0.2 * fix config * fix config * fix eps and transforms bug * add ema * fix resize * fix transforms.py * fix transforms.py * fix transforms.py * old version * old version * old version * old version * old version * old version * fix stride loss error * add INTER_LANCZOS4 * fix crop bug * init commit * format code * format code * bgr transforms.py * add typehint and doc in transforms.py * 继承新版yolov6head写法,删除不必要的注释 * fix transforms var name bug * bbox decode use stridetensor insted of priors * add headmodule todo * add ppyoloe README.md * add ppyoloe README.md * Update tests/test_datasets/test_transforms/test_transforms.py Co-authored-by: Range King <RangeKingHZ@gmail.com> * Update tests/test_datasets/test_transforms/test_transforms.py Co-authored-by: Range King <RangeKingHZ@gmail.com> * save ckpt last 10 epochs * save_best ckpt * del ppyoloe collate * change name of ppyoloebatchrandomresize * add the reason for rewritten PPYOLOEDetDataPreprocessor * solve ppyoloerandomresize name error * rm PPYOLOERandomExpand * rm l1 loss * rm yolov6 loss_obj * add difference between yolov6 and ppyoloe * add reason for rewrite paramscheduler * change proj init way * fix error * rm proj_conv in pth * format code * add load_from * update * support fast training * add pretrained model url * update * add pretrained model url * fix error * add imagenet model convert and use init_cfg to init backbone * add plus model pretrain model * add ut * add ut * fix ut * fix withstride bug * cat in yolov5_collate * merge * fix typehint * update readme * add reason for gap * fix log in README.md * rollback yolov6 * change inherit * fix ut * fix ut Co-authored-by: Range King <RangeKingHZ@gmail.com> Co-authored-by: hha <1286304229@qq.com> Co-authored-by: huanghaian <huanghaian@sensetime.com>pull/349/head^2
parent
a20b160f0f
commit
8127805dd3
|
@ -0,0 +1,38 @@
|
||||||
|
# PPYOLOE
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
PP-YOLOE is an excellent single-stage anchor-free model based on PP-YOLOv2, surpassing a variety of popular YOLO models. PP-YOLOE has a series of models, named s/m/l/x, which are configured through width multiplier and depth multiplier. PP-YOLOE avoids using special operators, such as Deformable Convolution or Matrix NMS, to be deployed friendly on various hardware.
|
||||||
|
|
||||||
|
<div align=center>
|
||||||
|
<img src="https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/docs/images/ppyoloe_plus_map_fps.png" width="600" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
### PPYOLOE+ COCO
|
||||||
|
|
||||||
|
| Backbone | Arch | Size | Epoch | SyncBN | Mem (GB) | Box AP | Config | Download |
|
||||||
|
| :---------: | :--: | :--: | :---: | :----: | :------: | :----: | :-------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||||
|
| PPYOLOE+ -s | P5 | 640 | 80 | Yes | 4.7 | 43.5 | [config](../ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052.log.json) |
|
||||||
|
| PPYOLOE+ -m | P5 | 640 | 80 | Yes | 8.4 | 49.5 | [config](../ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132.log.json) |
|
||||||
|
| PPYOLOE+ -l | P5 | 640 | 80 | Yes | 13.2 | 52.6 | [config](../ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825.log.json) |
|
||||||
|
| PPYOLOE+ -x | P5 | 640 | 80 | Yes | 19.1 | 54.2 | [config](../ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921.log.json) |
|
||||||
|
|
||||||
|
**Note**:
|
||||||
|
|
||||||
|
1. The above Box APs are all models with the best performance in COCO
|
||||||
|
2. The gap between the above performance and the official release is about 0.3. To speed up training in mmyolo, we use pytorch to implement the image resizing in `PPYOLOEBatchRandomResize` for multi-scale training, while official PPYOLOE use opencv. And `lanczos4` is not yet supported in `PPYOLOEBatchRandomResize`. The above two reasons lead to the gap. We will continue to experiment and address the gap in future releases.
|
||||||
|
3. The mAP of the non-Plus version needs more verification, and we will update more details of the non-Plus version in future versions.
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{Xu2022PPYOLOEAE,
|
||||||
|
title={PP-YOLOE: An evolved version of YOLO},
|
||||||
|
author={Shangliang Xu and Xinxin Wang and Wenyu Lv and Qinyao Chang and Cheng Cui and Kaipeng Deng and Guanzhong Wang and Qingqing Dang and Shengyun Wei and Yuning Du and Baohua Lai},
|
||||||
|
journal={ArXiv},
|
||||||
|
year={2022},
|
||||||
|
volume={abs/2203.16250}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,69 @@
|
||||||
|
Collections:
|
||||||
|
- Name: PPYOLOE
|
||||||
|
Metadata:
|
||||||
|
Training Data: COCO
|
||||||
|
Training Techniques:
|
||||||
|
- SGD with Nesterov
|
||||||
|
- Weight Decay
|
||||||
|
- Synchronize BN
|
||||||
|
Training Resources: 8x A100 GPUs
|
||||||
|
Architecture:
|
||||||
|
- PPYOLOECSPResNet
|
||||||
|
- PPYOLOECSPPAFPN
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/2203.16250
|
||||||
|
Title: 'PP-YOLOE: An evolved version of YOLO'
|
||||||
|
README: configs/ppyoloe/README.md
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/open-mmlab/mmyolo/blob/v0.0.1/mmyolo/models/detectors/yolo_detector.py#L12
|
||||||
|
Version: v0.0.1
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: ppyoloe_plus_s_fast_8xb8-80e_coco
|
||||||
|
In Collection: PPYOLOE
|
||||||
|
Config: configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
|
||||||
|
Metadata:
|
||||||
|
Training Memory (GB): 4.7
|
||||||
|
Epochs: 80
|
||||||
|
Results:
|
||||||
|
- Task: Object Detection
|
||||||
|
Dataset: COCO
|
||||||
|
Metrics:
|
||||||
|
box AP: 43.5
|
||||||
|
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth
|
||||||
|
- Name: ppyoloe_plus_m_fast_8xb8-80e_coco
|
||||||
|
In Collection: PPYOLOE
|
||||||
|
Config: configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
|
||||||
|
Metadata:
|
||||||
|
Training Memory (GB): 8.4
|
||||||
|
Epochs: 80
|
||||||
|
Results:
|
||||||
|
- Task: Object Detection
|
||||||
|
Dataset: COCO
|
||||||
|
Metrics:
|
||||||
|
box AP: 49.5
|
||||||
|
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth
|
||||||
|
- Name: ppyoloe_plus_L_fast_8xb8-80e_coco
|
||||||
|
In Collection: PPYOLOE
|
||||||
|
Config: configs/ppyoloe/ppyoloe_plus_L_fast_8xb8-80e_coco.py
|
||||||
|
Metadata:
|
||||||
|
Training Memory (GB): 13.2
|
||||||
|
Epochs: 80
|
||||||
|
Results:
|
||||||
|
- Task: Object Detection
|
||||||
|
Dataset: COCO
|
||||||
|
Metrics:
|
||||||
|
box AP: 52.6
|
||||||
|
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth
|
||||||
|
- Name: ppyoloe_plus_x_fast_8xb8-80e_coco
|
||||||
|
In Collection: PPYOLOE
|
||||||
|
Config: configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
|
||||||
|
Metadata:
|
||||||
|
Training Memory (GB): 19.1
|
||||||
|
Epochs: 80
|
||||||
|
Results:
|
||||||
|
- Task: Object Detection
|
||||||
|
Dataset: COCO
|
||||||
|
Metrics:
|
||||||
|
box AP: 54.2
|
||||||
|
Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth
|
|
@ -1,15 +1,23 @@
|
||||||
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_l_imagenet1k_pretrained-c0010e6c.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 1.0
|
deepen_factor = 1.0
|
||||||
widen_factor = 1.0
|
widen_factor = 1.0
|
||||||
|
|
||||||
# TODO: training on ppyoloe need to be implemented.
|
|
||||||
train_batch_size_per_gpu = 20
|
train_batch_size_per_gpu = 20
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
backbone=dict(
|
||||||
|
deepen_factor=deepen_factor,
|
||||||
|
widen_factor=widen_factor,
|
||||||
|
init_cfg=dict(checkpoint=checkpoint)),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
deepen_factor=deepen_factor,
|
deepen_factor=deepen_factor,
|
||||||
widen_factor=widen_factor,
|
widen_factor=widen_factor,
|
||||||
),
|
),
|
||||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
||||||
|
|
||||||
|
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_m_imagenet1k_pretrained-09f1eba2.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 0.67
|
deepen_factor = 0.67
|
||||||
widen_factor = 0.75
|
widen_factor = 0.75
|
||||||
|
|
||||||
# TODO: training on ppyoloe need to be implemented.
|
|
||||||
train_batch_size_per_gpu = 28
|
train_batch_size_per_gpu = 28
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
backbone=dict(
|
||||||
|
deepen_factor=deepen_factor,
|
||||||
|
widen_factor=widen_factor,
|
||||||
|
init_cfg=dict(checkpoint=checkpoint)),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
deepen_factor=deepen_factor,
|
deepen_factor=deepen_factor,
|
||||||
widen_factor=widen_factor,
|
widen_factor=widen_factor,
|
||||||
),
|
),
|
||||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
||||||
|
|
||||||
|
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_l_obj365_pretrained-3dd89562.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 1.0
|
deepen_factor = 1.0
|
||||||
widen_factor = 1.0
|
widen_factor = 1.0
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_m_ojb365_pretrained-03206892.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 0.67
|
deepen_factor = 0.67
|
||||||
widen_factor = 0.75
|
widen_factor = 0.75
|
||||||
|
|
||||||
|
|
|
@ -9,21 +9,40 @@ img_scale = (640, 640) # height, width
|
||||||
deepen_factor = 0.33
|
deepen_factor = 0.33
|
||||||
widen_factor = 0.5
|
widen_factor = 0.5
|
||||||
max_epochs = 80
|
max_epochs = 80
|
||||||
save_epoch_intervals = 10
|
num_classes = 80
|
||||||
|
save_epoch_intervals = 5
|
||||||
train_batch_size_per_gpu = 8
|
train_batch_size_per_gpu = 8
|
||||||
train_num_workers = 8
|
train_num_workers = 8
|
||||||
val_batch_size_per_gpu = 1
|
val_batch_size_per_gpu = 1
|
||||||
val_num_workers = 2
|
val_num_workers = 2
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_s_obj365_pretrained-bcfe8478.pth' # noqa
|
||||||
|
|
||||||
# persistent_workers must be False if num_workers is 0.
|
# persistent_workers must be False if num_workers is 0.
|
||||||
persistent_workers = True
|
persistent_workers = True
|
||||||
|
|
||||||
|
# Base learning rate for optim_wrapper
|
||||||
|
base_lr = 0.001
|
||||||
|
|
||||||
strides = [8, 16, 32]
|
strides = [8, 16, 32]
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
type='YOLODetector',
|
type='YOLODetector',
|
||||||
data_preprocessor=dict(
|
data_preprocessor=dict(
|
||||||
type='YOLOv5DetDataPreprocessor',
|
# use this to support multi_scale training
|
||||||
|
type='PPYOLOEDetDataPreprocessor',
|
||||||
|
pad_size_divisor=32,
|
||||||
|
batch_augments=[
|
||||||
|
dict(
|
||||||
|
type='PPYOLOEBatchRandomResize',
|
||||||
|
random_size_range=(320, 800),
|
||||||
|
interval=1,
|
||||||
|
size_divisor=32,
|
||||||
|
random_interp=True,
|
||||||
|
keep_ratio=False)
|
||||||
|
],
|
||||||
mean=[0., 0., 0.],
|
mean=[0., 0., 0.],
|
||||||
std=[255., 255., 255.],
|
std=[255., 255., 255.],
|
||||||
bgr_to_rgb=True),
|
bgr_to_rgb=True),
|
||||||
|
@ -56,11 +75,52 @@ model = dict(
|
||||||
type='PPYOLOEHead',
|
type='PPYOLOEHead',
|
||||||
head_module=dict(
|
head_module=dict(
|
||||||
type='PPYOLOEHeadModule',
|
type='PPYOLOEHeadModule',
|
||||||
num_classes=80,
|
num_classes=num_classes,
|
||||||
in_channels=[192, 384, 768],
|
in_channels=[192, 384, 768],
|
||||||
widen_factor=widen_factor,
|
widen_factor=widen_factor,
|
||||||
featmap_strides=strides,
|
featmap_strides=strides,
|
||||||
num_base_priors=1)),
|
reg_max=16,
|
||||||
|
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
|
||||||
|
act_cfg=dict(type='SiLU', inplace=True),
|
||||||
|
num_base_priors=1),
|
||||||
|
prior_generator=dict(
|
||||||
|
type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
|
||||||
|
bbox_coder=dict(type='DistancePointBBoxCoder'),
|
||||||
|
loss_cls=dict(
|
||||||
|
type='mmdet.VarifocalLoss',
|
||||||
|
use_sigmoid=True,
|
||||||
|
alpha=0.75,
|
||||||
|
gamma=2.0,
|
||||||
|
iou_weighted=True,
|
||||||
|
reduction='sum',
|
||||||
|
loss_weight=1.0),
|
||||||
|
loss_bbox=dict(
|
||||||
|
type='IoULoss',
|
||||||
|
iou_mode='giou',
|
||||||
|
bbox_format='xyxy',
|
||||||
|
reduction='mean',
|
||||||
|
loss_weight=2.5,
|
||||||
|
return_iou=False),
|
||||||
|
# Since the dflloss is implemented differently in the official
|
||||||
|
# and mmdet, we're going to divide loss_weight by 4.
|
||||||
|
loss_dfl=dict(
|
||||||
|
type='mmdet.DistributionFocalLoss',
|
||||||
|
reduction='mean',
|
||||||
|
loss_weight=0.5 / 4)),
|
||||||
|
train_cfg=dict(
|
||||||
|
initial_epoch=30,
|
||||||
|
initial_assigner=dict(
|
||||||
|
type='BatchATSSAssigner',
|
||||||
|
num_classes=num_classes,
|
||||||
|
topk=9,
|
||||||
|
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||||
|
assigner=dict(
|
||||||
|
type='BatchTaskAlignedAssigner',
|
||||||
|
num_classes=num_classes,
|
||||||
|
topk=13,
|
||||||
|
alpha=1,
|
||||||
|
beta=6,
|
||||||
|
eps=1e-9)),
|
||||||
test_cfg=dict(
|
test_cfg=dict(
|
||||||
multi_label=True,
|
multi_label=True,
|
||||||
nms_pre=1000,
|
nms_pre=1000,
|
||||||
|
@ -68,10 +128,36 @@ model = dict(
|
||||||
nms=dict(type='nms', iou_threshold=0.7),
|
nms=dict(type='nms', iou_threshold=0.7),
|
||||||
max_per_img=300))
|
max_per_img=300))
|
||||||
|
|
||||||
test_pipeline = [
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='PPYOLOERandomDistort'),
|
||||||
|
dict(type='mmdet.Expand', mean=(103.53, 116.28, 123.675)),
|
||||||
|
dict(type='PPYOLOERandomCrop'),
|
||||||
|
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||||
dict(
|
dict(
|
||||||
type='LoadImageFromFile',
|
type='mmdet.PackDetInputs',
|
||||||
file_client_args={{_base_.file_client_args}}),
|
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
|
||||||
|
'flip_direction'))
|
||||||
|
]
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=train_batch_size_per_gpu,
|
||||||
|
num_workers=train_num_workers,
|
||||||
|
persistent_workers=persistent_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
collate_fn=dict(type='yolov5_collate', use_ms_training=True),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
ann_file='annotations/instances_train2017.json',
|
||||||
|
data_prefix=dict(img='train2017/'),
|
||||||
|
filter_cfg=dict(filter_empty_gt=True, min_size=0),
|
||||||
|
pipeline=train_pipeline))
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||||
dict(
|
dict(
|
||||||
type='mmdet.FixShapeResize',
|
type='mmdet.FixShapeResize',
|
||||||
width=img_scale[1],
|
width=img_scale[1],
|
||||||
|
@ -103,6 +189,41 @@ val_dataloader = dict(
|
||||||
|
|
||||||
test_dataloader = val_dataloader
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
param_scheduler = None
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type='OptimWrapper',
|
||||||
|
optimizer=dict(
|
||||||
|
type='SGD',
|
||||||
|
lr=base_lr,
|
||||||
|
momentum=0.9,
|
||||||
|
weight_decay=5e-4,
|
||||||
|
nesterov=False),
|
||||||
|
paramwise_cfg=dict(norm_decay_mult=0.))
|
||||||
|
|
||||||
|
default_hooks = dict(
|
||||||
|
param_scheduler=dict(
|
||||||
|
type='PPYOLOEParamSchedulerHook',
|
||||||
|
warmup_min_iter=1000,
|
||||||
|
start_factor=0.,
|
||||||
|
warmup_epochs=5,
|
||||||
|
min_lr_ratio=0.0,
|
||||||
|
total_epochs=int(max_epochs * 1.2)),
|
||||||
|
checkpoint=dict(
|
||||||
|
type='CheckpointHook',
|
||||||
|
interval=save_epoch_intervals,
|
||||||
|
save_best='auto',
|
||||||
|
max_keep_ckpts=3))
|
||||||
|
|
||||||
|
custom_hooks = [
|
||||||
|
dict(
|
||||||
|
type='EMAHook',
|
||||||
|
ema_type='ExpMomentumEMA',
|
||||||
|
momentum=0.0002,
|
||||||
|
update_buffers=True,
|
||||||
|
strict_load=False,
|
||||||
|
priority=49)
|
||||||
|
]
|
||||||
|
|
||||||
val_evaluator = dict(
|
val_evaluator = dict(
|
||||||
type='mmdet.CocoMetric',
|
type='mmdet.CocoMetric',
|
||||||
proposal_nums=(100, 1, 10),
|
proposal_nums=(100, 1, 10),
|
||||||
|
@ -110,5 +231,9 @@ val_evaluator = dict(
|
||||||
metric='bbox')
|
metric='bbox')
|
||||||
test_evaluator = val_evaluator
|
test_evaluator = val_evaluator
|
||||||
|
|
||||||
|
train_cfg = dict(
|
||||||
|
type='EpochBasedTrainLoop',
|
||||||
|
max_epochs=max_epochs,
|
||||||
|
val_interval=save_epoch_intervals)
|
||||||
val_cfg = dict(type='ValLoop')
|
val_cfg = dict(type='ValLoop')
|
||||||
test_cfg = dict(type='TestLoop')
|
test_cfg = dict(type='TestLoop')
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_x_obj365_pretrained-43a8000d.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 1.33
|
deepen_factor = 1.33
|
||||||
widen_factor = 1.25
|
widen_factor = 1.25
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,36 @@
|
||||||
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
|
||||||
|
|
||||||
# TODO: training on ppyoloe need to be implemented.
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_s_imagenet1k_pretrained-2be81763.pth' # noqa
|
||||||
|
|
||||||
train_batch_size_per_gpu = 32
|
train_batch_size_per_gpu = 32
|
||||||
max_epochs = 300
|
max_epochs = 300
|
||||||
|
|
||||||
|
# Base learning rate for optim_wrapper
|
||||||
|
base_lr = 0.01
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
data_preprocessor=dict(
|
data_preprocessor=dict(
|
||||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||||
std=[0.229 * 255., 0.224 * 255., 0.225 * 255.]),
|
std=[0.229 * 255., 0.224 * 255., 0.225 * 255.]),
|
||||||
backbone=dict(block_cfg=dict(use_alpha=False)))
|
backbone=dict(
|
||||||
|
block_cfg=dict(use_alpha=False),
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Pretrained',
|
||||||
|
prefix='backbone.',
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
map_location='cpu')),
|
||||||
|
train_cfg=dict(initial_epoch=100))
|
||||||
|
|
||||||
|
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
|
||||||
|
|
||||||
|
optim_wrapper = dict(optimizer=dict(lr=base_lr))
|
||||||
|
|
||||||
|
default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
|
||||||
|
|
||||||
|
train_cfg = dict(max_epochs=max_epochs)
|
||||||
|
|
||||||
|
# PPYOLOE plus use obj365 pretrained model, but PPYOLOE not,
|
||||||
|
# `load_from` need to set to None.
|
||||||
|
load_from = None
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
||||||
|
|
||||||
# TODO: training on ppyoloe need to be implemented.
|
|
||||||
max_epochs = 400
|
max_epochs = 400
|
||||||
|
|
||||||
|
model = dict(train_cfg=dict(initial_epoch=133))
|
||||||
|
|
||||||
|
default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
|
||||||
|
|
||||||
|
train_cfg = dict(max_epochs=max_epochs)
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
|
||||||
|
|
||||||
|
# The pretrained model is geted and converted from official PPYOLOE.
|
||||||
|
# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
|
||||||
|
checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_x_imagenet1k_pretrained-81c33ccb.pth' # noqa
|
||||||
|
|
||||||
deepen_factor = 1.33
|
deepen_factor = 1.33
|
||||||
widen_factor = 1.25
|
widen_factor = 1.25
|
||||||
|
|
||||||
# TODO: training on ppyoloe need to be implemented.
|
|
||||||
train_batch_size_per_gpu = 16
|
train_batch_size_per_gpu = 16
|
||||||
|
|
||||||
model = dict(
|
model = dict(
|
||||||
backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
|
backbone=dict(
|
||||||
|
deepen_factor=deepen_factor,
|
||||||
|
widen_factor=widen_factor,
|
||||||
|
init_cfg=dict(checkpoint=checkpoint)),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
deepen_factor=deepen_factor,
|
deepen_factor=deepen_factor,
|
||||||
widen_factor=widen_factor,
|
widen_factor=widen_factor,
|
||||||
),
|
),
|
||||||
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
|
||||||
|
|
||||||
|
train_dataloader = dict(batch_size=train_batch_size_per_gpu)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
|
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
|
||||||
from .transforms import (LetterResize, LoadAnnotations, YOLOv5HSVRandomAug,
|
from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
|
||||||
|
PPYOLOERandomDistort, YOLOv5HSVRandomAug,
|
||||||
YOLOv5KeepRatioResize, YOLOv5RandomAffine)
|
YOLOv5KeepRatioResize, YOLOv5RandomAffine)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
|
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
|
||||||
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
|
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
|
||||||
'YOLOv5RandomAffine', 'Mosaic9'
|
'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
|
||||||
|
'Mosaic9'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
from typing import Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
|
@ -675,3 +675,397 @@ class YOLOv5RandomAffine(BaseTransform):
|
||||||
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
|
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
|
||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
return translation_matrix
|
return translation_matrix
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class PPYOLOERandomDistort(BaseTransform):
|
||||||
|
"""Random hue, saturation, contrast and brightness distortion.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img (np.float32)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hue_cfg (dict): Hue settings. Defaults to dict(min=-18,
|
||||||
|
max=18, prob=0.5).
|
||||||
|
saturation_cfg (dict): Saturation settings. Defaults to dict(
|
||||||
|
min=0.5, max=1.5, prob=0.5).
|
||||||
|
contrast_cfg (dict): Contrast settings. Defaults to dict(
|
||||||
|
min=0.5, max=1.5, prob=0.5).
|
||||||
|
brightness_cfg (dict): Brightness settings. Defaults to dict(
|
||||||
|
min=0.5, max=1.5, prob=0.5).
|
||||||
|
num_distort_func (int): The number of distort function. Defaults
|
||||||
|
to 4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hue_cfg: dict = dict(min=-18, max=18, prob=0.5),
|
||||||
|
saturation_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
||||||
|
contrast_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
||||||
|
brightness_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
||||||
|
num_distort_func: int = 4):
|
||||||
|
self.hue_cfg = hue_cfg
|
||||||
|
self.saturation_cfg = saturation_cfg
|
||||||
|
self.contrast_cfg = contrast_cfg
|
||||||
|
self.brightness_cfg = brightness_cfg
|
||||||
|
self.num_distort_func = num_distort_func
|
||||||
|
assert 0 < self.num_distort_func <= 4,\
|
||||||
|
'num_distort_func must > 0 and <= 4'
|
||||||
|
for cfg in [
|
||||||
|
self.hue_cfg, self.saturation_cfg, self.contrast_cfg,
|
||||||
|
self.brightness_cfg
|
||||||
|
]:
|
||||||
|
assert 0. <= cfg['prob'] <= 1., 'prob must >=0 and <=1'
|
||||||
|
|
||||||
|
def transform_hue(self, results):
|
||||||
|
"""Transform hue randomly."""
|
||||||
|
if random.uniform(0., 1.) >= self.hue_cfg['prob']:
|
||||||
|
return results
|
||||||
|
img = results['img']
|
||||||
|
delta = random.uniform(self.hue_cfg['min'], self.hue_cfg['max'])
|
||||||
|
u = np.cos(delta * np.pi)
|
||||||
|
w = np.sin(delta * np.pi)
|
||||||
|
delta_iq = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
|
||||||
|
rgb2yiq_matrix = np.array([[0.114, 0.587, 0.299],
|
||||||
|
[-0.321, -0.274, 0.596],
|
||||||
|
[0.311, -0.523, 0.211]])
|
||||||
|
yiq2rgb_matric = np.array([[1.0, -1.107, 1.705], [1.0, -0.272, -0.647],
|
||||||
|
[1.0, 0.956, 0.621]])
|
||||||
|
t = np.dot(np.dot(yiq2rgb_matric, delta_iq), rgb2yiq_matrix).T
|
||||||
|
img = np.dot(img, t)
|
||||||
|
results['img'] = img
|
||||||
|
return results
|
||||||
|
|
||||||
|
def transform_saturation(self, results):
|
||||||
|
"""Transform saturation randomly."""
|
||||||
|
if random.uniform(0., 1.) >= self.saturation_cfg['prob']:
|
||||||
|
return results
|
||||||
|
img = results['img']
|
||||||
|
delta = random.uniform(self.saturation_cfg['min'],
|
||||||
|
self.saturation_cfg['max'])
|
||||||
|
|
||||||
|
# convert bgr img to gray img
|
||||||
|
gray = img * np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
|
||||||
|
gray = gray.sum(axis=2, keepdims=True)
|
||||||
|
gray *= (1.0 - delta)
|
||||||
|
img *= delta
|
||||||
|
img += gray
|
||||||
|
results['img'] = img
|
||||||
|
return results
|
||||||
|
|
||||||
|
def transform_contrast(self, results):
|
||||||
|
"""Transform contrast randomly."""
|
||||||
|
if random.uniform(0., 1.) >= self.contrast_cfg['prob']:
|
||||||
|
return results
|
||||||
|
img = results['img']
|
||||||
|
delta = random.uniform(self.contrast_cfg['min'],
|
||||||
|
self.contrast_cfg['max'])
|
||||||
|
img *= delta
|
||||||
|
results['img'] = img
|
||||||
|
return results
|
||||||
|
|
||||||
|
def transform_brightness(self, results):
|
||||||
|
"""Transform brightness randomly."""
|
||||||
|
if random.uniform(0., 1.) >= self.brightness_cfg['prob']:
|
||||||
|
return results
|
||||||
|
img = results['img']
|
||||||
|
delta = random.uniform(self.brightness_cfg['min'],
|
||||||
|
self.brightness_cfg['max'])
|
||||||
|
img += delta
|
||||||
|
results['img'] = img
|
||||||
|
return results
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""The hue, saturation, contrast and brightness distortion function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): The result dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The result dict.
|
||||||
|
"""
|
||||||
|
results['img'] = results['img'].astype(np.float32)
|
||||||
|
|
||||||
|
functions = [
|
||||||
|
self.transform_brightness, self.transform_contrast,
|
||||||
|
self.transform_saturation, self.transform_hue
|
||||||
|
]
|
||||||
|
distortions = random.permutation(functions)[:self.num_distort_func]
|
||||||
|
for func in distortions:
|
||||||
|
results = func(results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class PPYOLOERandomCrop(BaseTransform):
|
||||||
|
"""Random crop the img and bboxes. Different thresholds are used in PPYOLOE
|
||||||
|
to judge whether the clipped image meets the requirements. This
|
||||||
|
implementation is different from the implementation of RandomCrop in mmdet.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
||||||
|
- gt_bboxes_labels (np.int64) (optional)
|
||||||
|
- gt_ignore_flags (bool) (optional)
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- img_shape
|
||||||
|
- gt_bboxes (optional)
|
||||||
|
- gt_bboxes_labels (optional)
|
||||||
|
- gt_ignore_flags (optional)
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
- pad_param (np.float32)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aspect_ratio (List[float]): Aspect ratio of cropped region. Default to
|
||||||
|
[.5, 2].
|
||||||
|
thresholds (List[float]): Iou thresholds for decide a valid bbox crop
|
||||||
|
in [min, max] format. Defaults to [.0, .1, .3, .5, .7, .9].
|
||||||
|
scaling (List[float]): Ratio between a cropped region and the original
|
||||||
|
image in [min, max] format. Default to [.3, 1.].
|
||||||
|
num_attempts (int): Number of tries for each threshold before
|
||||||
|
giving up. Default to 50.
|
||||||
|
allow_no_crop (bool): Allow return without actually cropping them.
|
||||||
|
Default to True.
|
||||||
|
cover_all_box (bool): Ensure all bboxes are covered in the final crop.
|
||||||
|
Default to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
aspect_ratio: List[float] = [.5, 2.],
|
||||||
|
thresholds: List[float] = [.0, .1, .3, .5, .7, .9],
|
||||||
|
scaling: List[float] = [.3, 1.],
|
||||||
|
num_attempts: int = 50,
|
||||||
|
allow_no_crop: bool = True,
|
||||||
|
cover_all_box: bool = False):
|
||||||
|
self.aspect_ratio = aspect_ratio
|
||||||
|
self.thresholds = thresholds
|
||||||
|
self.scaling = scaling
|
||||||
|
self.num_attempts = num_attempts
|
||||||
|
self.allow_no_crop = allow_no_crop
|
||||||
|
self.cover_all_box = cover_all_box
|
||||||
|
|
||||||
|
def _crop_data(self, results: dict, crop_box: Tuple[int, int, int, int],
|
||||||
|
valid_inds: np.ndarray) -> Union[dict, None]:
|
||||||
|
"""Function to randomly crop images, bounding boxes, masks, semantic
|
||||||
|
segmentation maps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
crop_box (Tuple[int, int, int, int]): Expected absolute coordinates
|
||||||
|
for cropping, (x1, y1, x2, y2).
|
||||||
|
valid_inds (np.ndarray): The indexes of gt that needs to be
|
||||||
|
retained.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (Union[dict, None]): Randomly cropped results, 'img_shape'
|
||||||
|
key in result dict is updated according to crop size. None will
|
||||||
|
be returned when there is no valid bbox after cropping.
|
||||||
|
"""
|
||||||
|
# crop the image
|
||||||
|
img = results['img']
|
||||||
|
crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
|
||||||
|
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
||||||
|
results['img'] = img
|
||||||
|
img_shape = img.shape
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
|
||||||
|
# crop bboxes accordingly and clip to the image boundary
|
||||||
|
if results.get('gt_bboxes', None) is not None:
|
||||||
|
bboxes = results['gt_bboxes']
|
||||||
|
bboxes.translate_([-crop_x1, -crop_y1])
|
||||||
|
bboxes.clip_(img_shape[:2])
|
||||||
|
|
||||||
|
results['gt_bboxes'] = bboxes[valid_inds]
|
||||||
|
|
||||||
|
if results.get('gt_ignore_flags', None) is not None:
|
||||||
|
results['gt_ignore_flags'] = \
|
||||||
|
results['gt_ignore_flags'][valid_inds]
|
||||||
|
|
||||||
|
if results.get('gt_bboxes_labels', None) is not None:
|
||||||
|
results['gt_bboxes_labels'] = \
|
||||||
|
results['gt_bboxes_labels'][valid_inds]
|
||||||
|
|
||||||
|
if results.get('gt_masks', None) is not None:
|
||||||
|
results['gt_masks'] = results['gt_masks'][
|
||||||
|
valid_inds.nonzero()[0]].crop(
|
||||||
|
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
|
||||||
|
|
||||||
|
# crop semantic seg
|
||||||
|
if results.get('gt_seg_map', None) is not None:
|
||||||
|
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
|
||||||
|
crop_x1:crop_x2]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@autocast_box_type()
|
||||||
|
def transform(self, results: dict) -> Union[dict, None]:
|
||||||
|
"""The random crop transform function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): The result dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The result dict.
|
||||||
|
"""
|
||||||
|
if results.get('gt_bboxes', None) is None or len(
|
||||||
|
results['gt_bboxes']) == 0:
|
||||||
|
return results
|
||||||
|
|
||||||
|
orig_img_h, orig_img_w = results['img'].shape[:2]
|
||||||
|
gt_bboxes = results['gt_bboxes']
|
||||||
|
|
||||||
|
thresholds = list(self.thresholds)
|
||||||
|
if self.allow_no_crop:
|
||||||
|
thresholds.append('no_crop')
|
||||||
|
random.shuffle(thresholds)
|
||||||
|
|
||||||
|
for thresh in thresholds:
|
||||||
|
# Determine the coordinates for cropping
|
||||||
|
if thresh == 'no_crop':
|
||||||
|
return results
|
||||||
|
|
||||||
|
found = False
|
||||||
|
for i in range(self.num_attempts):
|
||||||
|
crop_h, crop_w = self._get_crop_size((orig_img_h, orig_img_w))
|
||||||
|
if self.aspect_ratio is None:
|
||||||
|
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get image crop_box
|
||||||
|
margin_h = max(orig_img_h - crop_h, 0)
|
||||||
|
margin_w = max(orig_img_w - crop_w, 0)
|
||||||
|
offset_h, offset_w = self._rand_offset((margin_h, margin_w))
|
||||||
|
crop_y1, crop_y2 = offset_h, offset_h + crop_h
|
||||||
|
crop_x1, crop_x2 = offset_w, offset_w + crop_w
|
||||||
|
|
||||||
|
crop_box = [crop_x1, crop_y1, crop_x2, crop_y2]
|
||||||
|
# Calculate the iou between gt_bboxes and crop_boxes
|
||||||
|
iou = self._iou_matrix(gt_bboxes,
|
||||||
|
np.array([crop_box], dtype=np.float32))
|
||||||
|
# If the maximum value of the iou is less than thresh,
|
||||||
|
# the current crop_box is considered invalid.
|
||||||
|
if iou.max() < thresh:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If cover_all_box == True and the minimum value of
|
||||||
|
# the iou is less than thresh, the current crop_box
|
||||||
|
# is considered invalid.
|
||||||
|
if self.cover_all_box and iou.min() < thresh:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get which gt_bboxes to keep after cropping.
|
||||||
|
valid_inds = self._get_valid_inds(
|
||||||
|
gt_bboxes, np.array(crop_box, dtype=np.float32))
|
||||||
|
if valid_inds.size > 0:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found:
|
||||||
|
results = self._crop_data(results, crop_box, valid_inds)
|
||||||
|
return results
|
||||||
|
return results
|
||||||
|
|
||||||
|
@cache_randomness
|
||||||
|
def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
|
||||||
|
"""Randomly generate crop offset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
margin (Tuple[int, int]): The upper bound for the offset generated
|
||||||
|
randomly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: The random offset for the crop.
|
||||||
|
"""
|
||||||
|
margin_h, margin_w = margin
|
||||||
|
offset_h = np.random.randint(0, margin_h + 1)
|
||||||
|
offset_w = np.random.randint(0, margin_w + 1)
|
||||||
|
|
||||||
|
return (offset_h, offset_w)
|
||||||
|
|
||||||
|
@cache_randomness
|
||||||
|
def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||||
|
"""Randomly generates the crop size based on `image_size`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (Tuple[int, int]): (h, w).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
|
||||||
|
"""
|
||||||
|
h, w = image_size
|
||||||
|
scale = random.uniform(*self.scaling)
|
||||||
|
if self.aspect_ratio is not None:
|
||||||
|
min_ar, max_ar = self.aspect_ratio
|
||||||
|
aspect_ratio = random.uniform(
|
||||||
|
max(min_ar, scale**2), min(max_ar, scale**-2))
|
||||||
|
h_scale = scale / np.sqrt(aspect_ratio)
|
||||||
|
w_scale = scale * np.sqrt(aspect_ratio)
|
||||||
|
else:
|
||||||
|
h_scale = random.uniform(*self.scaling)
|
||||||
|
w_scale = random.uniform(*self.scaling)
|
||||||
|
crop_h = h * h_scale
|
||||||
|
crop_w = w * w_scale
|
||||||
|
return int(crop_h), int(crop_w)
|
||||||
|
|
||||||
|
def _iou_matrix(self,
|
||||||
|
gt_bbox: HorizontalBoxes,
|
||||||
|
crop_bbox: np.ndarray,
|
||||||
|
eps: float = 1e-10) -> np.ndarray:
|
||||||
|
"""Calculate iou between gt and image crop box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
|
||||||
|
crop_bbox (np.ndarray): Image crop coordinates in
|
||||||
|
[x1, y1, x2, y2] format.
|
||||||
|
eps (float): Default to 1e-10.
|
||||||
|
Return:
|
||||||
|
(np.ndarray): IoU.
|
||||||
|
"""
|
||||||
|
gt_bbox = gt_bbox.tensor.numpy()
|
||||||
|
lefttop = np.maximum(gt_bbox[:, np.newaxis, :2], crop_bbox[:, :2])
|
||||||
|
rightbottom = np.minimum(gt_bbox[:, np.newaxis, 2:], crop_bbox[:, 2:])
|
||||||
|
|
||||||
|
overlap = np.prod(
|
||||||
|
rightbottom - lefttop,
|
||||||
|
axis=2) * (lefttop < rightbottom).all(axis=2)
|
||||||
|
area_gt_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
|
||||||
|
area_crop_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
|
||||||
|
area_o = (area_gt_bbox[:, np.newaxis] + area_crop_bbox - overlap)
|
||||||
|
return overlap / (area_o + eps)
|
||||||
|
|
||||||
|
def _get_valid_inds(self, gt_bbox: HorizontalBoxes,
|
||||||
|
img_crop_bbox: np.ndarray) -> np.ndarray:
|
||||||
|
"""Get which Bboxes to keep at the current cropping coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
|
||||||
|
img_crop_bbox (np.ndarray): Image crop coordinates in
|
||||||
|
[x1, y1, x2, y2] format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(np.ndarray): Valid indexes.
|
||||||
|
"""
|
||||||
|
cropped_box = gt_bbox.tensor.numpy().copy()
|
||||||
|
gt_bbox = gt_bbox.tensor.numpy().copy()
|
||||||
|
|
||||||
|
cropped_box[:, :2] = np.maximum(gt_bbox[:, :2], img_crop_bbox[:2])
|
||||||
|
cropped_box[:, 2:] = np.minimum(gt_bbox[:, 2:], img_crop_bbox[2:])
|
||||||
|
cropped_box[:, :2] -= img_crop_bbox[:2]
|
||||||
|
cropped_box[:, 2:] -= img_crop_bbox[:2]
|
||||||
|
|
||||||
|
centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
|
||||||
|
valid = np.logical_and(img_crop_bbox[:2] <= centers,
|
||||||
|
centers < img_crop_bbox[2:]).all(axis=1)
|
||||||
|
valid = np.logical_and(
|
||||||
|
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
|
||||||
|
|
||||||
|
return np.where(valid)[0]
|
||||||
|
|
|
@ -9,8 +9,14 @@ from ..registry import TASK_UTILS
|
||||||
|
|
||||||
|
|
||||||
@COLLATE_FUNCTIONS.register_module()
|
@COLLATE_FUNCTIONS.register_module()
|
||||||
def yolov5_collate(data_batch: Sequence) -> dict:
|
def yolov5_collate(data_batch: Sequence,
|
||||||
"""Rewrite collate_fn to get faster training speed."""
|
use_ms_training: bool = False) -> dict:
|
||||||
|
"""Rewrite collate_fn to get faster training speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_batch (Sequence): Batch of data.
|
||||||
|
use_ms_training (bool): Whether to use multi-scale training.
|
||||||
|
"""
|
||||||
batch_imgs = []
|
batch_imgs = []
|
||||||
batch_bboxes_labels = []
|
batch_bboxes_labels = []
|
||||||
for i in range(len(data_batch)):
|
for i in range(len(data_batch)):
|
||||||
|
@ -25,10 +31,16 @@ def yolov5_collate(data_batch: Sequence) -> dict:
|
||||||
batch_bboxes_labels.append(bboxes_labels)
|
batch_bboxes_labels.append(bboxes_labels)
|
||||||
|
|
||||||
batch_imgs.append(inputs)
|
batch_imgs.append(inputs)
|
||||||
return {
|
if use_ms_training:
|
||||||
'inputs': torch.stack(batch_imgs, 0),
|
return {
|
||||||
'data_samples': torch.cat(batch_bboxes_labels, 0)
|
'inputs': batch_imgs,
|
||||||
}
|
'data_samples': torch.cat(batch_bboxes_labels, 0)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
'inputs': torch.stack(batch_imgs, 0),
|
||||||
|
'data_samples': torch.cat(batch_bboxes_labels, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@TASK_UTILS.register_module()
|
@TASK_UTILS.register_module()
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
|
||||||
from .switch_to_deploy_hook import SwitchToDeployHook
|
from .switch_to_deploy_hook import SwitchToDeployHook
|
||||||
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
|
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
|
||||||
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook'
|
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
|
||||||
|
'PPYOLOEParamSchedulerHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mmengine.hooks import ParamSchedulerHook
|
||||||
|
from mmengine.runner import Runner
|
||||||
|
|
||||||
|
from mmyolo.registry import HOOKS
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
|
||||||
|
"""A hook to update learning rate and momentum in optimizer of PPYOLOE. We
|
||||||
|
use this hook to implement adaptive computation for `warmup_total_iters`,
|
||||||
|
which is not possible with the built-in ParamScheduler in mmyolo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
|
||||||
|
start_factor (float): The number we multiply learning rate in the
|
||||||
|
first epoch. The multiplication factor changes towards end_factor
|
||||||
|
in the following epochs. Defaults to 0.
|
||||||
|
warmup_epochs (int): Epochs for warmup. Defaults to 5.
|
||||||
|
min_lr_ratio (float): Minimum learning rate ratio.
|
||||||
|
total_epochs (int): In PPYOLOE, `total_epochs` is set to
|
||||||
|
training_epochs x 1.2. Defaults to 360.
|
||||||
|
"""
|
||||||
|
priority = 9
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
warmup_min_iter: int = 1000,
|
||||||
|
start_factor: float = 0.,
|
||||||
|
warmup_epochs: int = 5,
|
||||||
|
min_lr_ratio: float = 0.0,
|
||||||
|
total_epochs: int = 360):
|
||||||
|
|
||||||
|
self.warmup_min_iter = warmup_min_iter
|
||||||
|
self.start_factor = start_factor
|
||||||
|
self.warmup_epochs = warmup_epochs
|
||||||
|
self.min_lr_ratio = min_lr_ratio
|
||||||
|
self.total_epochs = total_epochs
|
||||||
|
|
||||||
|
self._warmup_end = False
|
||||||
|
self._base_lr = None
|
||||||
|
|
||||||
|
def before_train(self, runner: Runner):
|
||||||
|
"""Operations before train.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
optimizer = runner.optim_wrapper.optimizer
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
# If the param is never be scheduled, record the current value
|
||||||
|
# as the initial value.
|
||||||
|
group.setdefault('initial_lr', group['lr'])
|
||||||
|
|
||||||
|
self._base_lr = [
|
||||||
|
group['initial_lr'] for group in optimizer.param_groups
|
||||||
|
]
|
||||||
|
self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
|
||||||
|
|
||||||
|
def before_train_iter(self,
|
||||||
|
runner: Runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: Optional[dict] = None):
|
||||||
|
"""Operations before each training iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
|
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||||
|
"""
|
||||||
|
cur_iters = runner.iter
|
||||||
|
optimizer = runner.optim_wrapper.optimizer
|
||||||
|
dataloader_len = len(runner.train_dataloader)
|
||||||
|
|
||||||
|
# The minimum warmup is self.warmup_min_iter
|
||||||
|
warmup_total_iters = max(
|
||||||
|
round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
|
||||||
|
|
||||||
|
if cur_iters <= warmup_total_iters:
|
||||||
|
# warm up
|
||||||
|
alpha = cur_iters / warmup_total_iters
|
||||||
|
factor = self.start_factor * (1 - alpha) + alpha
|
||||||
|
|
||||||
|
for group_idx, param in enumerate(optimizer.param_groups):
|
||||||
|
param['lr'] = self._base_lr[group_idx] * factor
|
||||||
|
else:
|
||||||
|
for group_idx, param in enumerate(optimizer.param_groups):
|
||||||
|
total_iters = self.total_epochs * dataloader_len
|
||||||
|
lr = self._min_lr[group_idx] + (
|
||||||
|
self._base_lr[group_idx] -
|
||||||
|
self._min_lr[group_idx]) * 0.5 * (
|
||||||
|
math.cos((cur_iters - warmup_total_iters) * math.pi /
|
||||||
|
(total_iters - warmup_total_iters)) + 1.0)
|
||||||
|
param['lr'] = lr
|
|
@ -1,4 +1,9 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .data_preprocessor import YOLOv5DetDataPreprocessor
|
from .data_preprocessor import (PPYOLOEBatchRandomResize,
|
||||||
|
PPYOLOEDetDataPreprocessor,
|
||||||
|
YOLOv5DetDataPreprocessor)
|
||||||
|
|
||||||
__all__ = ['YOLOv5DetDataPreprocessor']
|
__all__ = [
|
||||||
|
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
|
||||||
|
'PPYOLOEBatchRandomResize'
|
||||||
|
]
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import random
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmdet.models import BatchSyncRandomResize
|
||||||
from mmdet.models.data_preprocessors import DetDataPreprocessor
|
from mmdet.models.data_preprocessors import DetDataPreprocessor
|
||||||
|
from mmengine import MessageHub, is_list_of
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from mmyolo.registry import MODELS
|
from mmyolo.registry import MODELS
|
||||||
|
|
||||||
|
@ -50,3 +57,191 @@ class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
|
||||||
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
|
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
|
||||||
|
|
||||||
return {'inputs': inputs, 'data_samples': data_samples}
|
return {'inputs': inputs, 'data_samples': data_samples}
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class PPYOLOEDetDataPreprocessor(DetDataPreprocessor):
|
||||||
|
"""Image pre-processor for detection tasks.
|
||||||
|
|
||||||
|
The main difference between PPYOLOEDetDataPreprocessor and
|
||||||
|
DetDataPreprocessor is the normalization order. The official
|
||||||
|
PPYOLOE resize image first, and then normalize image.
|
||||||
|
In DetDataPreprocessor, the order is reversed.
|
||||||
|
|
||||||
|
Note: It must be used together with
|
||||||
|
`mmyolo.datasets.utils.yolov5_collate`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, data: dict, training: bool = False) -> dict:
|
||||||
|
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||||
|
``BaseDataPreprocessor``. This class use batch_augments first, and then
|
||||||
|
normalize the image, which is different from the `DetDataPreprocessor`
|
||||||
|
.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): Data sampled from dataloader.
|
||||||
|
training (bool): Whether to enable training time augmentation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Data in the same format as the model input.
|
||||||
|
"""
|
||||||
|
if not training:
|
||||||
|
return super().forward(data, training)
|
||||||
|
|
||||||
|
assert isinstance(data['inputs'], list) and is_list_of(
|
||||||
|
data['inputs'], torch.Tensor), \
|
||||||
|
'"inputs" should be a list of Tensor, but got ' \
|
||||||
|
f'{type(data["inputs"])}. The possible reason for this ' \
|
||||||
|
'is that you are not using it with ' \
|
||||||
|
'"mmyolo.datasets.utils.yolov5_collate". Please refer to ' \
|
||||||
|
'"cconfigs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py".'
|
||||||
|
|
||||||
|
data = self.cast_data(data)
|
||||||
|
inputs, data_samples = data['inputs'], data['data_samples']
|
||||||
|
|
||||||
|
# Process data.
|
||||||
|
batch_inputs = []
|
||||||
|
for _batch_input, data_sample in zip(inputs, data_samples):
|
||||||
|
# channel transform
|
||||||
|
if self._channel_conversion:
|
||||||
|
_batch_input = _batch_input[[2, 1, 0], ...]
|
||||||
|
# Convert to float after channel conversion to ensure
|
||||||
|
# efficiency
|
||||||
|
_batch_input = _batch_input.float()
|
||||||
|
|
||||||
|
batch_inputs.append(_batch_input)
|
||||||
|
|
||||||
|
# Batch random resize image.
|
||||||
|
if self.batch_augments is not None:
|
||||||
|
for batch_aug in self.batch_augments:
|
||||||
|
inputs, data_samples = batch_aug(batch_inputs, data_samples)
|
||||||
|
|
||||||
|
if self._enable_normalize:
|
||||||
|
inputs = (inputs - self.mean) / self.std
|
||||||
|
|
||||||
|
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
|
||||||
|
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
|
||||||
|
|
||||||
|
return {'inputs': inputs, 'data_samples': data_samples}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: No generality. Its input data format is different
|
||||||
|
# mmdet's batch aug, and it must be compatible in the future.
|
||||||
|
@MODELS.register_module()
|
||||||
|
class PPYOLOEBatchRandomResize(BatchSyncRandomResize):
|
||||||
|
"""PPYOLOE batch random resize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_size_range (tuple): The multi-scale random range during
|
||||||
|
multi-scale training.
|
||||||
|
interval (int): The iter interval of change
|
||||||
|
image size. Defaults to 10.
|
||||||
|
size_divisor (int): Image size divisible factor.
|
||||||
|
Defaults to 32.
|
||||||
|
random_interp (bool): Whether to choose interp_mode randomly.
|
||||||
|
If set to True, the type of `interp_mode` must be list.
|
||||||
|
If set to False, the type of `interp_mode` must be str.
|
||||||
|
Defaults to True.
|
||||||
|
interp_mode (Union[List, str]): The modes available for resizing
|
||||||
|
are ('nearest', 'bilinear', 'bicubic', 'area').
|
||||||
|
keep_ratio (bool): Whether to keep the aspect ratio when resizing
|
||||||
|
the image. Now we only support keep_ratio=False.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
random_size_range: Tuple[int, int],
|
||||||
|
interval: int = 1,
|
||||||
|
size_divisor: int = 32,
|
||||||
|
random_interp=True,
|
||||||
|
interp_mode: Union[List[str], str] = [
|
||||||
|
'nearest', 'bilinear', 'bicubic', 'area'
|
||||||
|
],
|
||||||
|
keep_ratio: bool = False) -> None:
|
||||||
|
super().__init__(random_size_range, interval, size_divisor)
|
||||||
|
self.random_interp = random_interp
|
||||||
|
self.keep_ratio = keep_ratio
|
||||||
|
# TODO: need to support keep_ratio==True
|
||||||
|
assert not self.keep_ratio, 'We do not yet support keep_ratio=True'
|
||||||
|
|
||||||
|
if self.random_interp:
|
||||||
|
assert isinstance(interp_mode, list) and len(interp_mode) > 1,\
|
||||||
|
'While random_interp==True, the type of `interp_mode`' \
|
||||||
|
' must be list and len(interp_mode) must large than 1'
|
||||||
|
self.interp_mode_list = interp_mode
|
||||||
|
self.interp_mode = None
|
||||||
|
else:
|
||||||
|
assert isinstance(interp_mode, str),\
|
||||||
|
'While random_interp==False, the type of ' \
|
||||||
|
'`interp_mode` must be str'
|
||||||
|
assert interp_mode in ['nearest', 'bilinear', 'bicubic', 'area']
|
||||||
|
self.interp_mode_list = None
|
||||||
|
self.interp_mode = interp_mode
|
||||||
|
|
||||||
|
def forward(self, inputs: list,
|
||||||
|
data_samples: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Resize a batch of images and bboxes to shape ``self._input_size``.
|
||||||
|
|
||||||
|
The inputs and data_samples should be list, and
|
||||||
|
``PPYOLOEBatchRandomResize`` must be used with
|
||||||
|
``PPYOLOEDetDataPreprocessor`` and ``yolov5_collate`` with
|
||||||
|
``use_ms_training == True``.
|
||||||
|
"""
|
||||||
|
assert isinstance(inputs, list),\
|
||||||
|
'The type of inputs must be list. The possible reason for this ' \
|
||||||
|
'is that you are not using it with `PPYOLOEDetDataPreprocessor` ' \
|
||||||
|
'and `yolov5_collate` with use_ms_training == True.'
|
||||||
|
message_hub = MessageHub.get_current_instance()
|
||||||
|
if (message_hub.get_info('iter') + 1) % self._interval == 0:
|
||||||
|
# get current input size
|
||||||
|
self._input_size, interp_mode = self._get_random_size_and_interp()
|
||||||
|
if self.random_interp:
|
||||||
|
self.interp_mode = interp_mode
|
||||||
|
|
||||||
|
# TODO: need to support type(inputs)==Tensor
|
||||||
|
if isinstance(inputs, list):
|
||||||
|
outputs = []
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
_batch_input = inputs[i]
|
||||||
|
h, w = _batch_input.shape[-2:]
|
||||||
|
scale_y = self._input_size[0] / h
|
||||||
|
scale_x = self._input_size[1] / w
|
||||||
|
if scale_x != 1. or scale_y != 1.:
|
||||||
|
if self.interp_mode in ('nearest', 'area'):
|
||||||
|
align_corners = None
|
||||||
|
else:
|
||||||
|
align_corners = False
|
||||||
|
_batch_input = F.interpolate(
|
||||||
|
_batch_input.unsqueeze(0),
|
||||||
|
size=self._input_size,
|
||||||
|
mode=self.interp_mode,
|
||||||
|
align_corners=align_corners)
|
||||||
|
|
||||||
|
# rescale boxes
|
||||||
|
indexes = data_samples[:, 0] == i
|
||||||
|
data_samples[indexes, 2] *= scale_x
|
||||||
|
data_samples[indexes, 3] *= scale_y
|
||||||
|
data_samples[indexes, 4] *= scale_x
|
||||||
|
data_samples[indexes, 5] *= scale_y
|
||||||
|
else:
|
||||||
|
_batch_input = _batch_input.unsqueeze(0)
|
||||||
|
|
||||||
|
outputs.append(_batch_input)
|
||||||
|
|
||||||
|
# convert to Tensor
|
||||||
|
return torch.cat(outputs, dim=0), data_samples
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Not implemented yet!')
|
||||||
|
|
||||||
|
def _get_random_size_and_interp(self) -> Tuple[int, int]:
|
||||||
|
"""Randomly generate a shape in ``_random_size_range`` and a
|
||||||
|
interp_mode in interp_mode_list."""
|
||||||
|
size = random.randint(*self._random_size_range)
|
||||||
|
input_size = (self._size_divisor * size, self._size_divisor * size)
|
||||||
|
|
||||||
|
if self.random_interp:
|
||||||
|
interp_ind = random.randint(0, len(self.interp_mode_list) - 1)
|
||||||
|
interp_mode = self.interp_mode_list[interp_ind]
|
||||||
|
else:
|
||||||
|
interp_mode = None
|
||||||
|
return input_size, interp_mode
|
||||||
|
|
|
@ -1,24 +1,27 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmdet.models.utils import multi_apply
|
from mmdet.models.utils import multi_apply
|
||||||
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
||||||
OptMultiConfig)
|
OptMultiConfig, reduce_mean)
|
||||||
|
from mmengine import MessageHub
|
||||||
from mmengine.model import BaseModule, bias_init_with_prob
|
from mmengine.model import BaseModule, bias_init_with_prob
|
||||||
from mmengine.structures import InstanceData
|
from mmengine.structures import InstanceData
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmyolo.registry import MODELS
|
from mmyolo.registry import MODELS
|
||||||
from ..layers.yolo_bricks import PPYOLOESELayer
|
from ..layers.yolo_bricks import PPYOLOESELayer
|
||||||
from .yolov5_head import YOLOv5Head
|
from .yolov6_head import YOLOv6Head
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class PPYOLOEHeadModule(BaseModule):
|
class PPYOLOEHeadModule(BaseModule):
|
||||||
"""PPYOLOEHead head module used in `PPYOLOE`
|
"""PPYOLOEHead head module used in `PPYOLOE.
|
||||||
|
|
||||||
|
<https://arxiv.org/abs/2203.16250>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_classes (int): Number of categories excluding the background
|
num_classes (int): Number of categories excluding the background
|
||||||
|
@ -30,7 +33,8 @@ class PPYOLOEHeadModule(BaseModule):
|
||||||
on the feature grid.
|
on the feature grid.
|
||||||
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
||||||
Defaults to (8, 16, 32).
|
Defaults to (8, 16, 32).
|
||||||
reg_max (int): TOOD reg_max param.
|
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
|
||||||
|
in QFL setting. Defaults to 16.
|
||||||
norm_cfg (dict): Config dict for normalization layer.
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||||
act_cfg (dict): Config dict for activation layer.
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
@ -100,15 +104,12 @@ class PPYOLOEHeadModule(BaseModule):
|
||||||
self.reg_preds.append(
|
self.reg_preds.append(
|
||||||
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
|
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
|
||||||
|
|
||||||
self.proj_conv = nn.Conv2d(self.reg_max + 1, 1, 1, bias=False)
|
# init proj
|
||||||
self.proj = nn.Parameter(
|
proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
|
||||||
torch.linspace(0, self.reg_max, self.reg_max + 1),
|
[1, self.reg_max + 1, 1, 1])
|
||||||
requires_grad=False)
|
self.register_buffer('proj', proj, persistent=False)
|
||||||
self.proj_conv.weight = nn.Parameter(
|
|
||||||
self.proj.view([1, self.reg_max + 1, 1, 1]).clone().detach(),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tuple[Tensor]) -> Tensor:
|
||||||
"""Forward features from the upstream network.
|
"""Forward features from the upstream network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -131,17 +132,24 @@ class PPYOLOEHeadModule(BaseModule):
|
||||||
hw = h * w
|
hw = h * w
|
||||||
avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
|
avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
|
||||||
cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
|
cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
|
||||||
reg_dist = reg_pred(reg_stem(x, avg_feat))
|
bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
|
||||||
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1,
|
# TODO: Test whether use matmul instead of conv can speed up training.
|
||||||
hw]).permute(0, 2, 3, 1)
|
bbox_dist_preds = bbox_dist_preds.reshape(
|
||||||
reg_dist = self.proj_conv(F.softmax(reg_dist, dim=1))
|
[-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
|
||||||
|
|
||||||
return cls_logit, reg_dist
|
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
return cls_logit, bbox_preds, bbox_dist_preds
|
||||||
|
else:
|
||||||
|
return cls_logit, bbox_preds
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class PPYOLOEHead(YOLOv5Head):
|
class PPYOLOEHead(YOLOv6Head):
|
||||||
"""PPYOLOEHead head used in `PPYOLOE`.
|
"""PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_.
|
||||||
|
The YOLOv6 head and the PPYOLOE head are only slightly different.
|
||||||
|
Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
head_module(ConfigType): Base module used for YOLOv5Head
|
head_module(ConfigType): Base module used for YOLOv5Head
|
||||||
|
@ -150,7 +158,8 @@ class PPYOLOEHead(YOLOv5Head):
|
||||||
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
||||||
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
||||||
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
||||||
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
|
loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
|
||||||
|
loss.
|
||||||
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
||||||
anchor head. Defaults to None.
|
anchor head. Defaults to None.
|
||||||
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
||||||
|
@ -168,17 +177,24 @@ class PPYOLOEHead(YOLOv5Head):
|
||||||
strides=[8, 16, 32]),
|
strides=[8, 16, 32]),
|
||||||
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
||||||
loss_cls: ConfigType = dict(
|
loss_cls: ConfigType = dict(
|
||||||
type='mmdet.CrossEntropyLoss',
|
type='mmdet.VarifocalLoss',
|
||||||
use_sigmoid=True,
|
use_sigmoid=True,
|
||||||
|
alpha=0.75,
|
||||||
|
gamma=2.0,
|
||||||
|
iou_weighted=True,
|
||||||
reduction='sum',
|
reduction='sum',
|
||||||
loss_weight=1.0),
|
loss_weight=1.0),
|
||||||
loss_bbox: ConfigType = dict(
|
loss_bbox: ConfigType = dict(
|
||||||
type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
|
type='IoULoss',
|
||||||
loss_obj: ConfigType = dict(
|
iou_mode='giou',
|
||||||
type='mmdet.CrossEntropyLoss',
|
bbox_format='xyxy',
|
||||||
use_sigmoid=True,
|
reduction='mean',
|
||||||
reduction='sum',
|
loss_weight=2.5,
|
||||||
loss_weight=1.0),
|
return_iou=False),
|
||||||
|
loss_dfl: ConfigType = dict(
|
||||||
|
type='mmdet.DistributionFocalLoss',
|
||||||
|
reduction='mean',
|
||||||
|
loss_weight=0.5 / 4),
|
||||||
train_cfg: OptConfigType = None,
|
train_cfg: OptConfigType = None,
|
||||||
test_cfg: OptConfigType = None,
|
test_cfg: OptConfigType = None,
|
||||||
init_cfg: OptMultiConfig = None):
|
init_cfg: OptMultiConfig = None):
|
||||||
|
@ -188,19 +204,18 @@ class PPYOLOEHead(YOLOv5Head):
|
||||||
bbox_coder=bbox_coder,
|
bbox_coder=bbox_coder,
|
||||||
loss_cls=loss_cls,
|
loss_cls=loss_cls,
|
||||||
loss_bbox=loss_bbox,
|
loss_bbox=loss_bbox,
|
||||||
loss_obj=loss_obj,
|
|
||||||
train_cfg=train_cfg,
|
train_cfg=train_cfg,
|
||||||
test_cfg=test_cfg,
|
test_cfg=test_cfg,
|
||||||
init_cfg=init_cfg)
|
init_cfg=init_cfg)
|
||||||
|
self.loss_dfl = MODELS.build(loss_dfl)
|
||||||
def special_init(self):
|
# ppyoloe doesn't need loss_obj
|
||||||
"""Not Implenented."""
|
self.loss_obj = None
|
||||||
pass
|
|
||||||
|
|
||||||
def loss_by_feat(
|
def loss_by_feat(
|
||||||
self,
|
self,
|
||||||
cls_scores: Sequence[Tensor],
|
cls_scores: Sequence[Tensor],
|
||||||
bbox_preds: Sequence[Tensor],
|
bbox_preds: Sequence[Tensor],
|
||||||
|
bbox_dist_preds: Sequence[Tensor],
|
||||||
batch_gt_instances: Sequence[InstanceData],
|
batch_gt_instances: Sequence[InstanceData],
|
||||||
batch_img_metas: Sequence[dict],
|
batch_img_metas: Sequence[dict],
|
||||||
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
||||||
|
@ -214,6 +229,8 @@ class PPYOLOEHead(YOLOv5Head):
|
||||||
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
||||||
level, each is a 4D-tensor, the channel number is
|
level, each is a 4D-tensor, the channel number is
|
||||||
num_priors * 4.
|
num_priors * 4.
|
||||||
|
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
|
||||||
|
each scale level with shape (bs, reg_max + 1, H*W, 4).
|
||||||
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||||
gt_instance. It usually includes ``bboxes`` and ``labels``
|
gt_instance. It usually includes ``bboxes`` and ``labels``
|
||||||
attributes.
|
attributes.
|
||||||
|
@ -226,4 +243,131 @@ class PPYOLOEHead(YOLOv5Head):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: A dictionary of losses.
|
dict[str, Tensor]: A dictionary of losses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError('Not implemented yet!')
|
|
||||||
|
# get epoch information from message hub
|
||||||
|
message_hub = MessageHub.get_current_instance()
|
||||||
|
current_epoch = message_hub.get_info('epoch')
|
||||||
|
|
||||||
|
num_imgs = len(batch_img_metas)
|
||||||
|
|
||||||
|
current_featmap_sizes = [
|
||||||
|
cls_score.shape[2:] for cls_score in cls_scores
|
||||||
|
]
|
||||||
|
# If the shape does not equal, generate new one
|
||||||
|
if current_featmap_sizes != self.featmap_sizes_train:
|
||||||
|
self.featmap_sizes_train = current_featmap_sizes
|
||||||
|
|
||||||
|
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
||||||
|
self.featmap_sizes_train,
|
||||||
|
dtype=cls_scores[0].dtype,
|
||||||
|
device=cls_scores[0].device,
|
||||||
|
with_stride=True)
|
||||||
|
|
||||||
|
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
|
||||||
|
self.flatten_priors_train = torch.cat(
|
||||||
|
mlvl_priors_with_stride, dim=0)
|
||||||
|
self.stride_tensor = self.flatten_priors_train[..., [2]]
|
||||||
|
|
||||||
|
# gt info
|
||||||
|
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||||
|
gt_labels = gt_info[:, :, :1]
|
||||||
|
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
||||||
|
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
||||||
|
|
||||||
|
# pred info
|
||||||
|
flatten_cls_preds = [
|
||||||
|
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||||
|
self.num_classes)
|
||||||
|
for cls_pred in cls_scores
|
||||||
|
]
|
||||||
|
flatten_pred_bboxes = [
|
||||||
|
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||||
|
for bbox_pred in bbox_preds
|
||||||
|
]
|
||||||
|
# (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
|
||||||
|
flatten_pred_dists = [
|
||||||
|
bbox_pred_org.permute(0, 2, 3, 1).reshape(
|
||||||
|
num_imgs, -1, (self.head_module.reg_max + 1) * 4)
|
||||||
|
for bbox_pred_org in bbox_dist_preds
|
||||||
|
]
|
||||||
|
|
||||||
|
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
|
||||||
|
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
||||||
|
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
||||||
|
flatten_pred_bboxes = self.bbox_coder.decode(
|
||||||
|
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
|
||||||
|
self.stride_tensor[..., 0])
|
||||||
|
pred_scores = torch.sigmoid(flatten_cls_preds)
|
||||||
|
|
||||||
|
if current_epoch < self.initial_epoch:
|
||||||
|
assigned_result = self.initial_assigner(
|
||||||
|
flatten_pred_bboxes.detach(), self.flatten_priors_train,
|
||||||
|
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
||||||
|
else:
|
||||||
|
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
||||||
|
pred_scores.detach(),
|
||||||
|
self.flatten_priors_train,
|
||||||
|
gt_labels, gt_bboxes,
|
||||||
|
pad_bbox_flag)
|
||||||
|
|
||||||
|
assigned_bboxes = assigned_result['assigned_bboxes']
|
||||||
|
assigned_scores = assigned_result['assigned_scores']
|
||||||
|
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
|
||||||
|
|
||||||
|
# cls loss
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
|
||||||
|
|
||||||
|
# rescale bbox
|
||||||
|
assigned_bboxes /= self.stride_tensor
|
||||||
|
flatten_pred_bboxes /= self.stride_tensor
|
||||||
|
|
||||||
|
assigned_scores_sum = assigned_scores.sum()
|
||||||
|
# reduce_mean between all gpus
|
||||||
|
assigned_scores_sum = torch.clamp(
|
||||||
|
reduce_mean(assigned_scores_sum), min=1)
|
||||||
|
loss_cls /= assigned_scores_sum
|
||||||
|
|
||||||
|
# select positive samples mask
|
||||||
|
num_pos = fg_mask_pre_prior.sum()
|
||||||
|
if num_pos > 0:
|
||||||
|
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
|
||||||
|
# will not report an error
|
||||||
|
# iou loss
|
||||||
|
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
|
||||||
|
pred_bboxes_pos = torch.masked_select(
|
||||||
|
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
|
||||||
|
assigned_bboxes_pos = torch.masked_select(
|
||||||
|
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
|
||||||
|
bbox_weight = torch.masked_select(
|
||||||
|
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
|
||||||
|
loss_bbox = self.loss_bbox(
|
||||||
|
pred_bboxes_pos,
|
||||||
|
assigned_bboxes_pos,
|
||||||
|
weight=bbox_weight,
|
||||||
|
avg_factor=assigned_scores_sum)
|
||||||
|
|
||||||
|
# dfl loss
|
||||||
|
dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
|
||||||
|
[1, 1, (self.head_module.reg_max + 1) * 4])
|
||||||
|
|
||||||
|
pred_dist_pos = torch.masked_select(
|
||||||
|
flatten_dist_preds,
|
||||||
|
dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
|
||||||
|
assigned_ltrb = self.bbox_coder.encode(
|
||||||
|
self.flatten_priors_train[..., :2] / self.stride_tensor,
|
||||||
|
assigned_bboxes,
|
||||||
|
max_dis=self.head_module.reg_max,
|
||||||
|
eps=0.01)
|
||||||
|
assigned_ltrb_pos = torch.masked_select(
|
||||||
|
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
|
||||||
|
loss_dfl = self.loss_dfl(
|
||||||
|
pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
|
||||||
|
assigned_ltrb_pos.reshape(-1),
|
||||||
|
weight=bbox_weight.expand(-1, 4).reshape(-1),
|
||||||
|
avg_factor=assigned_scores_sum)
|
||||||
|
else:
|
||||||
|
loss_bbox = flatten_pred_bboxes.sum() * 0
|
||||||
|
loss_dfl = flatten_pred_bboxes.sum() * 0
|
||||||
|
|
||||||
|
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
|
||||||
|
|
|
@ -169,7 +169,6 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
in 2D points-based detectors.
|
in 2D points-based detectors.
|
||||||
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
||||||
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
||||||
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
|
|
||||||
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
||||||
anchor head. Defaults to None.
|
anchor head. Defaults to None.
|
||||||
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
||||||
|
@ -201,11 +200,6 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
reduction='mean',
|
reduction='mean',
|
||||||
loss_weight=2.5,
|
loss_weight=2.5,
|
||||||
return_iou=False),
|
return_iou=False),
|
||||||
loss_obj: ConfigType = dict(
|
|
||||||
type='mmdet.CrossEntropyLoss',
|
|
||||||
use_sigmoid=True,
|
|
||||||
reduction='sum',
|
|
||||||
loss_weight=1.0),
|
|
||||||
train_cfg: OptConfigType = None,
|
train_cfg: OptConfigType = None,
|
||||||
test_cfg: OptConfigType = None,
|
test_cfg: OptConfigType = None,
|
||||||
init_cfg: OptMultiConfig = None):
|
init_cfg: OptMultiConfig = None):
|
||||||
|
@ -215,13 +209,11 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
bbox_coder=bbox_coder,
|
bbox_coder=bbox_coder,
|
||||||
loss_cls=loss_cls,
|
loss_cls=loss_cls,
|
||||||
loss_bbox=loss_bbox,
|
loss_bbox=loss_bbox,
|
||||||
loss_obj=loss_obj,
|
|
||||||
train_cfg=train_cfg,
|
train_cfg=train_cfg,
|
||||||
test_cfg=test_cfg,
|
test_cfg=test_cfg,
|
||||||
init_cfg=init_cfg)
|
init_cfg=init_cfg)
|
||||||
|
# yolov6 doesn't need loss_obj
|
||||||
self.loss_bbox = MODELS.build(loss_bbox)
|
self.loss_obj = None
|
||||||
self.loss_cls = MODELS.build(loss_cls)
|
|
||||||
|
|
||||||
def special_init(self):
|
def special_init(self):
|
||||||
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
||||||
|
@ -236,10 +228,9 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
||||||
|
|
||||||
# Add common attributes to reduce calculation
|
# Add common attributes to reduce calculation
|
||||||
self.featmap_sizes = None
|
self.featmap_sizes_train = None
|
||||||
self.mlvl_priors = None
|
|
||||||
self.num_level_priors = None
|
self.num_level_priors = None
|
||||||
self.flatten_priors = None
|
self.flatten_priors_train = None
|
||||||
self.stride_tensor = None
|
self.stride_tensor = None
|
||||||
|
|
||||||
def loss_by_feat(
|
def loss_by_feat(
|
||||||
|
@ -284,19 +275,19 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
cls_score.shape[2:] for cls_score in cls_scores
|
cls_score.shape[2:] for cls_score in cls_scores
|
||||||
]
|
]
|
||||||
# If the shape does not equal, generate new one
|
# If the shape does not equal, generate new one
|
||||||
if current_featmap_sizes != self.featmap_sizes:
|
if current_featmap_sizes != self.featmap_sizes_train:
|
||||||
self.featmap_sizes = current_featmap_sizes
|
self.featmap_sizes_train = current_featmap_sizes
|
||||||
|
|
||||||
mlvl_priors = self.prior_generator.grid_priors(
|
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
||||||
self.featmap_sizes,
|
self.featmap_sizes_train,
|
||||||
dtype=cls_scores[0].dtype,
|
dtype=cls_scores[0].dtype,
|
||||||
device=cls_scores[0].device,
|
device=cls_scores[0].device,
|
||||||
with_stride=True)
|
with_stride=True)
|
||||||
|
|
||||||
self.num_level_priors = [len(n) for n in mlvl_priors]
|
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
|
||||||
self.flatten_priors = torch.cat(mlvl_priors, dim=0)
|
self.flatten_priors_train = torch.cat(
|
||||||
self.stride_tensor = self.flatten_priors[..., [2]]
|
mlvl_priors_with_stride, dim=0)
|
||||||
self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
|
self.stride_tensor = self.flatten_priors_train[..., [2]]
|
||||||
|
|
||||||
# gt info
|
# gt info
|
||||||
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
|
||||||
|
@ -319,19 +310,20 @@ class YOLOv6Head(YOLOv5Head):
|
||||||
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
||||||
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
||||||
flatten_pred_bboxes = self.bbox_coder.decode(
|
flatten_pred_bboxes = self.bbox_coder.decode(
|
||||||
self.flatten_priors[..., :2], flatten_pred_bboxes,
|
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
|
||||||
self.flatten_priors[..., 2])
|
self.stride_tensor[:, 0])
|
||||||
pred_scores = torch.sigmoid(flatten_cls_preds)
|
pred_scores = torch.sigmoid(flatten_cls_preds)
|
||||||
|
|
||||||
if current_epoch < self.initial_epoch:
|
if current_epoch < self.initial_epoch:
|
||||||
assigned_result = self.initial_assigner(
|
assigned_result = self.initial_assigner(
|
||||||
flatten_pred_bboxes.detach(), self.flatten_priors,
|
flatten_pred_bboxes.detach(), self.flatten_priors_train,
|
||||||
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
||||||
else:
|
else:
|
||||||
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
||||||
pred_scores.detach(),
|
pred_scores.detach(),
|
||||||
self.flatten_priors, gt_labels,
|
self.flatten_priors_train,
|
||||||
gt_bboxes, pad_bbox_flag)
|
gt_labels, gt_bboxes,
|
||||||
|
pad_bbox_flag)
|
||||||
|
|
||||||
assigned_bboxes = assigned_result['assigned_bboxes']
|
assigned_bboxes = assigned_result['assigned_bboxes']
|
||||||
assigned_scores = assigned_result['assigned_scores']
|
assigned_scores = assigned_result['assigned_scores']
|
||||||
|
|
|
@ -92,7 +92,7 @@ class BatchATSSAssigner(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
pred_bboxes (Tensor): Predicted bounding boxes,
|
pred_bboxes (Tensor): Predicted bounding boxes,
|
||||||
shape(batch_size, num_priors, 4)
|
shape(batch_size, num_priors, 4)
|
||||||
priors (Tensor): Model priors, shape(num_priors, 4)
|
priors (Tensor): Model priors with stride, shape(num_priors, 4)
|
||||||
num_level_priors (List): Number of bboxes in each level, len(3)
|
num_level_priors (List): Number of bboxes in each level, len(3)
|
||||||
gt_labels (Tensor): Ground truth label,
|
gt_labels (Tensor): Ground truth label,
|
||||||
shape(batch_size, num_gt, 1)
|
shape(batch_size, num_gt, 1)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Optional, Sequence, Union
|
||||||
import torch
|
import torch
|
||||||
from mmdet.models.task_modules.coders import \
|
from mmdet.models.task_modules.coders import \
|
||||||
DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder
|
DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder
|
||||||
from mmdet.structures.bbox import distance2bbox
|
from mmdet.structures.bbox import bbox2distance, distance2bbox
|
||||||
|
|
||||||
from mmyolo.registry import TASK_UTILS
|
from mmyolo.registry import TASK_UTILS
|
||||||
|
|
||||||
|
@ -51,3 +51,29 @@ class DistancePointBBoxCoder(MMDET_DistancePointBBoxCoder):
|
||||||
pred_bboxes = pred_bboxes * stride[None, :, None]
|
pred_bboxes = pred_bboxes * stride[None, :, None]
|
||||||
|
|
||||||
return distance2bbox(points, pred_bboxes, max_shape)
|
return distance2bbox(points, pred_bboxes, max_shape)
|
||||||
|
|
||||||
|
def encode(self,
|
||||||
|
points: torch.Tensor,
|
||||||
|
gt_bboxes: torch.Tensor,
|
||||||
|
max_dis: float = 16.,
|
||||||
|
eps: float = 0.01) -> torch.Tensor:
|
||||||
|
"""Encode bounding box to distances. The rewrite is to support batch
|
||||||
|
operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (Tensor): Shape (B, N, 2) or (N, 2), The format is [x, y].
|
||||||
|
gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
|
||||||
|
is "xyxy"
|
||||||
|
max_dis (float): Upper bound of the distance. Default to 16..
|
||||||
|
eps (float): a small value to ensure target < max_dis, instead <=.
|
||||||
|
Default 0.01.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Box transformation deltas. The shape is (N, 4) or
|
||||||
|
(B, N, 4).
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert points.size(-2) == gt_bboxes.size(-2)
|
||||||
|
assert points.size(-1) == 2
|
||||||
|
assert gt_bboxes.size(-1) == 4
|
||||||
|
return bbox2distance(points, gt_bboxes, max_dis, eps)
|
||||||
|
|
|
@ -4,3 +4,4 @@ Import:
|
||||||
- configs/yolox/metafile.yml
|
- configs/yolox/metafile.yml
|
||||||
- configs/rtmdet/metafile.yml
|
- configs/rtmdet/metafile.yml
|
||||||
- configs/yolov7/metafile.yml
|
- configs/yolov7/metafile.yml
|
||||||
|
- configs/ppyoloe/metafile.yml
|
||||||
|
|
|
@ -13,6 +13,8 @@ from mmyolo.datasets.transforms import (LetterResize, LoadAnnotations,
|
||||||
YOLOv5HSVRandomAug,
|
YOLOv5HSVRandomAug,
|
||||||
YOLOv5KeepRatioResize,
|
YOLOv5KeepRatioResize,
|
||||||
YOLOv5RandomAffine)
|
YOLOv5RandomAffine)
|
||||||
|
from mmyolo.datasets.transforms.transforms import (PPYOLOERandomCrop,
|
||||||
|
PPYOLOERandomDistort)
|
||||||
|
|
||||||
|
|
||||||
class TestLetterResize(unittest.TestCase):
|
class TestLetterResize(unittest.TestCase):
|
||||||
|
@ -355,3 +357,100 @@ class TestYOLOv5RandomAffine(unittest.TestCase):
|
||||||
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
||||||
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
|
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
|
||||||
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPPYOLOERandomCrop(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Setup the data info which are used in every test method.
|
||||||
|
|
||||||
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||||
|
tearDown() -> cleanUp()
|
||||||
|
"""
|
||||||
|
self.results = {
|
||||||
|
'img':
|
||||||
|
np.random.random((224, 224, 3)),
|
||||||
|
'img_shape': (224, 224),
|
||||||
|
'gt_bboxes_labels':
|
||||||
|
np.array([1, 2, 3], dtype=np.int64),
|
||||||
|
'gt_bboxes':
|
||||||
|
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
|
||||||
|
dtype=np.float32),
|
||||||
|
'gt_ignore_flags':
|
||||||
|
np.array([0, 0, 1], dtype=bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
transform = PPYOLOERandomCrop()
|
||||||
|
results = transform(copy.deepcopy(self.results))
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
|
||||||
|
results['gt_bboxes'].shape[0])
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
||||||
|
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
|
||||||
|
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
||||||
|
|
||||||
|
def test_transform_with_boxlist(self):
|
||||||
|
results = copy.deepcopy(self.results)
|
||||||
|
results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
|
||||||
|
|
||||||
|
transform = PPYOLOERandomCrop()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
|
||||||
|
results['gt_bboxes'].shape[0])
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
||||||
|
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
|
||||||
|
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPPYOLOERandomDistort(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Setup the data info which are used in every test method.
|
||||||
|
|
||||||
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||||
|
tearDown() -> cleanUp()
|
||||||
|
"""
|
||||||
|
self.results = {
|
||||||
|
'img':
|
||||||
|
np.random.random((224, 224, 3)),
|
||||||
|
'img_shape': (224, 224),
|
||||||
|
'gt_bboxes_labels':
|
||||||
|
np.array([1, 2, 3], dtype=np.int64),
|
||||||
|
'gt_bboxes':
|
||||||
|
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
|
||||||
|
dtype=np.float32),
|
||||||
|
'gt_ignore_flags':
|
||||||
|
np.array([0, 0, 1], dtype=bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_transform(self):
|
||||||
|
# test assertion for invalid prob
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
transform = PPYOLOERandomDistort(
|
||||||
|
hue_cfg=dict(min=-18, max=18, prob=1.5))
|
||||||
|
|
||||||
|
# test assertion for invalid num_distort_func
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
transform = PPYOLOERandomDistort(num_distort_func=5)
|
||||||
|
|
||||||
|
transform = PPYOLOERandomDistort()
|
||||||
|
results = transform(copy.deepcopy(self.results))
|
||||||
|
self.assertTrue(results['img'].shape[:2] == (224, 224))
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
|
||||||
|
results['gt_bboxes'].shape[0])
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
||||||
|
self.assertTrue(results['gt_bboxes'].dtype == np.float32)
|
||||||
|
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
||||||
|
|
||||||
|
def test_transform_with_boxlist(self):
|
||||||
|
results = copy.deepcopy(self.results)
|
||||||
|
results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
|
||||||
|
|
||||||
|
transform = PPYOLOERandomDistort()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
self.assertTrue(results['img'].shape[:2] == (224, 224))
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
|
||||||
|
results['gt_bboxes'].shape[0])
|
||||||
|
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
|
||||||
|
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
|
||||||
|
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
|
||||||
|
|
|
@ -47,6 +47,36 @@ class TestYOLOv5Collate(unittest.TestCase):
|
||||||
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10))
|
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10))
|
||||||
self.assertTrue(out['data_samples'].shape == (8, 6))
|
self.assertTrue(out['data_samples'].shape == (8, 6))
|
||||||
|
|
||||||
|
def test_yolov5_collate_with_multi_scale(self):
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
|
||||||
|
inputs = torch.randn((3, 10, 10))
|
||||||
|
data_samples = DetDataSample()
|
||||||
|
gt_instances = InstanceData()
|
||||||
|
bboxes = _rand_bboxes(rng, 4, 6, 8)
|
||||||
|
gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
|
||||||
|
labels = rng.randint(1, 2, size=len(bboxes))
|
||||||
|
gt_instances.labels = torch.LongTensor(labels)
|
||||||
|
data_samples.gt_instances = gt_instances
|
||||||
|
|
||||||
|
out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)],
|
||||||
|
use_ms_training=True)
|
||||||
|
self.assertIsInstance(out, dict)
|
||||||
|
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
|
||||||
|
print(out['data_samples'].shape)
|
||||||
|
self.assertTrue(out['data_samples'].shape == (4, 6))
|
||||||
|
self.assertIsInstance(out['inputs'], list)
|
||||||
|
self.assertIsInstance(out['data_samples'], torch.Tensor)
|
||||||
|
|
||||||
|
out = yolov5_collate(
|
||||||
|
[dict(inputs=inputs, data_samples=data_samples)] * 2,
|
||||||
|
use_ms_training=True)
|
||||||
|
self.assertIsInstance(out, dict)
|
||||||
|
self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
|
||||||
|
self.assertTrue(out['data_samples'].shape == (8, 6))
|
||||||
|
self.assertIsInstance(out['inputs'], list)
|
||||||
|
self.assertIsInstance(out['data_samples'], torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
class TestBatchShapePolicy(unittest.TestCase):
|
class TestBatchShapePolicy(unittest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,13 @@ from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmdet.structures import DetDataSample
|
from mmdet.structures import DetDataSample
|
||||||
|
from mmengine import MessageHub
|
||||||
|
|
||||||
|
from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
|
||||||
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
|
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
|
||||||
|
from mmyolo.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
class TestYOLOv5DetDataPreprocessor(TestCase):
|
class TestYOLOv5DetDataPreprocessor(TestCase):
|
||||||
|
@ -69,3 +74,51 @@ class TestYOLOv5DetDataPreprocessor(TestCase):
|
||||||
# data_samples must be tensor
|
# data_samples must be tensor
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
processor(data, training=True)
|
processor(data, training=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPPYOLOEDetDataPreprocessor(TestCase):
|
||||||
|
|
||||||
|
def test_batch_random_resize(self):
|
||||||
|
processor = PPYOLOEDetDataPreprocessor(
|
||||||
|
pad_size_divisor=32,
|
||||||
|
batch_augments=[
|
||||||
|
dict(
|
||||||
|
type='PPYOLOEBatchRandomResize',
|
||||||
|
random_size_range=(320, 480),
|
||||||
|
interval=1,
|
||||||
|
size_divisor=32,
|
||||||
|
random_interp=True,
|
||||||
|
keep_ratio=False)
|
||||||
|
],
|
||||||
|
mean=[0., 0., 0.],
|
||||||
|
std=[255., 255., 255.],
|
||||||
|
bgr_to_rgb=True)
|
||||||
|
self.assertTrue(
|
||||||
|
isinstance(processor.batch_augments[0], PPYOLOEBatchRandomResize))
|
||||||
|
message_hub = MessageHub.get_instance('test_batch_random_resize')
|
||||||
|
message_hub.update_info('iter', 0)
|
||||||
|
|
||||||
|
# test training
|
||||||
|
data = {
|
||||||
|
'inputs': [
|
||||||
|
torch.randint(0, 256, (3, 10, 11)),
|
||||||
|
torch.randint(0, 256, (3, 10, 11))
|
||||||
|
],
|
||||||
|
'data_samples':
|
||||||
|
torch.randint(0, 11, (18, 6)).float(),
|
||||||
|
}
|
||||||
|
out_data = processor(data, training=True)
|
||||||
|
batch_data_samples = out_data['data_samples']
|
||||||
|
self.assertIn('img_metas', batch_data_samples)
|
||||||
|
self.assertIn('bboxes_labels', batch_data_samples)
|
||||||
|
self.assertIsInstance(batch_data_samples['bboxes_labels'],
|
||||||
|
torch.Tensor)
|
||||||
|
self.assertIsInstance(batch_data_samples['img_metas'], list)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'inputs': [torch.randint(0, 256, (3, 11, 10))],
|
||||||
|
'data_samples': DetDataSample()
|
||||||
|
}
|
||||||
|
# data_samples must be list
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
processor(data, training=True)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmengine import ConfigDict, MessageHub
|
||||||
from mmengine.config import Config
|
from mmengine.config import Config
|
||||||
from mmengine.model import bias_init_with_prob
|
from mmengine.model import bias_init_with_prob
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
@ -12,11 +13,14 @@ from mmyolo.utils import register_all_modules
|
||||||
register_all_modules()
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
class TestYOLOXHead(TestCase):
|
class TestPPYOLOEHead(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.head_module = dict(
|
self.head_module = dict(
|
||||||
type='PPYOLOEHeadModule', num_classes=4, in_channels=[32, 64, 128])
|
type='PPYOLOEHeadModule',
|
||||||
|
num_classes=4,
|
||||||
|
in_channels=[32, 64, 128],
|
||||||
|
featmap_strides=(8, 16, 32))
|
||||||
|
|
||||||
def test_init_weights(self):
|
def test_init_weights(self):
|
||||||
head = PPYOLOEHead(head_module=self.head_module)
|
head = PPYOLOEHead(head_module=self.head_module)
|
||||||
|
@ -50,6 +54,7 @@ class TestYOLOXHead(TestCase):
|
||||||
max_per_img=300))
|
max_per_img=300))
|
||||||
|
|
||||||
head = PPYOLOEHead(head_module=self.head_module, test_cfg=test_cfg)
|
head = PPYOLOEHead(head_module=self.head_module, test_cfg=test_cfg)
|
||||||
|
head.eval()
|
||||||
feat = [
|
feat = [
|
||||||
torch.rand(1, in_channels, s // feat_size, s // feat_size)
|
torch.rand(1, in_channels, s // feat_size, s // feat_size)
|
||||||
for in_channels, feat_size in [[32, 8], [64, 16], [128, 32]]
|
for in_channels, feat_size in [[32, 8], [64, 16], [128, 32]]
|
||||||
|
@ -71,3 +76,130 @@ class TestYOLOXHead(TestCase):
|
||||||
cfg=test_cfg,
|
cfg=test_cfg,
|
||||||
rescale=False,
|
rescale=False,
|
||||||
with_nms=False)
|
with_nms=False)
|
||||||
|
|
||||||
|
def test_loss_by_feat(self):
|
||||||
|
message_hub = MessageHub.get_instance('test_ppyoloe_loss_by_feat')
|
||||||
|
message_hub.update_info('epoch', 1)
|
||||||
|
|
||||||
|
s = 256
|
||||||
|
img_metas = [{
|
||||||
|
'img_shape': (s, s, 3),
|
||||||
|
'batch_input_shape': (s, s),
|
||||||
|
'scale_factor': 1,
|
||||||
|
}]
|
||||||
|
|
||||||
|
head = PPYOLOEHead(
|
||||||
|
head_module=self.head_module,
|
||||||
|
train_cfg=ConfigDict(
|
||||||
|
initial_epoch=31,
|
||||||
|
initial_assigner=dict(
|
||||||
|
type='BatchATSSAssigner',
|
||||||
|
num_classes=4,
|
||||||
|
topk=9,
|
||||||
|
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||||
|
assigner=dict(
|
||||||
|
type='BatchTaskAlignedAssigner',
|
||||||
|
num_classes=4,
|
||||||
|
topk=13,
|
||||||
|
alpha=1,
|
||||||
|
beta=6)))
|
||||||
|
head.train()
|
||||||
|
|
||||||
|
feat = []
|
||||||
|
for i in range(len(self.head_module['in_channels'])):
|
||||||
|
in_channel = self.head_module['in_channels'][i]
|
||||||
|
feat_size = self.head_module['featmap_strides'][i]
|
||||||
|
feat.append(
|
||||||
|
torch.rand(1, in_channel, s // feat_size, s // feat_size))
|
||||||
|
|
||||||
|
cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
|
||||||
|
|
||||||
|
# Test that empty ground truth encourages the network to predict
|
||||||
|
# background
|
||||||
|
gt_instances = torch.empty((0, 6), dtype=torch.float32)
|
||||||
|
|
||||||
|
empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||||
|
bbox_dist_preds, gt_instances,
|
||||||
|
img_metas)
|
||||||
|
# When there is no truth, the cls loss should be nonzero but there
|
||||||
|
# should be no box loss.
|
||||||
|
empty_cls_loss = empty_gt_losses['loss_cls'].sum()
|
||||||
|
empty_box_loss = empty_gt_losses['loss_bbox'].sum()
|
||||||
|
empty_dfl_loss = empty_gt_losses['loss_dfl'].sum()
|
||||||
|
self.assertGreater(empty_cls_loss.item(), 0,
|
||||||
|
'cls loss should be non-zero')
|
||||||
|
self.assertEqual(
|
||||||
|
empty_box_loss.item(), 0,
|
||||||
|
'there should be no box loss when there are no true boxes')
|
||||||
|
self.assertEqual(
|
||||||
|
empty_dfl_loss.item(), 0,
|
||||||
|
'there should be df loss when there are no true boxes')
|
||||||
|
|
||||||
|
# When truth is non-empty then both cls and box loss should be nonzero
|
||||||
|
# for random inputs
|
||||||
|
head = PPYOLOEHead(
|
||||||
|
head_module=self.head_module,
|
||||||
|
train_cfg=ConfigDict(
|
||||||
|
initial_epoch=31,
|
||||||
|
initial_assigner=dict(
|
||||||
|
type='BatchATSSAssigner',
|
||||||
|
num_classes=4,
|
||||||
|
topk=9,
|
||||||
|
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||||
|
assigner=dict(
|
||||||
|
type='BatchTaskAlignedAssigner',
|
||||||
|
num_classes=4,
|
||||||
|
topk=13,
|
||||||
|
alpha=1,
|
||||||
|
beta=6)))
|
||||||
|
head.train()
|
||||||
|
gt_instances = torch.Tensor(
|
||||||
|
[[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
|
||||||
|
|
||||||
|
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||||
|
bbox_dist_preds, gt_instances,
|
||||||
|
img_metas)
|
||||||
|
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
|
||||||
|
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
|
||||||
|
onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
|
||||||
|
self.assertGreater(onegt_cls_loss.item(), 0,
|
||||||
|
'cls loss should be non-zero')
|
||||||
|
self.assertGreater(onegt_box_loss.item(), 0,
|
||||||
|
'box loss should be non-zero')
|
||||||
|
self.assertGreater(onegt_loss_dfl.item(), 0,
|
||||||
|
'obj loss should be non-zero')
|
||||||
|
|
||||||
|
# test num_class = 1
|
||||||
|
self.head_module['num_classes'] = 1
|
||||||
|
head = PPYOLOEHead(
|
||||||
|
head_module=self.head_module,
|
||||||
|
train_cfg=ConfigDict(
|
||||||
|
initial_epoch=31,
|
||||||
|
initial_assigner=dict(
|
||||||
|
type='BatchATSSAssigner',
|
||||||
|
num_classes=1,
|
||||||
|
topk=9,
|
||||||
|
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
|
||||||
|
assigner=dict(
|
||||||
|
type='BatchTaskAlignedAssigner',
|
||||||
|
num_classes=1,
|
||||||
|
topk=13,
|
||||||
|
alpha=1,
|
||||||
|
beta=6)))
|
||||||
|
head.train()
|
||||||
|
gt_instances = torch.Tensor(
|
||||||
|
[[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
|
||||||
|
cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
|
||||||
|
|
||||||
|
one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
|
||||||
|
bbox_dist_preds, gt_instances,
|
||||||
|
img_metas)
|
||||||
|
onegt_cls_loss = one_gt_losses['loss_cls'].sum()
|
||||||
|
onegt_box_loss = one_gt_losses['loss_bbox'].sum()
|
||||||
|
onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
|
||||||
|
self.assertGreater(onegt_cls_loss.item(), 0,
|
||||||
|
'cls loss should be non-zero')
|
||||||
|
self.assertGreater(onegt_box_loss.item(), 0,
|
||||||
|
'box loss should be non-zero')
|
||||||
|
self.assertGreater(onegt_loss_dfl.item(), 0,
|
||||||
|
'obj loss should be non-zero')
|
||||||
|
|
|
@ -5,13 +5,13 @@ from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def convert_bn(k):
|
def convert_bn(k: str):
|
||||||
name = k.replace('._mean',
|
name = k.replace('._mean',
|
||||||
'.running_mean').replace('._variance', '.running_var')
|
'.running_mean').replace('._variance', '.running_var')
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def convert_repvgg(k):
|
def convert_repvgg(k: str):
|
||||||
if '.conv2.conv1.' in k:
|
if '.conv2.conv1.' in k:
|
||||||
name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.')
|
name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.')
|
||||||
return name
|
return name
|
||||||
|
@ -22,111 +22,142 @@ def convert_repvgg(k):
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
def convert(src, dst):
|
def convert(src: str, dst: str, imagenet_pretrain: bool = False):
|
||||||
# TODO: add pretrained model convert
|
|
||||||
with open(src, 'rb') as f:
|
with open(src, 'rb') as f:
|
||||||
model = pickle.load(f)
|
model = pickle.load(f)
|
||||||
|
|
||||||
new_state_dict = OrderedDict()
|
new_state_dict = OrderedDict()
|
||||||
for k, v in model.items():
|
if imagenet_pretrain:
|
||||||
name = k
|
for k, v in model.items():
|
||||||
if k.startswith('backbone.'):
|
if '@@' in k:
|
||||||
if '.stem.' in k:
|
continue
|
||||||
|
if 'stem.' in k:
|
||||||
# backbone.stem.conv1.conv.weight
|
# backbone.stem.conv1.conv.weight
|
||||||
# -> backbone.stem.0.conv.weight
|
# -> backbone.stem.0.conv.weight
|
||||||
org_ind = k.split('.')[2][-1]
|
org_ind = k.split('.')[1][-1]
|
||||||
new_ind = str(int(org_ind) - 1)
|
new_ind = str(int(org_ind) - 1)
|
||||||
name = k.replace('.stem.conv%s.' % org_ind,
|
name = k.replace('stem.conv%s.' % org_ind,
|
||||||
'.stem.%s.' % new_ind)
|
'stem.%s.' % new_ind)
|
||||||
else:
|
else:
|
||||||
# backbone.stages.1.conv2.bn._variance
|
# backbone.stages.1.conv2.bn._variance
|
||||||
# -> backbone.stage2.0.conv2.bn.running_var
|
# -> backbone.stage2.0.conv2.bn.running_var
|
||||||
org_stage_ind = k.split('.')[2]
|
org_stage_ind = k.split('.')[1]
|
||||||
new_stage_ind = str(int(org_stage_ind) + 1)
|
new_stage_ind = str(int(org_stage_ind) + 1)
|
||||||
name = k.replace('.stages.%s.' % org_stage_ind,
|
name = k.replace('stages.%s.' % org_stage_ind,
|
||||||
'.stage%s.0.' % new_stage_ind)
|
'stage%s.0.' % new_stage_ind)
|
||||||
name = convert_repvgg(name)
|
name = convert_repvgg(name)
|
||||||
if '.attn.' in k:
|
if '.attn.' in k:
|
||||||
name = name.replace('.attn.fc.', '.attn.fc.conv.')
|
name = name.replace('.attn.fc.', '.attn.fc.conv.')
|
||||||
name = convert_bn(name)
|
name = convert_bn(name)
|
||||||
elif k.startswith('neck.'):
|
name = 'backbone.' + name
|
||||||
# fpn_stages
|
|
||||||
if k.startswith('neck.fpn_stages.'):
|
new_state_dict[name] = torch.from_numpy(v)
|
||||||
# neck.fpn_stages.0.0.conv1.conv.weight
|
else:
|
||||||
# -> neck.reduce_layers.2.0.conv1.conv.weight
|
for k, v in model.items():
|
||||||
if k.startswith('neck.fpn_stages.0.0.'):
|
name = k
|
||||||
name = k.replace('neck.fpn_stages.0.0.',
|
if k.startswith('backbone.'):
|
||||||
'neck.reduce_layers.2.0.')
|
if '.stem.' in k:
|
||||||
if '.spp.' in name:
|
# backbone.stem.conv1.conv.weight
|
||||||
name = name.replace('.spp.conv.', '.spp.conv2.')
|
# -> backbone.stem.0.conv.weight
|
||||||
# neck.fpn_stages.1.0.conv1.conv.weight
|
org_ind = k.split('.')[2][-1]
|
||||||
# -> neck.top_down_layers.0.0.conv1.conv.weight
|
new_ind = str(int(org_ind) - 1)
|
||||||
elif k.startswith('neck.fpn_stages.1.0.'):
|
name = k.replace('.stem.conv%s.' % org_ind,
|
||||||
name = k.replace('neck.fpn_stages.1.0.',
|
'.stem.%s.' % new_ind)
|
||||||
'neck.top_down_layers.0.0.')
|
|
||||||
elif k.startswith('neck.fpn_stages.2.0.'):
|
|
||||||
name = k.replace('neck.fpn_stages.2.0.',
|
|
||||||
'neck.top_down_layers.1.0.')
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Not implemented.')
|
# backbone.stages.1.conv2.bn._variance
|
||||||
name = name.replace('.0.convs.', '.0.blocks.')
|
# -> backbone.stage2.0.conv2.bn.running_var
|
||||||
elif k.startswith('neck.fpn_routes.'):
|
org_stage_ind = k.split('.')[2]
|
||||||
# neck.fpn_routes.0.conv.weight
|
new_stage_ind = str(int(org_stage_ind) + 1)
|
||||||
# -> neck.upsample_layers.0.0.conv.weight
|
name = k.replace('.stages.%s.' % org_stage_ind,
|
||||||
index = k.split('.')[2]
|
'.stage%s.0.' % new_stage_ind)
|
||||||
name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
|
name = convert_repvgg(name)
|
||||||
k.split('.')[-2:])
|
if '.attn.' in k:
|
||||||
name = name.replace('.0.convs.', '.0.blocks.')
|
name = name.replace('.attn.fc.', '.attn.fc.conv.')
|
||||||
elif k.startswith('neck.pan_stages.'):
|
name = convert_bn(name)
|
||||||
# neck.pan_stages.0.0.conv1.conv.weight
|
elif k.startswith('neck.'):
|
||||||
# -> neck.bottom_up_layers.1.0.conv1.conv.weight
|
# fpn_stages
|
||||||
ind = k.split('.')[2]
|
if k.startswith('neck.fpn_stages.'):
|
||||||
name = k.replace(
|
# neck.fpn_stages.0.0.conv1.conv.weight
|
||||||
'neck.pan_stages.' + ind,
|
# -> neck.reduce_layers.2.0.conv1.conv.weight
|
||||||
'neck.bottom_up_layers.' + ('0' if ind == '1' else '1'))
|
if k.startswith('neck.fpn_stages.0.0.'):
|
||||||
name = name.replace('.0.convs.', '.0.blocks.')
|
name = k.replace('neck.fpn_stages.0.0.',
|
||||||
elif k.startswith('neck.pan_routes.'):
|
'neck.reduce_layers.2.0.')
|
||||||
# neck.pan_routes.0.conv.weight
|
if '.spp.' in name:
|
||||||
# -> neck.downsample_layers.0.conv.weight
|
name = name.replace('.spp.conv.', '.spp.conv2.')
|
||||||
ind = k.split('.')[2]
|
# neck.fpn_stages.1.0.conv1.conv.weight
|
||||||
name = k.replace(
|
# -> neck.top_down_layers.0.0.conv1.conv.weight
|
||||||
'neck.pan_routes.' + ind,
|
elif k.startswith('neck.fpn_stages.1.0.'):
|
||||||
'neck.downsample_layers.' + ('0' if ind == '1' else '1'))
|
name = k.replace('neck.fpn_stages.1.0.',
|
||||||
name = name.replace('.0.convs.', '.0.blocks.')
|
'neck.top_down_layers.0.0.')
|
||||||
|
elif k.startswith('neck.fpn_stages.2.0.'):
|
||||||
|
name = k.replace('neck.fpn_stages.2.0.',
|
||||||
|
'neck.top_down_layers.1.0.')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Not implemented.')
|
||||||
|
name = name.replace('.0.convs.', '.0.blocks.')
|
||||||
|
elif k.startswith('neck.fpn_routes.'):
|
||||||
|
# neck.fpn_routes.0.conv.weight
|
||||||
|
# -> neck.upsample_layers.0.0.conv.weight
|
||||||
|
index = k.split('.')[2]
|
||||||
|
name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
|
||||||
|
k.split('.')[-2:])
|
||||||
|
name = name.replace('.0.convs.', '.0.blocks.')
|
||||||
|
elif k.startswith('neck.pan_stages.'):
|
||||||
|
# neck.pan_stages.0.0.conv1.conv.weight
|
||||||
|
# -> neck.bottom_up_layers.1.0.conv1.conv.weight
|
||||||
|
ind = k.split('.')[2]
|
||||||
|
name = k.replace(
|
||||||
|
'neck.pan_stages.' + ind, 'neck.bottom_up_layers.' +
|
||||||
|
('0' if ind == '1' else '1'))
|
||||||
|
name = name.replace('.0.convs.', '.0.blocks.')
|
||||||
|
elif k.startswith('neck.pan_routes.'):
|
||||||
|
# neck.pan_routes.0.conv.weight
|
||||||
|
# -> neck.downsample_layers.0.conv.weight
|
||||||
|
ind = k.split('.')[2]
|
||||||
|
name = k.replace(
|
||||||
|
'neck.pan_routes.' + ind, 'neck.downsample_layers.' +
|
||||||
|
('0' if ind == '1' else '1'))
|
||||||
|
name = name.replace('.0.convs.', '.0.blocks.')
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Not implement.')
|
||||||
|
name = convert_repvgg(name)
|
||||||
|
name = convert_bn(name)
|
||||||
|
elif k.startswith('yolo_head.'):
|
||||||
|
if ('anchor_points' in k) or ('stride_tensor' in k):
|
||||||
|
continue
|
||||||
|
if 'proj_conv' in k:
|
||||||
|
name = k.replace('yolo_head.proj_conv.',
|
||||||
|
'bbox_head.head_module.proj_conv.')
|
||||||
|
else:
|
||||||
|
for org_key, rep_key in [
|
||||||
|
[
|
||||||
|
'yolo_head.stem_cls.',
|
||||||
|
'bbox_head.head_module.cls_stems.'
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'yolo_head.stem_reg.',
|
||||||
|
'bbox_head.head_module.reg_stems.'
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'yolo_head.pred_cls.',
|
||||||
|
'bbox_head.head_module.cls_preds.'
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'yolo_head.pred_reg.',
|
||||||
|
'bbox_head.head_module.reg_preds.'
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
name = name.replace(org_key, rep_key)
|
||||||
|
name = name.split('.')
|
||||||
|
ind = name[3]
|
||||||
|
name[3] = str(2 - int(ind))
|
||||||
|
name = '.'.join(name)
|
||||||
|
name = convert_bn(name)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Not implement.')
|
|
||||||
name = convert_repvgg(name)
|
|
||||||
name = convert_bn(name)
|
|
||||||
elif k.startswith('yolo_head.'):
|
|
||||||
if ('anchor_points' in k) or ('stride_tensor' in k):
|
|
||||||
continue
|
continue
|
||||||
if 'proj_conv' in k:
|
|
||||||
name = k.replace('yolo_head.proj_conv.',
|
|
||||||
'bbox_head.head_module.proj_conv.')
|
|
||||||
else:
|
|
||||||
for org_key, rep_key in [[
|
|
||||||
'yolo_head.stem_cls.',
|
|
||||||
'bbox_head.head_module.cls_stems.'
|
|
||||||
], ['yolo_head.stem_reg.', 'bbox_head.head_module.reg_stems.'],
|
|
||||||
[
|
|
||||||
'yolo_head.pred_cls.',
|
|
||||||
'bbox_head.head_module.cls_preds.'
|
|
||||||
],
|
|
||||||
[
|
|
||||||
'yolo_head.pred_reg.',
|
|
||||||
'bbox_head.head_module.reg_preds.'
|
|
||||||
]]:
|
|
||||||
name = name.replace(org_key, rep_key)
|
|
||||||
name = name.split('.')
|
|
||||||
ind = name[3]
|
|
||||||
name[3] = str(2 - int(ind))
|
|
||||||
name = '.'.join(name)
|
|
||||||
name = convert_bn(name)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_state_dict[name] = torch.from_numpy(v)
|
new_state_dict[name] = torch.from_numpy(v)
|
||||||
data = {'state_dict': new_state_dict}
|
data = {'state_dict': new_state_dict}
|
||||||
torch.save(data, dst)
|
torch.save(data, dst)
|
||||||
|
|
||||||
|
@ -139,8 +170,14 @@ def main():
|
||||||
help='src ppyoloe model path')
|
help='src ppyoloe model path')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dst', default='mmppyoloe_plus_s.pt', help='save path')
|
'--dst', default='mmppyoloe_plus_s.pt', help='save path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--imagenet-pretrain',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Load model pretrained on imagenet dataset which only '
|
||||||
|
'have weight for backbone.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.src, args.dst)
|
convert(args.src, args.dst, args.imagenet_pretrain)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue