diff --git a/configs/ppyoloe/README.md b/configs/ppyoloe/README.md
new file mode 100644
index 00000000..a7b23227
--- /dev/null
+++ b/configs/ppyoloe/README.md
@@ -0,0 +1,38 @@
+# PPYOLOE
+
+
+
+## Abstract
+
+PP-YOLOE is an excellent single-stage anchor-free model based on PP-YOLOv2, surpassing a variety of popular YOLO models. PP-YOLOE has a series of models, named s/m/l/x, which are configured through width multiplier and depth multiplier. PP-YOLOE avoids using special operators, such as Deformable Convolution or Matrix NMS, to be deployed friendly on various hardware.
+
+
+

+
+
+## Results and models
+
+### PPYOLOE+ COCO
+
+| Backbone | Arch | Size | Epoch | SyncBN | Mem (GB) | Box AP | Config | Download |
+| :---------: | :--: | :--: | :---: | :----: | :------: | :----: | :-------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| PPYOLOE+ -s | P5 | 640 | 80 | Yes | 4.7 | 43.5 | [config](../ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052.log.json) |
+| PPYOLOE+ -m | P5 | 640 | 80 | Yes | 8.4 | 49.5 | [config](../ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132.log.json) |
+| PPYOLOE+ -l | P5 | 640 | 80 | Yes | 13.2 | 52.6 | [config](../ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825.log.json) |
+| PPYOLOE+ -x | P5 | 640 | 80 | Yes | 19.1 | 54.2 | [config](../ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth) \| [log](https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921.log.json) |
+
+**Note**:
+
+1. The above Box APs are all models with the best performance in COCO
+2. The gap between the above performance and the official release is about 0.3. To speed up training in mmyolo, we use pytorch to implement the image resizing in `PPYOLOEBatchRandomResize` for multi-scale training, while official PPYOLOE use opencv. And `lanczos4` is not yet supported in `PPYOLOEBatchRandomResize`. The above two reasons lead to the gap. We will continue to experiment and address the gap in future releases.
+3. The mAP of the non-Plus version needs more verification, and we will update more details of the non-Plus version in future versions.
+
+```latex
+@article{Xu2022PPYOLOEAE,
+ title={PP-YOLOE: An evolved version of YOLO},
+ author={Shangliang Xu and Xinxin Wang and Wenyu Lv and Qinyao Chang and Cheng Cui and Kaipeng Deng and Guanzhong Wang and Qingqing Dang and Shengyun Wei and Yuning Du and Baohua Lai},
+ journal={ArXiv},
+ year={2022},
+ volume={abs/2203.16250}
+}
+```
diff --git a/configs/ppyoloe/metafile.yml b/configs/ppyoloe/metafile.yml
new file mode 100644
index 00000000..5b7ed948
--- /dev/null
+++ b/configs/ppyoloe/metafile.yml
@@ -0,0 +1,69 @@
+Collections:
+ - Name: PPYOLOE
+ Metadata:
+ Training Data: COCO
+ Training Techniques:
+ - SGD with Nesterov
+ - Weight Decay
+ - Synchronize BN
+ Training Resources: 8x A100 GPUs
+ Architecture:
+ - PPYOLOECSPResNet
+ - PPYOLOECSPPAFPN
+ Paper:
+ URL: https://arxiv.org/abs/2203.16250
+ Title: 'PP-YOLOE: An evolved version of YOLO'
+ README: configs/ppyoloe/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmyolo/blob/v0.0.1/mmyolo/models/detectors/yolo_detector.py#L12
+ Version: v0.0.1
+
+Models:
+ - Name: ppyoloe_plus_s_fast_8xb8-80e_coco
+ In Collection: PPYOLOE
+ Config: configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
+ Metadata:
+ Training Memory (GB): 4.7
+ Epochs: 80
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 43.5
+ Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco/ppyoloe_plus_s_fast_8xb8-80e_coco_20230101_154052-9fee7619.pth
+ - Name: ppyoloe_plus_m_fast_8xb8-80e_coco
+ In Collection: PPYOLOE
+ Config: configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
+ Metadata:
+ Training Memory (GB): 8.4
+ Epochs: 80
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 49.5
+ Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco/ppyoloe_plus_m_fast_8xb8-80e_coco_20230104_193132-e4325ada.pth
+ - Name: ppyoloe_plus_L_fast_8xb8-80e_coco
+ In Collection: PPYOLOE
+ Config: configs/ppyoloe/ppyoloe_plus_L_fast_8xb8-80e_coco.py
+ Metadata:
+ Training Memory (GB): 13.2
+ Epochs: 80
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 52.6
+ Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco/ppyoloe_plus_l_fast_8xb8-80e_coco_20230102_203825-1864e7b3.pth
+ - Name: ppyoloe_plus_x_fast_8xb8-80e_coco
+ In Collection: PPYOLOE
+ Config: configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
+ Metadata:
+ Training Memory (GB): 19.1
+ Epochs: 80
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 54.2
+ Weights: https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco/ppyoloe_plus_x_fast_8xb8-80e_coco_20230104_194921-8c953949.pth
diff --git a/configs/ppyoloe/ppyoloe_l_fast_8xb20-300e_coco.py b/configs/ppyoloe/ppyoloe_l_fast_8xb20-300e_coco.py
index 3ef870e5..ef1b4eaa 100644
--- a/configs/ppyoloe/ppyoloe_l_fast_8xb20-300e_coco.py
+++ b/configs/ppyoloe/ppyoloe_l_fast_8xb20-300e_coco.py
@@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_l_imagenet1k_pretrained-c0010e6c.pth' # noqa
+
deepen_factor = 1.0
widen_factor = 1.0
-# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 20
model = dict(
- backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
+ backbone=dict(
+ deepen_factor=deepen_factor,
+ widen_factor=widen_factor,
+ init_cfg=dict(checkpoint=checkpoint)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
+
+train_dataloader = dict(batch_size=train_batch_size_per_gpu)
diff --git a/configs/ppyoloe/ppyoloe_m_fast_8xb28-300e_coco.py b/configs/ppyoloe/ppyoloe_m_fast_8xb28-300e_coco.py
index 77b49b76..abcfd783 100644
--- a/configs/ppyoloe/ppyoloe_m_fast_8xb28-300e_coco.py
+++ b/configs/ppyoloe/ppyoloe_m_fast_8xb28-300e_coco.py
@@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_m_imagenet1k_pretrained-09f1eba2.pth' # noqa
+
deepen_factor = 0.67
widen_factor = 0.75
-# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 28
model = dict(
- backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
+ backbone=dict(
+ deepen_factor=deepen_factor,
+ widen_factor=widen_factor,
+ init_cfg=dict(checkpoint=checkpoint)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
+
+train_dataloader = dict(batch_size=train_batch_size_per_gpu)
diff --git a/configs/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py b/configs/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py
index 3741d5f0..9db53e26 100644
--- a/configs/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py
+++ b/configs/ppyoloe/ppyoloe_plus_l_fast_8xb8-80e_coco.py
@@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_l_obj365_pretrained-3dd89562.pth' # noqa
+
deepen_factor = 1.0
widen_factor = 1.0
diff --git a/configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py b/configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
index af85f310..17cb3355 100644
--- a/configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
+++ b/configs/ppyoloe/ppyoloe_plus_m_fast_8xb8-80e_coco.py
@@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_m_ojb365_pretrained-03206892.pth' # noqa
+
deepen_factor = 0.67
widen_factor = 0.75
diff --git a/configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py b/configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
index a5931ec3..d46d6d82 100644
--- a/configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
+++ b/configs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py
@@ -9,21 +9,40 @@ img_scale = (640, 640) # height, width
deepen_factor = 0.33
widen_factor = 0.5
max_epochs = 80
-save_epoch_intervals = 10
+num_classes = 80
+save_epoch_intervals = 5
train_batch_size_per_gpu = 8
train_num_workers = 8
val_batch_size_per_gpu = 1
val_num_workers = 2
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_s_obj365_pretrained-bcfe8478.pth' # noqa
+
# persistent_workers must be False if num_workers is 0.
persistent_workers = True
+# Base learning rate for optim_wrapper
+base_lr = 0.001
+
strides = [8, 16, 32]
model = dict(
type='YOLODetector',
data_preprocessor=dict(
- type='YOLOv5DetDataPreprocessor',
+ # use this to support multi_scale training
+ type='PPYOLOEDetDataPreprocessor',
+ pad_size_divisor=32,
+ batch_augments=[
+ dict(
+ type='PPYOLOEBatchRandomResize',
+ random_size_range=(320, 800),
+ interval=1,
+ size_divisor=32,
+ random_interp=True,
+ keep_ratio=False)
+ ],
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
@@ -56,11 +75,52 @@ model = dict(
type='PPYOLOEHead',
head_module=dict(
type='PPYOLOEHeadModule',
- num_classes=80,
+ num_classes=num_classes,
in_channels=[192, 384, 768],
widen_factor=widen_factor,
featmap_strides=strides,
- num_base_priors=1)),
+ reg_max=16,
+ norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
+ act_cfg=dict(type='SiLU', inplace=True),
+ num_base_priors=1),
+ prior_generator=dict(
+ type='mmdet.MlvlPointGenerator', offset=0.5, strides=strides),
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
+ loss_cls=dict(
+ type='mmdet.VarifocalLoss',
+ use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ reduction='sum',
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='IoULoss',
+ iou_mode='giou',
+ bbox_format='xyxy',
+ reduction='mean',
+ loss_weight=2.5,
+ return_iou=False),
+ # Since the dflloss is implemented differently in the official
+ # and mmdet, we're going to divide loss_weight by 4.
+ loss_dfl=dict(
+ type='mmdet.DistributionFocalLoss',
+ reduction='mean',
+ loss_weight=0.5 / 4)),
+ train_cfg=dict(
+ initial_epoch=30,
+ initial_assigner=dict(
+ type='BatchATSSAssigner',
+ num_classes=num_classes,
+ topk=9,
+ iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
+ assigner=dict(
+ type='BatchTaskAlignedAssigner',
+ num_classes=num_classes,
+ topk=13,
+ alpha=1,
+ beta=6,
+ eps=1e-9)),
test_cfg=dict(
multi_label=True,
nms_pre=1000,
@@ -68,10 +128,36 @@ model = dict(
nms=dict(type='nms', iou_threshold=0.7),
max_per_img=300))
-test_pipeline = [
+train_pipeline = [
+ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='PPYOLOERandomDistort'),
+ dict(type='mmdet.Expand', mean=(103.53, 116.28, 123.675)),
+ dict(type='PPYOLOERandomCrop'),
+ dict(type='mmdet.RandomFlip', prob=0.5),
dict(
- type='LoadImageFromFile',
- file_client_args={{_base_.file_client_args}}),
+ type='mmdet.PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
+ 'flip_direction'))
+]
+
+train_dataloader = dict(
+ batch_size=train_batch_size_per_gpu,
+ num_workers=train_num_workers,
+ persistent_workers=persistent_workers,
+ pin_memory=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='yolov5_collate', use_ms_training=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='annotations/instances_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ filter_cfg=dict(filter_empty_gt=True, min_size=0),
+ pipeline=train_pipeline))
+
+test_pipeline = [
+ dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(
type='mmdet.FixShapeResize',
width=img_scale[1],
@@ -103,6 +189,41 @@ val_dataloader = dict(
test_dataloader = val_dataloader
+param_scheduler = None
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='SGD',
+ lr=base_lr,
+ momentum=0.9,
+ weight_decay=5e-4,
+ nesterov=False),
+ paramwise_cfg=dict(norm_decay_mult=0.))
+
+default_hooks = dict(
+ param_scheduler=dict(
+ type='PPYOLOEParamSchedulerHook',
+ warmup_min_iter=1000,
+ start_factor=0.,
+ warmup_epochs=5,
+ min_lr_ratio=0.0,
+ total_epochs=int(max_epochs * 1.2)),
+ checkpoint=dict(
+ type='CheckpointHook',
+ interval=save_epoch_intervals,
+ save_best='auto',
+ max_keep_ckpts=3))
+
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ strict_load=False,
+ priority=49)
+]
+
val_evaluator = dict(
type='mmdet.CocoMetric',
proposal_nums=(100, 1, 10),
@@ -110,5 +231,9 @@ val_evaluator = dict(
metric='bbox')
test_evaluator = val_evaluator
+train_cfg = dict(
+ type='EpochBasedTrainLoop',
+ max_epochs=max_epochs,
+ val_interval=save_epoch_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
diff --git a/configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py b/configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
index 1d598177..b8e61120 100644
--- a/configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
+++ b/configs/ppyoloe/ppyoloe_plus_x_fast_8xb8-80e_coco.py
@@ -1,5 +1,9 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+load_from = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/ppyoloe_plus_x_obj365_pretrained-43a8000d.pth' # noqa
+
deepen_factor = 1.33
widen_factor = 1.25
diff --git a/configs/ppyoloe/ppyoloe_s_fast_8xb32-300e_coco.py b/configs/ppyoloe/ppyoloe_s_fast_8xb32-300e_coco.py
index 002de203..62233289 100644
--- a/configs/ppyoloe/ppyoloe_s_fast_8xb32-300e_coco.py
+++ b/configs/ppyoloe/ppyoloe_s_fast_8xb32-300e_coco.py
@@ -1,11 +1,36 @@
_base_ = './ppyoloe_plus_s_fast_8xb8-80e_coco.py'
-# TODO: training on ppyoloe need to be implemented.
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_s_imagenet1k_pretrained-2be81763.pth' # noqa
+
train_batch_size_per_gpu = 32
max_epochs = 300
+# Base learning rate for optim_wrapper
+base_lr = 0.01
+
model = dict(
data_preprocessor=dict(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255., 0.224 * 255., 0.225 * 255.]),
- backbone=dict(block_cfg=dict(use_alpha=False)))
+ backbone=dict(
+ block_cfg=dict(use_alpha=False),
+ init_cfg=dict(
+ type='Pretrained',
+ prefix='backbone.',
+ checkpoint=checkpoint,
+ map_location='cpu')),
+ train_cfg=dict(initial_epoch=100))
+
+train_dataloader = dict(batch_size=train_batch_size_per_gpu)
+
+optim_wrapper = dict(optimizer=dict(lr=base_lr))
+
+default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
+
+train_cfg = dict(max_epochs=max_epochs)
+
+# PPYOLOE plus use obj365 pretrained model, but PPYOLOE not,
+# `load_from` need to set to None.
+load_from = None
diff --git a/configs/ppyoloe/ppyoloe_s_fast_8xb32-400e_coco.py b/configs/ppyoloe/ppyoloe_s_fast_8xb32-400e_coco.py
index 9efb6402..bef9e913 100644
--- a/configs/ppyoloe/ppyoloe_s_fast_8xb32-400e_coco.py
+++ b/configs/ppyoloe/ppyoloe_s_fast_8xb32-400e_coco.py
@@ -1,4 +1,9 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
-# TODO: training on ppyoloe need to be implemented.
max_epochs = 400
+
+model = dict(train_cfg=dict(initial_epoch=133))
+
+default_hooks = dict(param_scheduler=dict(total_epochs=int(max_epochs * 1.2)))
+
+train_cfg = dict(max_epochs=max_epochs)
diff --git a/configs/ppyoloe/ppyoloe_x_fast_8xb16-300e_coco.py b/configs/ppyoloe/ppyoloe_x_fast_8xb16-300e_coco.py
index 86cdfc19..fed594f0 100644
--- a/configs/ppyoloe/ppyoloe_x_fast_8xb16-300e_coco.py
+++ b/configs/ppyoloe/ppyoloe_x_fast_8xb16-300e_coco.py
@@ -1,15 +1,23 @@
_base_ = './ppyoloe_s_fast_8xb32-300e_coco.py'
+# The pretrained model is geted and converted from official PPYOLOE.
+# https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/configs/ppyoloe/README.md
+checkpoint = 'https://download.openmmlab.com/mmyolo/v0/ppyoloe/ppyoloe_pretrain/cspresnet_x_imagenet1k_pretrained-81c33ccb.pth' # noqa
+
deepen_factor = 1.33
widen_factor = 1.25
-# TODO: training on ppyoloe need to be implemented.
train_batch_size_per_gpu = 16
model = dict(
- backbone=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
+ backbone=dict(
+ deepen_factor=deepen_factor,
+ widen_factor=widen_factor,
+ init_cfg=dict(checkpoint=checkpoint)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
),
bbox_head=dict(head_module=dict(widen_factor=widen_factor)))
+
+train_dataloader = dict(batch_size=train_batch_size_per_gpu)
diff --git a/mmyolo/datasets/transforms/__init__.py b/mmyolo/datasets/transforms/__init__.py
index 842ad641..ea1cd41e 100644
--- a/mmyolo/datasets/transforms/__init__.py
+++ b/mmyolo/datasets/transforms/__init__.py
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
-from .transforms import (LetterResize, LoadAnnotations, YOLOv5HSVRandomAug,
+from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
+ PPYOLOERandomDistort, YOLOv5HSVRandomAug,
YOLOv5KeepRatioResize, YOLOv5RandomAffine)
__all__ = [
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
- 'YOLOv5RandomAffine', 'Mosaic9'
+ 'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
+ 'Mosaic9'
]
diff --git a/mmyolo/datasets/transforms/transforms.py b/mmyolo/datasets/transforms/transforms.py
index cb025cd5..adfcfeb8 100644
--- a/mmyolo/datasets/transforms/transforms.py
+++ b/mmyolo/datasets/transforms/transforms.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
-from typing import Tuple, Union
+from typing import List, Tuple, Union
import cv2
import mmcv
@@ -675,3 +675,397 @@ class YOLOv5RandomAffine(BaseTransform):
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
dtype=np.float32)
return translation_matrix
+
+
+@TRANSFORMS.register_module()
+class PPYOLOERandomDistort(BaseTransform):
+ """Random hue, saturation, contrast and brightness distortion.
+
+ Required Keys:
+
+ - img
+
+ Modified Keys:
+
+ - img (np.float32)
+
+ Args:
+ hue_cfg (dict): Hue settings. Defaults to dict(min=-18,
+ max=18, prob=0.5).
+ saturation_cfg (dict): Saturation settings. Defaults to dict(
+ min=0.5, max=1.5, prob=0.5).
+ contrast_cfg (dict): Contrast settings. Defaults to dict(
+ min=0.5, max=1.5, prob=0.5).
+ brightness_cfg (dict): Brightness settings. Defaults to dict(
+ min=0.5, max=1.5, prob=0.5).
+ num_distort_func (int): The number of distort function. Defaults
+ to 4.
+ """
+
+ def __init__(self,
+ hue_cfg: dict = dict(min=-18, max=18, prob=0.5),
+ saturation_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
+ contrast_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
+ brightness_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
+ num_distort_func: int = 4):
+ self.hue_cfg = hue_cfg
+ self.saturation_cfg = saturation_cfg
+ self.contrast_cfg = contrast_cfg
+ self.brightness_cfg = brightness_cfg
+ self.num_distort_func = num_distort_func
+ assert 0 < self.num_distort_func <= 4,\
+ 'num_distort_func must > 0 and <= 4'
+ for cfg in [
+ self.hue_cfg, self.saturation_cfg, self.contrast_cfg,
+ self.brightness_cfg
+ ]:
+ assert 0. <= cfg['prob'] <= 1., 'prob must >=0 and <=1'
+
+ def transform_hue(self, results):
+ """Transform hue randomly."""
+ if random.uniform(0., 1.) >= self.hue_cfg['prob']:
+ return results
+ img = results['img']
+ delta = random.uniform(self.hue_cfg['min'], self.hue_cfg['max'])
+ u = np.cos(delta * np.pi)
+ w = np.sin(delta * np.pi)
+ delta_iq = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
+ rgb2yiq_matrix = np.array([[0.114, 0.587, 0.299],
+ [-0.321, -0.274, 0.596],
+ [0.311, -0.523, 0.211]])
+ yiq2rgb_matric = np.array([[1.0, -1.107, 1.705], [1.0, -0.272, -0.647],
+ [1.0, 0.956, 0.621]])
+ t = np.dot(np.dot(yiq2rgb_matric, delta_iq), rgb2yiq_matrix).T
+ img = np.dot(img, t)
+ results['img'] = img
+ return results
+
+ def transform_saturation(self, results):
+ """Transform saturation randomly."""
+ if random.uniform(0., 1.) >= self.saturation_cfg['prob']:
+ return results
+ img = results['img']
+ delta = random.uniform(self.saturation_cfg['min'],
+ self.saturation_cfg['max'])
+
+ # convert bgr img to gray img
+ gray = img * np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
+ gray = gray.sum(axis=2, keepdims=True)
+ gray *= (1.0 - delta)
+ img *= delta
+ img += gray
+ results['img'] = img
+ return results
+
+ def transform_contrast(self, results):
+ """Transform contrast randomly."""
+ if random.uniform(0., 1.) >= self.contrast_cfg['prob']:
+ return results
+ img = results['img']
+ delta = random.uniform(self.contrast_cfg['min'],
+ self.contrast_cfg['max'])
+ img *= delta
+ results['img'] = img
+ return results
+
+ def transform_brightness(self, results):
+ """Transform brightness randomly."""
+ if random.uniform(0., 1.) >= self.brightness_cfg['prob']:
+ return results
+ img = results['img']
+ delta = random.uniform(self.brightness_cfg['min'],
+ self.brightness_cfg['max'])
+ img += delta
+ results['img'] = img
+ return results
+
+ def transform(self, results: dict) -> dict:
+ """The hue, saturation, contrast and brightness distortion function.
+
+ Args:
+ results (dict): The result dict.
+
+ Returns:
+ dict: The result dict.
+ """
+ results['img'] = results['img'].astype(np.float32)
+
+ functions = [
+ self.transform_brightness, self.transform_contrast,
+ self.transform_saturation, self.transform_hue
+ ]
+ distortions = random.permutation(functions)[:self.num_distort_func]
+ for func in distortions:
+ results = func(results)
+ return results
+
+
+@TRANSFORMS.register_module()
+class PPYOLOERandomCrop(BaseTransform):
+ """Random crop the img and bboxes. Different thresholds are used in PPYOLOE
+ to judge whether the clipped image meets the requirements. This
+ implementation is different from the implementation of RandomCrop in mmdet.
+
+ Required Keys:
+
+ - img
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
+ - gt_bboxes_labels (np.int64) (optional)
+ - gt_ignore_flags (bool) (optional)
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - gt_bboxes (optional)
+ - gt_bboxes_labels (optional)
+ - gt_ignore_flags (optional)
+
+ Added Keys:
+ - pad_param (np.float32)
+
+ Args:
+ aspect_ratio (List[float]): Aspect ratio of cropped region. Default to
+ [.5, 2].
+ thresholds (List[float]): Iou thresholds for decide a valid bbox crop
+ in [min, max] format. Defaults to [.0, .1, .3, .5, .7, .9].
+ scaling (List[float]): Ratio between a cropped region and the original
+ image in [min, max] format. Default to [.3, 1.].
+ num_attempts (int): Number of tries for each threshold before
+ giving up. Default to 50.
+ allow_no_crop (bool): Allow return without actually cropping them.
+ Default to True.
+ cover_all_box (bool): Ensure all bboxes are covered in the final crop.
+ Default to False.
+ """
+
+ def __init__(self,
+ aspect_ratio: List[float] = [.5, 2.],
+ thresholds: List[float] = [.0, .1, .3, .5, .7, .9],
+ scaling: List[float] = [.3, 1.],
+ num_attempts: int = 50,
+ allow_no_crop: bool = True,
+ cover_all_box: bool = False):
+ self.aspect_ratio = aspect_ratio
+ self.thresholds = thresholds
+ self.scaling = scaling
+ self.num_attempts = num_attempts
+ self.allow_no_crop = allow_no_crop
+ self.cover_all_box = cover_all_box
+
+ def _crop_data(self, results: dict, crop_box: Tuple[int, int, int, int],
+ valid_inds: np.ndarray) -> Union[dict, None]:
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_box (Tuple[int, int, int, int]): Expected absolute coordinates
+ for cropping, (x1, y1, x2, y2).
+ valid_inds (np.ndarray): The indexes of gt that needs to be
+ retained.
+
+ Returns:
+ results (Union[dict, None]): Randomly cropped results, 'img_shape'
+ key in result dict is updated according to crop size. None will
+ be returned when there is no valid bbox after cropping.
+ """
+ # crop the image
+ img = results['img']
+ crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img.shape
+
+ # crop bboxes accordingly and clip to the image boundary
+ if results.get('gt_bboxes', None) is not None:
+ bboxes = results['gt_bboxes']
+ bboxes.translate_([-crop_x1, -crop_y1])
+ bboxes.clip_(img_shape[:2])
+
+ results['gt_bboxes'] = bboxes[valid_inds]
+
+ if results.get('gt_ignore_flags', None) is not None:
+ results['gt_ignore_flags'] = \
+ results['gt_ignore_flags'][valid_inds]
+
+ if results.get('gt_bboxes_labels', None) is not None:
+ results['gt_bboxes_labels'] = \
+ results['gt_bboxes_labels'][valid_inds]
+
+ if results.get('gt_masks', None) is not None:
+ results['gt_masks'] = results['gt_masks'][
+ valid_inds.nonzero()[0]].crop(
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
+
+ # crop semantic seg
+ if results.get('gt_seg_map', None) is not None:
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
+ crop_x1:crop_x2]
+
+ return results
+
+ @autocast_box_type()
+ def transform(self, results: dict) -> Union[dict, None]:
+ """The random crop transform function.
+
+ Args:
+ results (dict): The result dict.
+
+ Returns:
+ dict: The result dict.
+ """
+ if results.get('gt_bboxes', None) is None or len(
+ results['gt_bboxes']) == 0:
+ return results
+
+ orig_img_h, orig_img_w = results['img'].shape[:2]
+ gt_bboxes = results['gt_bboxes']
+
+ thresholds = list(self.thresholds)
+ if self.allow_no_crop:
+ thresholds.append('no_crop')
+ random.shuffle(thresholds)
+
+ for thresh in thresholds:
+ # Determine the coordinates for cropping
+ if thresh == 'no_crop':
+ return results
+
+ found = False
+ for i in range(self.num_attempts):
+ crop_h, crop_w = self._get_crop_size((orig_img_h, orig_img_w))
+ if self.aspect_ratio is None:
+ if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
+ continue
+
+ # get image crop_box
+ margin_h = max(orig_img_h - crop_h, 0)
+ margin_w = max(orig_img_w - crop_w, 0)
+ offset_h, offset_w = self._rand_offset((margin_h, margin_w))
+ crop_y1, crop_y2 = offset_h, offset_h + crop_h
+ crop_x1, crop_x2 = offset_w, offset_w + crop_w
+
+ crop_box = [crop_x1, crop_y1, crop_x2, crop_y2]
+ # Calculate the iou between gt_bboxes and crop_boxes
+ iou = self._iou_matrix(gt_bboxes,
+ np.array([crop_box], dtype=np.float32))
+ # If the maximum value of the iou is less than thresh,
+ # the current crop_box is considered invalid.
+ if iou.max() < thresh:
+ continue
+
+ # If cover_all_box == True and the minimum value of
+ # the iou is less than thresh, the current crop_box
+ # is considered invalid.
+ if self.cover_all_box and iou.min() < thresh:
+ continue
+
+ # Get which gt_bboxes to keep after cropping.
+ valid_inds = self._get_valid_inds(
+ gt_bboxes, np.array(crop_box, dtype=np.float32))
+ if valid_inds.size > 0:
+ found = True
+ break
+
+ if found:
+ results = self._crop_data(results, crop_box, valid_inds)
+ return results
+ return results
+
+ @cache_randomness
+ def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
+ """Randomly generate crop offset.
+
+ Args:
+ margin (Tuple[int, int]): The upper bound for the offset generated
+ randomly.
+
+ Returns:
+ Tuple[int, int]: The random offset for the crop.
+ """
+ margin_h, margin_w = margin
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+
+ return (offset_h, offset_w)
+
+ @cache_randomness
+ def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
+ """Randomly generates the crop size based on `image_size`.
+
+ Args:
+ image_size (Tuple[int, int]): (h, w).
+
+ Returns:
+ crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
+ """
+ h, w = image_size
+ scale = random.uniform(*self.scaling)
+ if self.aspect_ratio is not None:
+ min_ar, max_ar = self.aspect_ratio
+ aspect_ratio = random.uniform(
+ max(min_ar, scale**2), min(max_ar, scale**-2))
+ h_scale = scale / np.sqrt(aspect_ratio)
+ w_scale = scale * np.sqrt(aspect_ratio)
+ else:
+ h_scale = random.uniform(*self.scaling)
+ w_scale = random.uniform(*self.scaling)
+ crop_h = h * h_scale
+ crop_w = w * w_scale
+ return int(crop_h), int(crop_w)
+
+ def _iou_matrix(self,
+ gt_bbox: HorizontalBoxes,
+ crop_bbox: np.ndarray,
+ eps: float = 1e-10) -> np.ndarray:
+ """Calculate iou between gt and image crop box.
+
+ Args:
+ gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
+ crop_bbox (np.ndarray): Image crop coordinates in
+ [x1, y1, x2, y2] format.
+ eps (float): Default to 1e-10.
+ Return:
+ (np.ndarray): IoU.
+ """
+ gt_bbox = gt_bbox.tensor.numpy()
+ lefttop = np.maximum(gt_bbox[:, np.newaxis, :2], crop_bbox[:, :2])
+ rightbottom = np.minimum(gt_bbox[:, np.newaxis, 2:], crop_bbox[:, 2:])
+
+ overlap = np.prod(
+ rightbottom - lefttop,
+ axis=2) * (lefttop < rightbottom).all(axis=2)
+ area_gt_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
+ area_crop_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
+ area_o = (area_gt_bbox[:, np.newaxis] + area_crop_bbox - overlap)
+ return overlap / (area_o + eps)
+
+ def _get_valid_inds(self, gt_bbox: HorizontalBoxes,
+ img_crop_bbox: np.ndarray) -> np.ndarray:
+ """Get which Bboxes to keep at the current cropping coordinates.
+
+ Args:
+ gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
+ img_crop_bbox (np.ndarray): Image crop coordinates in
+ [x1, y1, x2, y2] format.
+
+ Returns:
+ (np.ndarray): Valid indexes.
+ """
+ cropped_box = gt_bbox.tensor.numpy().copy()
+ gt_bbox = gt_bbox.tensor.numpy().copy()
+
+ cropped_box[:, :2] = np.maximum(gt_bbox[:, :2], img_crop_bbox[:2])
+ cropped_box[:, 2:] = np.minimum(gt_bbox[:, 2:], img_crop_bbox[2:])
+ cropped_box[:, :2] -= img_crop_bbox[:2]
+ cropped_box[:, 2:] -= img_crop_bbox[:2]
+
+ centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
+ valid = np.logical_and(img_crop_bbox[:2] <= centers,
+ centers < img_crop_bbox[2:]).all(axis=1)
+ valid = np.logical_and(
+ valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
+
+ return np.where(valid)[0]
diff --git a/mmyolo/datasets/utils.py b/mmyolo/datasets/utils.py
index 1d84d39b..0cca341b 100644
--- a/mmyolo/datasets/utils.py
+++ b/mmyolo/datasets/utils.py
@@ -9,8 +9,14 @@ from ..registry import TASK_UTILS
@COLLATE_FUNCTIONS.register_module()
-def yolov5_collate(data_batch: Sequence) -> dict:
- """Rewrite collate_fn to get faster training speed."""
+def yolov5_collate(data_batch: Sequence,
+ use_ms_training: bool = False) -> dict:
+ """Rewrite collate_fn to get faster training speed.
+
+ Args:
+ data_batch (Sequence): Batch of data.
+ use_ms_training (bool): Whether to use multi-scale training.
+ """
batch_imgs = []
batch_bboxes_labels = []
for i in range(len(data_batch)):
@@ -25,10 +31,16 @@ def yolov5_collate(data_batch: Sequence) -> dict:
batch_bboxes_labels.append(bboxes_labels)
batch_imgs.append(inputs)
- return {
- 'inputs': torch.stack(batch_imgs, 0),
- 'data_samples': torch.cat(batch_bboxes_labels, 0)
- }
+ if use_ms_training:
+ return {
+ 'inputs': batch_imgs,
+ 'data_samples': torch.cat(batch_bboxes_labels, 0)
+ }
+ else:
+ return {
+ 'inputs': torch.stack(batch_imgs, 0),
+ 'data_samples': torch.cat(batch_bboxes_labels, 0)
+ }
@TASK_UTILS.register_module()
diff --git a/mmyolo/engine/hooks/__init__.py b/mmyolo/engine/hooks/__init__.py
index 466fa511..0b8deebc 100644
--- a/mmyolo/engine/hooks/__init__.py
+++ b/mmyolo/engine/hooks/__init__.py
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
from .switch_to_deploy_hook import SwitchToDeployHook
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
__all__ = [
- 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook'
+ 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
+ 'PPYOLOEParamSchedulerHook'
]
diff --git a/mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py b/mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py
new file mode 100644
index 00000000..26dfe6ef
--- /dev/null
+++ b/mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Optional
+
+from mmengine.hooks import ParamSchedulerHook
+from mmengine.runner import Runner
+
+from mmyolo.registry import HOOKS
+
+
+@HOOKS.register_module()
+class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
+ """A hook to update learning rate and momentum in optimizer of PPYOLOE. We
+ use this hook to implement adaptive computation for `warmup_total_iters`,
+ which is not possible with the built-in ParamScheduler in mmyolo.
+
+ Args:
+ warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
+ start_factor (float): The number we multiply learning rate in the
+ first epoch. The multiplication factor changes towards end_factor
+ in the following epochs. Defaults to 0.
+ warmup_epochs (int): Epochs for warmup. Defaults to 5.
+ min_lr_ratio (float): Minimum learning rate ratio.
+ total_epochs (int): In PPYOLOE, `total_epochs` is set to
+ training_epochs x 1.2. Defaults to 360.
+ """
+ priority = 9
+
+ def __init__(self,
+ warmup_min_iter: int = 1000,
+ start_factor: float = 0.,
+ warmup_epochs: int = 5,
+ min_lr_ratio: float = 0.0,
+ total_epochs: int = 360):
+
+ self.warmup_min_iter = warmup_min_iter
+ self.start_factor = start_factor
+ self.warmup_epochs = warmup_epochs
+ self.min_lr_ratio = min_lr_ratio
+ self.total_epochs = total_epochs
+
+ self._warmup_end = False
+ self._base_lr = None
+
+ def before_train(self, runner: Runner):
+ """Operations before train.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ """
+ optimizer = runner.optim_wrapper.optimizer
+ for group in optimizer.param_groups:
+ # If the param is never be scheduled, record the current value
+ # as the initial value.
+ group.setdefault('initial_lr', group['lr'])
+
+ self._base_lr = [
+ group['initial_lr'] for group in optimizer.param_groups
+ ]
+ self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
+
+ def before_train_iter(self,
+ runner: Runner,
+ batch_idx: int,
+ data_batch: Optional[dict] = None):
+ """Operations before each training iteration.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ batch_idx (int): The index of the current batch in the train loop.
+ data_batch (dict or tuple or list, optional): Data from dataloader.
+ """
+ cur_iters = runner.iter
+ optimizer = runner.optim_wrapper.optimizer
+ dataloader_len = len(runner.train_dataloader)
+
+ # The minimum warmup is self.warmup_min_iter
+ warmup_total_iters = max(
+ round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
+
+ if cur_iters <= warmup_total_iters:
+ # warm up
+ alpha = cur_iters / warmup_total_iters
+ factor = self.start_factor * (1 - alpha) + alpha
+
+ for group_idx, param in enumerate(optimizer.param_groups):
+ param['lr'] = self._base_lr[group_idx] * factor
+ else:
+ for group_idx, param in enumerate(optimizer.param_groups):
+ total_iters = self.total_epochs * dataloader_len
+ lr = self._min_lr[group_idx] + (
+ self._base_lr[group_idx] -
+ self._min_lr[group_idx]) * 0.5 * (
+ math.cos((cur_iters - warmup_total_iters) * math.pi /
+ (total_iters - warmup_total_iters)) + 1.0)
+ param['lr'] = lr
diff --git a/mmyolo/models/data_preprocessors/__init__.py b/mmyolo/models/data_preprocessors/__init__.py
index d9edbfd2..4e31aa71 100644
--- a/mmyolo/models/data_preprocessors/__init__.py
+++ b/mmyolo/models/data_preprocessors/__init__.py
@@ -1,4 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .data_preprocessor import YOLOv5DetDataPreprocessor
+from .data_preprocessor import (PPYOLOEBatchRandomResize,
+ PPYOLOEDetDataPreprocessor,
+ YOLOv5DetDataPreprocessor)
-__all__ = ['YOLOv5DetDataPreprocessor']
+__all__ = [
+ 'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
+ 'PPYOLOEBatchRandomResize'
+]
diff --git a/mmyolo/models/data_preprocessors/data_preprocessor.py b/mmyolo/models/data_preprocessors/data_preprocessor.py
index 04a62821..c7281fa5 100644
--- a/mmyolo/models/data_preprocessors/data_preprocessor.py
+++ b/mmyolo/models/data_preprocessors/data_preprocessor.py
@@ -1,6 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from typing import List, Tuple, Union
+
import torch
+import torch.nn.functional as F
+from mmdet.models import BatchSyncRandomResize
from mmdet.models.data_preprocessors import DetDataPreprocessor
+from mmengine import MessageHub, is_list_of
+from torch import Tensor
from mmyolo.registry import MODELS
@@ -50,3 +57,191 @@ class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
return {'inputs': inputs, 'data_samples': data_samples}
+
+
+@MODELS.register_module()
+class PPYOLOEDetDataPreprocessor(DetDataPreprocessor):
+ """Image pre-processor for detection tasks.
+
+ The main difference between PPYOLOEDetDataPreprocessor and
+ DetDataPreprocessor is the normalization order. The official
+ PPYOLOE resize image first, and then normalize image.
+ In DetDataPreprocessor, the order is reversed.
+
+ Note: It must be used together with
+ `mmyolo.datasets.utils.yolov5_collate`
+ """
+
+ def forward(self, data: dict, training: bool = False) -> dict:
+ """Perform normalization、padding and bgr2rgb conversion based on
+ ``BaseDataPreprocessor``. This class use batch_augments first, and then
+ normalize the image, which is different from the `DetDataPreprocessor`
+ .
+
+ Args:
+ data (dict): Data sampled from dataloader.
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict: Data in the same format as the model input.
+ """
+ if not training:
+ return super().forward(data, training)
+
+ assert isinstance(data['inputs'], list) and is_list_of(
+ data['inputs'], torch.Tensor), \
+ '"inputs" should be a list of Tensor, but got ' \
+ f'{type(data["inputs"])}. The possible reason for this ' \
+ 'is that you are not using it with ' \
+ '"mmyolo.datasets.utils.yolov5_collate". Please refer to ' \
+ '"cconfigs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py".'
+
+ data = self.cast_data(data)
+ inputs, data_samples = data['inputs'], data['data_samples']
+
+ # Process data.
+ batch_inputs = []
+ for _batch_input, data_sample in zip(inputs, data_samples):
+ # channel transform
+ if self._channel_conversion:
+ _batch_input = _batch_input[[2, 1, 0], ...]
+ # Convert to float after channel conversion to ensure
+ # efficiency
+ _batch_input = _batch_input.float()
+
+ batch_inputs.append(_batch_input)
+
+ # Batch random resize image.
+ if self.batch_augments is not None:
+ for batch_aug in self.batch_augments:
+ inputs, data_samples = batch_aug(batch_inputs, data_samples)
+
+ if self._enable_normalize:
+ inputs = (inputs - self.mean) / self.std
+
+ img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
+ data_samples = {'bboxes_labels': data_samples, 'img_metas': img_metas}
+
+ return {'inputs': inputs, 'data_samples': data_samples}
+
+
+# TODO: No generality. Its input data format is different
+# mmdet's batch aug, and it must be compatible in the future.
+@MODELS.register_module()
+class PPYOLOEBatchRandomResize(BatchSyncRandomResize):
+ """PPYOLOE batch random resize.
+
+ Args:
+ random_size_range (tuple): The multi-scale random range during
+ multi-scale training.
+ interval (int): The iter interval of change
+ image size. Defaults to 10.
+ size_divisor (int): Image size divisible factor.
+ Defaults to 32.
+ random_interp (bool): Whether to choose interp_mode randomly.
+ If set to True, the type of `interp_mode` must be list.
+ If set to False, the type of `interp_mode` must be str.
+ Defaults to True.
+ interp_mode (Union[List, str]): The modes available for resizing
+ are ('nearest', 'bilinear', 'bicubic', 'area').
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing
+ the image. Now we only support keep_ratio=False.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ random_size_range: Tuple[int, int],
+ interval: int = 1,
+ size_divisor: int = 32,
+ random_interp=True,
+ interp_mode: Union[List[str], str] = [
+ 'nearest', 'bilinear', 'bicubic', 'area'
+ ],
+ keep_ratio: bool = False) -> None:
+ super().__init__(random_size_range, interval, size_divisor)
+ self.random_interp = random_interp
+ self.keep_ratio = keep_ratio
+ # TODO: need to support keep_ratio==True
+ assert not self.keep_ratio, 'We do not yet support keep_ratio=True'
+
+ if self.random_interp:
+ assert isinstance(interp_mode, list) and len(interp_mode) > 1,\
+ 'While random_interp==True, the type of `interp_mode`' \
+ ' must be list and len(interp_mode) must large than 1'
+ self.interp_mode_list = interp_mode
+ self.interp_mode = None
+ else:
+ assert isinstance(interp_mode, str),\
+ 'While random_interp==False, the type of ' \
+ '`interp_mode` must be str'
+ assert interp_mode in ['nearest', 'bilinear', 'bicubic', 'area']
+ self.interp_mode_list = None
+ self.interp_mode = interp_mode
+
+ def forward(self, inputs: list,
+ data_samples: Tensor) -> Tuple[Tensor, Tensor]:
+ """Resize a batch of images and bboxes to shape ``self._input_size``.
+
+ The inputs and data_samples should be list, and
+ ``PPYOLOEBatchRandomResize`` must be used with
+ ``PPYOLOEDetDataPreprocessor`` and ``yolov5_collate`` with
+ ``use_ms_training == True``.
+ """
+ assert isinstance(inputs, list),\
+ 'The type of inputs must be list. The possible reason for this ' \
+ 'is that you are not using it with `PPYOLOEDetDataPreprocessor` ' \
+ 'and `yolov5_collate` with use_ms_training == True.'
+ message_hub = MessageHub.get_current_instance()
+ if (message_hub.get_info('iter') + 1) % self._interval == 0:
+ # get current input size
+ self._input_size, interp_mode = self._get_random_size_and_interp()
+ if self.random_interp:
+ self.interp_mode = interp_mode
+
+ # TODO: need to support type(inputs)==Tensor
+ if isinstance(inputs, list):
+ outputs = []
+ for i in range(len(inputs)):
+ _batch_input = inputs[i]
+ h, w = _batch_input.shape[-2:]
+ scale_y = self._input_size[0] / h
+ scale_x = self._input_size[1] / w
+ if scale_x != 1. or scale_y != 1.:
+ if self.interp_mode in ('nearest', 'area'):
+ align_corners = None
+ else:
+ align_corners = False
+ _batch_input = F.interpolate(
+ _batch_input.unsqueeze(0),
+ size=self._input_size,
+ mode=self.interp_mode,
+ align_corners=align_corners)
+
+ # rescale boxes
+ indexes = data_samples[:, 0] == i
+ data_samples[indexes, 2] *= scale_x
+ data_samples[indexes, 3] *= scale_y
+ data_samples[indexes, 4] *= scale_x
+ data_samples[indexes, 5] *= scale_y
+ else:
+ _batch_input = _batch_input.unsqueeze(0)
+
+ outputs.append(_batch_input)
+
+ # convert to Tensor
+ return torch.cat(outputs, dim=0), data_samples
+ else:
+ raise NotImplementedError('Not implemented yet!')
+
+ def _get_random_size_and_interp(self) -> Tuple[int, int]:
+ """Randomly generate a shape in ``_random_size_range`` and a
+ interp_mode in interp_mode_list."""
+ size = random.randint(*self._random_size_range)
+ input_size = (self._size_divisor * size, self._size_divisor * size)
+
+ if self.random_interp:
+ interp_ind = random.randint(0, len(self.interp_mode_list) - 1)
+ interp_mode = self.interp_mode_list[interp_ind]
+ else:
+ interp_mode = None
+ return input_size, interp_mode
diff --git a/mmyolo/models/dense_heads/ppyoloe_head.py b/mmyolo/models/dense_heads/ppyoloe_head.py
index f643a1d5..bd246c4d 100644
--- a/mmyolo/models/dense_heads/ppyoloe_head.py
+++ b/mmyolo/models/dense_heads/ppyoloe_head.py
@@ -1,24 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Sequence, Union
+from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.utils import multi_apply
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
- OptMultiConfig)
+ OptMultiConfig, reduce_mean)
+from mmengine import MessageHub
from mmengine.model import BaseModule, bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from ..layers.yolo_bricks import PPYOLOESELayer
-from .yolov5_head import YOLOv5Head
+from .yolov6_head import YOLOv6Head
@MODELS.register_module()
class PPYOLOEHeadModule(BaseModule):
- """PPYOLOEHead head module used in `PPYOLOE`
+ """PPYOLOEHead head module used in `PPYOLOE.
+
+ `_.
Args:
num_classes (int): Number of categories excluding the background
@@ -30,7 +33,8 @@ class PPYOLOEHeadModule(BaseModule):
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).
- reg_max (int): TOOD reg_max param.
+ reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
+ in QFL setting. Defaults to 16.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
@@ -100,15 +104,12 @@ class PPYOLOEHeadModule(BaseModule):
self.reg_preds.append(
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
- self.proj_conv = nn.Conv2d(self.reg_max + 1, 1, 1, bias=False)
- self.proj = nn.Parameter(
- torch.linspace(0, self.reg_max, self.reg_max + 1),
- requires_grad=False)
- self.proj_conv.weight = nn.Parameter(
- self.proj.view([1, self.reg_max + 1, 1, 1]).clone().detach(),
- requires_grad=False)
+ # init proj
+ proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
+ [1, self.reg_max + 1, 1, 1])
+ self.register_buffer('proj', proj, persistent=False)
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x: Tuple[Tensor]) -> Tensor:
"""Forward features from the upstream network.
Args:
@@ -131,17 +132,24 @@ class PPYOLOEHeadModule(BaseModule):
hw = h * w
avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
- reg_dist = reg_pred(reg_stem(x, avg_feat))
- reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1,
- hw]).permute(0, 2, 3, 1)
- reg_dist = self.proj_conv(F.softmax(reg_dist, dim=1))
+ bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
+ # TODO: Test whether use matmul instead of conv can speed up training.
+ bbox_dist_preds = bbox_dist_preds.reshape(
+ [-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
- return cls_logit, reg_dist
+ bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
+
+ if self.training:
+ return cls_logit, bbox_preds, bbox_dist_preds
+ else:
+ return cls_logit, bbox_preds
@MODELS.register_module()
-class PPYOLOEHead(YOLOv5Head):
- """PPYOLOEHead head used in `PPYOLOE`.
+class PPYOLOEHead(YOLOv6Head):
+ """PPYOLOEHead head used in `PPYOLOE `_.
+ The YOLOv6 head and the PPYOLOE head are only slightly different.
+ Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
Args:
head_module(ConfigType): Base module used for YOLOv5Head
@@ -150,7 +158,8 @@ class PPYOLOEHead(YOLOv5Head):
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
- loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
+ loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
+ loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
@@ -168,17 +177,24 @@ class PPYOLOEHead(YOLOv5Head):
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
- type='mmdet.CrossEntropyLoss',
+ type='mmdet.VarifocalLoss',
use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
- type='mmdet.GIoULoss', reduction='sum', loss_weight=5.0),
- loss_obj: ConfigType = dict(
- type='mmdet.CrossEntropyLoss',
- use_sigmoid=True,
- reduction='sum',
- loss_weight=1.0),
+ type='IoULoss',
+ iou_mode='giou',
+ bbox_format='xyxy',
+ reduction='mean',
+ loss_weight=2.5,
+ return_iou=False),
+ loss_dfl: ConfigType = dict(
+ type='mmdet.DistributionFocalLoss',
+ reduction='mean',
+ loss_weight=0.5 / 4),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
@@ -188,19 +204,18 @@ class PPYOLOEHead(YOLOv5Head):
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
- loss_obj=loss_obj,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
-
- def special_init(self):
- """Not Implenented."""
- pass
+ self.loss_dfl = MODELS.build(loss_dfl)
+ # ppyoloe doesn't need loss_obj
+ self.loss_obj = None
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
+ bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
@@ -214,6 +229,8 @@ class PPYOLOEHead(YOLOv5Head):
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
+ bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
+ each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
@@ -226,4 +243,131 @@ class PPYOLOEHead(YOLOv5Head):
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
- raise NotImplementedError('Not implemented yet!')
+
+ # get epoch information from message hub
+ message_hub = MessageHub.get_current_instance()
+ current_epoch = message_hub.get_info('epoch')
+
+ num_imgs = len(batch_img_metas)
+
+ current_featmap_sizes = [
+ cls_score.shape[2:] for cls_score in cls_scores
+ ]
+ # If the shape does not equal, generate new one
+ if current_featmap_sizes != self.featmap_sizes_train:
+ self.featmap_sizes_train = current_featmap_sizes
+
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
+ self.featmap_sizes_train,
+ dtype=cls_scores[0].dtype,
+ device=cls_scores[0].device,
+ with_stride=True)
+
+ self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
+ self.flatten_priors_train = torch.cat(
+ mlvl_priors_with_stride, dim=0)
+ self.stride_tensor = self.flatten_priors_train[..., [2]]
+
+ # gt info
+ gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
+ gt_labels = gt_info[:, :, :1]
+ gt_bboxes = gt_info[:, :, 1:] # xyxy
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
+
+ # pred info
+ flatten_cls_preds = [
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.num_classes)
+ for cls_pred in cls_scores
+ ]
+ flatten_pred_bboxes = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ # (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
+ flatten_pred_dists = [
+ bbox_pred_org.permute(0, 2, 3, 1).reshape(
+ num_imgs, -1, (self.head_module.reg_max + 1) * 4)
+ for bbox_pred_org in bbox_dist_preds
+ ]
+
+ flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
+ flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
+ flatten_pred_bboxes = self.bbox_coder.decode(
+ self.flatten_priors_train[..., :2], flatten_pred_bboxes,
+ self.stride_tensor[..., 0])
+ pred_scores = torch.sigmoid(flatten_cls_preds)
+
+ if current_epoch < self.initial_epoch:
+ assigned_result = self.initial_assigner(
+ flatten_pred_bboxes.detach(), self.flatten_priors_train,
+ self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
+ else:
+ assigned_result = self.assigner(flatten_pred_bboxes.detach(),
+ pred_scores.detach(),
+ self.flatten_priors_train,
+ gt_labels, gt_bboxes,
+ pad_bbox_flag)
+
+ assigned_bboxes = assigned_result['assigned_bboxes']
+ assigned_scores = assigned_result['assigned_scores']
+ fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
+
+ # cls loss
+ with torch.cuda.amp.autocast(enabled=False):
+ loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
+
+ # rescale bbox
+ assigned_bboxes /= self.stride_tensor
+ flatten_pred_bboxes /= self.stride_tensor
+
+ assigned_scores_sum = assigned_scores.sum()
+ # reduce_mean between all gpus
+ assigned_scores_sum = torch.clamp(
+ reduce_mean(assigned_scores_sum), min=1)
+ loss_cls /= assigned_scores_sum
+
+ # select positive samples mask
+ num_pos = fg_mask_pre_prior.sum()
+ if num_pos > 0:
+ # when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
+ # will not report an error
+ # iou loss
+ prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
+ pred_bboxes_pos = torch.masked_select(
+ flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
+ assigned_bboxes_pos = torch.masked_select(
+ assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
+ bbox_weight = torch.masked_select(
+ assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
+ loss_bbox = self.loss_bbox(
+ pred_bboxes_pos,
+ assigned_bboxes_pos,
+ weight=bbox_weight,
+ avg_factor=assigned_scores_sum)
+
+ # dfl loss
+ dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
+ [1, 1, (self.head_module.reg_max + 1) * 4])
+
+ pred_dist_pos = torch.masked_select(
+ flatten_dist_preds,
+ dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
+ assigned_ltrb = self.bbox_coder.encode(
+ self.flatten_priors_train[..., :2] / self.stride_tensor,
+ assigned_bboxes,
+ max_dis=self.head_module.reg_max,
+ eps=0.01)
+ assigned_ltrb_pos = torch.masked_select(
+ assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
+ loss_dfl = self.loss_dfl(
+ pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
+ assigned_ltrb_pos.reshape(-1),
+ weight=bbox_weight.expand(-1, 4).reshape(-1),
+ avg_factor=assigned_scores_sum)
+ else:
+ loss_bbox = flatten_pred_bboxes.sum() * 0
+ loss_dfl = flatten_pred_bboxes.sum() * 0
+
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
diff --git a/mmyolo/models/dense_heads/yolov6_head.py b/mmyolo/models/dense_heads/yolov6_head.py
index e85cd828..60abf29d 100644
--- a/mmyolo/models/dense_heads/yolov6_head.py
+++ b/mmyolo/models/dense_heads/yolov6_head.py
@@ -169,7 +169,6 @@ class YOLOv6Head(YOLOv5Head):
in 2D points-based detectors.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
- loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
@@ -201,11 +200,6 @@ class YOLOv6Head(YOLOv5Head):
reduction='mean',
loss_weight=2.5,
return_iou=False),
- loss_obj: ConfigType = dict(
- type='mmdet.CrossEntropyLoss',
- use_sigmoid=True,
- reduction='sum',
- loss_weight=1.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
@@ -215,13 +209,11 @@ class YOLOv6Head(YOLOv5Head):
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
- loss_obj=loss_obj,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
-
- self.loss_bbox = MODELS.build(loss_bbox)
- self.loss_cls = MODELS.build(loss_cls)
+ # yolov6 doesn't need loss_obj
+ self.loss_obj = None
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
@@ -236,10 +228,9 @@ class YOLOv6Head(YOLOv5Head):
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
# Add common attributes to reduce calculation
- self.featmap_sizes = None
- self.mlvl_priors = None
+ self.featmap_sizes_train = None
self.num_level_priors = None
- self.flatten_priors = None
+ self.flatten_priors_train = None
self.stride_tensor = None
def loss_by_feat(
@@ -284,19 +275,19 @@ class YOLOv6Head(YOLOv5Head):
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
- if current_featmap_sizes != self.featmap_sizes:
- self.featmap_sizes = current_featmap_sizes
+ if current_featmap_sizes != self.featmap_sizes_train:
+ self.featmap_sizes_train = current_featmap_sizes
- mlvl_priors = self.prior_generator.grid_priors(
- self.featmap_sizes,
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
+ self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
- self.num_level_priors = [len(n) for n in mlvl_priors]
- self.flatten_priors = torch.cat(mlvl_priors, dim=0)
- self.stride_tensor = self.flatten_priors[..., [2]]
- self.mlvl_priors = [mlvl[:, :2] for mlvl in mlvl_priors]
+ self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
+ self.flatten_priors_train = torch.cat(
+ mlvl_priors_with_stride, dim=0)
+ self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = self.gt_instances_preprocess(batch_gt_instances, num_imgs)
@@ -319,19 +310,20 @@ class YOLOv6Head(YOLOv5Head):
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
- self.flatten_priors[..., :2], flatten_pred_bboxes,
- self.flatten_priors[..., 2])
+ self.flatten_priors_train[..., :2], flatten_pred_bboxes,
+ self.stride_tensor[:, 0])
pred_scores = torch.sigmoid(flatten_cls_preds)
if current_epoch < self.initial_epoch:
assigned_result = self.initial_assigner(
- flatten_pred_bboxes.detach(), self.flatten_priors,
+ flatten_pred_bboxes.detach(), self.flatten_priors_train,
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
else:
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
pred_scores.detach(),
- self.flatten_priors, gt_labels,
- gt_bboxes, pad_bbox_flag)
+ self.flatten_priors_train,
+ gt_labels, gt_bboxes,
+ pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
diff --git a/mmyolo/models/task_modules/assigners/batch_atss_assigner.py b/mmyolo/models/task_modules/assigners/batch_atss_assigner.py
index 5b2ed50e..45b3069a 100644
--- a/mmyolo/models/task_modules/assigners/batch_atss_assigner.py
+++ b/mmyolo/models/task_modules/assigners/batch_atss_assigner.py
@@ -92,7 +92,7 @@ class BatchATSSAssigner(nn.Module):
Args:
pred_bboxes (Tensor): Predicted bounding boxes,
shape(batch_size, num_priors, 4)
- priors (Tensor): Model priors, shape(num_priors, 4)
+ priors (Tensor): Model priors with stride, shape(num_priors, 4)
num_level_priors (List): Number of bboxes in each level, len(3)
gt_labels (Tensor): Ground truth label,
shape(batch_size, num_gt, 1)
diff --git a/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py b/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py
index f43890ec..16417b8a 100644
--- a/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py
+++ b/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py
@@ -4,7 +4,7 @@ from typing import Optional, Sequence, Union
import torch
from mmdet.models.task_modules.coders import \
DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder
-from mmdet.structures.bbox import distance2bbox
+from mmdet.structures.bbox import bbox2distance, distance2bbox
from mmyolo.registry import TASK_UTILS
@@ -51,3 +51,29 @@ class DistancePointBBoxCoder(MMDET_DistancePointBBoxCoder):
pred_bboxes = pred_bboxes * stride[None, :, None]
return distance2bbox(points, pred_bboxes, max_shape)
+
+ def encode(self,
+ points: torch.Tensor,
+ gt_bboxes: torch.Tensor,
+ max_dis: float = 16.,
+ eps: float = 0.01) -> torch.Tensor:
+ """Encode bounding box to distances. The rewrite is to support batch
+ operations.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2), The format is [x, y].
+ gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format
+ is "xyxy"
+ max_dis (float): Upper bound of the distance. Default to 16..
+ eps (float): a small value to ensure target < max_dis, instead <=.
+ Default 0.01.
+
+ Returns:
+ Tensor: Box transformation deltas. The shape is (N, 4) or
+ (B, N, 4).
+ """
+
+ assert points.size(-2) == gt_bboxes.size(-2)
+ assert points.size(-1) == 2
+ assert gt_bboxes.size(-1) == 4
+ return bbox2distance(points, gt_bboxes, max_dis, eps)
diff --git a/model-index.yml b/model-index.yml
index de8794ca..d804a939 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -4,3 +4,4 @@ Import:
- configs/yolox/metafile.yml
- configs/rtmdet/metafile.yml
- configs/yolov7/metafile.yml
+ - configs/ppyoloe/metafile.yml
diff --git a/tests/test_datasets/test_transforms/test_transforms.py b/tests/test_datasets/test_transforms/test_transforms.py
index eb61b508..d256dd9f 100644
--- a/tests/test_datasets/test_transforms/test_transforms.py
+++ b/tests/test_datasets/test_transforms/test_transforms.py
@@ -13,6 +13,8 @@ from mmyolo.datasets.transforms import (LetterResize, LoadAnnotations,
YOLOv5HSVRandomAug,
YOLOv5KeepRatioResize,
YOLOv5RandomAffine)
+from mmyolo.datasets.transforms.transforms import (PPYOLOERandomCrop,
+ PPYOLOERandomDistort)
class TestLetterResize(unittest.TestCase):
@@ -355,3 +357,100 @@ class TestYOLOv5RandomAffine(unittest.TestCase):
self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
self.assertTrue(results['gt_ignore_flags'].dtype == bool)
+
+
+class TestPPYOLOERandomCrop(unittest.TestCase):
+
+ def setUp(self):
+ """Setup the data info which are used in every test method.
+
+ TestCase calls functions in this order: setUp() -> testMethod() ->
+ tearDown() -> cleanUp()
+ """
+ self.results = {
+ 'img':
+ np.random.random((224, 224, 3)),
+ 'img_shape': (224, 224),
+ 'gt_bboxes_labels':
+ np.array([1, 2, 3], dtype=np.int64),
+ 'gt_bboxes':
+ np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
+ dtype=np.float32),
+ 'gt_ignore_flags':
+ np.array([0, 0, 1], dtype=bool),
+ }
+
+ def test_transform(self):
+ transform = PPYOLOERandomCrop()
+ results = transform(copy.deepcopy(self.results))
+ self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
+ results['gt_bboxes'].shape[0])
+ self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
+ self.assertTrue(results['gt_bboxes'].dtype == np.float32)
+ self.assertTrue(results['gt_ignore_flags'].dtype == bool)
+
+ def test_transform_with_boxlist(self):
+ results = copy.deepcopy(self.results)
+ results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
+
+ transform = PPYOLOERandomCrop()
+ results = transform(copy.deepcopy(results))
+ self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
+ results['gt_bboxes'].shape[0])
+ self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
+ self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
+ self.assertTrue(results['gt_ignore_flags'].dtype == bool)
+
+
+class TestPPYOLOERandomDistort(unittest.TestCase):
+
+ def setUp(self):
+ """Setup the data info which are used in every test method.
+
+ TestCase calls functions in this order: setUp() -> testMethod() ->
+ tearDown() -> cleanUp()
+ """
+ self.results = {
+ 'img':
+ np.random.random((224, 224, 3)),
+ 'img_shape': (224, 224),
+ 'gt_bboxes_labels':
+ np.array([1, 2, 3], dtype=np.int64),
+ 'gt_bboxes':
+ np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
+ dtype=np.float32),
+ 'gt_ignore_flags':
+ np.array([0, 0, 1], dtype=bool),
+ }
+
+ def test_transform(self):
+ # test assertion for invalid prob
+ with self.assertRaises(AssertionError):
+ transform = PPYOLOERandomDistort(
+ hue_cfg=dict(min=-18, max=18, prob=1.5))
+
+ # test assertion for invalid num_distort_func
+ with self.assertRaises(AssertionError):
+ transform = PPYOLOERandomDistort(num_distort_func=5)
+
+ transform = PPYOLOERandomDistort()
+ results = transform(copy.deepcopy(self.results))
+ self.assertTrue(results['img'].shape[:2] == (224, 224))
+ self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
+ results['gt_bboxes'].shape[0])
+ self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
+ self.assertTrue(results['gt_bboxes'].dtype == np.float32)
+ self.assertTrue(results['gt_ignore_flags'].dtype == bool)
+
+ def test_transform_with_boxlist(self):
+ results = copy.deepcopy(self.results)
+ results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes'])
+
+ transform = PPYOLOERandomDistort()
+ results = transform(copy.deepcopy(results))
+ self.assertTrue(results['img'].shape[:2] == (224, 224))
+ self.assertTrue(results['gt_bboxes_labels'].shape[0] ==
+ results['gt_bboxes'].shape[0])
+ self.assertTrue(results['gt_bboxes_labels'].dtype == np.int64)
+ self.assertTrue(results['gt_bboxes'].dtype == torch.float32)
+ self.assertTrue(results['gt_ignore_flags'].dtype == bool)
diff --git a/tests/test_datasets/test_utils.py b/tests/test_datasets/test_utils.py
index 136eda53..43c8e61f 100644
--- a/tests/test_datasets/test_utils.py
+++ b/tests/test_datasets/test_utils.py
@@ -47,6 +47,36 @@ class TestYOLOv5Collate(unittest.TestCase):
self.assertTrue(out['inputs'].shape == (2, 3, 10, 10))
self.assertTrue(out['data_samples'].shape == (8, 6))
+ def test_yolov5_collate_with_multi_scale(self):
+ rng = np.random.RandomState(0)
+
+ inputs = torch.randn((3, 10, 10))
+ data_samples = DetDataSample()
+ gt_instances = InstanceData()
+ bboxes = _rand_bboxes(rng, 4, 6, 8)
+ gt_instances.bboxes = HorizontalBoxes(bboxes, dtype=torch.float32)
+ labels = rng.randint(1, 2, size=len(bboxes))
+ gt_instances.labels = torch.LongTensor(labels)
+ data_samples.gt_instances = gt_instances
+
+ out = yolov5_collate([dict(inputs=inputs, data_samples=data_samples)],
+ use_ms_training=True)
+ self.assertIsInstance(out, dict)
+ self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
+ print(out['data_samples'].shape)
+ self.assertTrue(out['data_samples'].shape == (4, 6))
+ self.assertIsInstance(out['inputs'], list)
+ self.assertIsInstance(out['data_samples'], torch.Tensor)
+
+ out = yolov5_collate(
+ [dict(inputs=inputs, data_samples=data_samples)] * 2,
+ use_ms_training=True)
+ self.assertIsInstance(out, dict)
+ self.assertTrue(out['inputs'][0].shape == (3, 10, 10))
+ self.assertTrue(out['data_samples'].shape == (8, 6))
+ self.assertIsInstance(out['inputs'], list)
+ self.assertIsInstance(out['data_samples'], torch.Tensor)
+
class TestBatchShapePolicy(unittest.TestCase):
diff --git a/tests/test_models/test_data_preprocessor/test_data_preprocessor.py b/tests/test_models/test_data_preprocessor/test_data_preprocessor.py
index 85cdb742..203660ae 100644
--- a/tests/test_models/test_data_preprocessor/test_data_preprocessor.py
+++ b/tests/test_models/test_data_preprocessor/test_data_preprocessor.py
@@ -3,8 +3,13 @@ from unittest import TestCase
import torch
from mmdet.structures import DetDataSample
+from mmengine import MessageHub
+from mmyolo.models import PPYOLOEBatchRandomResize, PPYOLOEDetDataPreprocessor
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
+from mmyolo.utils import register_all_modules
+
+register_all_modules()
class TestYOLOv5DetDataPreprocessor(TestCase):
@@ -69,3 +74,51 @@ class TestYOLOv5DetDataPreprocessor(TestCase):
# data_samples must be tensor
with self.assertRaises(AssertionError):
processor(data, training=True)
+
+
+class TestPPYOLOEDetDataPreprocessor(TestCase):
+
+ def test_batch_random_resize(self):
+ processor = PPYOLOEDetDataPreprocessor(
+ pad_size_divisor=32,
+ batch_augments=[
+ dict(
+ type='PPYOLOEBatchRandomResize',
+ random_size_range=(320, 480),
+ interval=1,
+ size_divisor=32,
+ random_interp=True,
+ keep_ratio=False)
+ ],
+ mean=[0., 0., 0.],
+ std=[255., 255., 255.],
+ bgr_to_rgb=True)
+ self.assertTrue(
+ isinstance(processor.batch_augments[0], PPYOLOEBatchRandomResize))
+ message_hub = MessageHub.get_instance('test_batch_random_resize')
+ message_hub.update_info('iter', 0)
+
+ # test training
+ data = {
+ 'inputs': [
+ torch.randint(0, 256, (3, 10, 11)),
+ torch.randint(0, 256, (3, 10, 11))
+ ],
+ 'data_samples':
+ torch.randint(0, 11, (18, 6)).float(),
+ }
+ out_data = processor(data, training=True)
+ batch_data_samples = out_data['data_samples']
+ self.assertIn('img_metas', batch_data_samples)
+ self.assertIn('bboxes_labels', batch_data_samples)
+ self.assertIsInstance(batch_data_samples['bboxes_labels'],
+ torch.Tensor)
+ self.assertIsInstance(batch_data_samples['img_metas'], list)
+
+ data = {
+ 'inputs': [torch.randint(0, 256, (3, 11, 10))],
+ 'data_samples': DetDataSample()
+ }
+ # data_samples must be list
+ with self.assertRaises(TypeError):
+ processor(data, training=True)
diff --git a/tests/test_models/test_dense_heads/test_ppyoloe_head.py b/tests/test_models/test_dense_heads/test_ppyoloe_head.py
index 15879bd8..20e0c457 100644
--- a/tests/test_models/test_dense_heads/test_ppyoloe_head.py
+++ b/tests/test_models/test_dense_heads/test_ppyoloe_head.py
@@ -2,6 +2,7 @@
from unittest import TestCase
import torch
+from mmengine import ConfigDict, MessageHub
from mmengine.config import Config
from mmengine.model import bias_init_with_prob
from mmengine.testing import assert_allclose
@@ -12,11 +13,14 @@ from mmyolo.utils import register_all_modules
register_all_modules()
-class TestYOLOXHead(TestCase):
+class TestPPYOLOEHead(TestCase):
def setUp(self):
self.head_module = dict(
- type='PPYOLOEHeadModule', num_classes=4, in_channels=[32, 64, 128])
+ type='PPYOLOEHeadModule',
+ num_classes=4,
+ in_channels=[32, 64, 128],
+ featmap_strides=(8, 16, 32))
def test_init_weights(self):
head = PPYOLOEHead(head_module=self.head_module)
@@ -50,6 +54,7 @@ class TestYOLOXHead(TestCase):
max_per_img=300))
head = PPYOLOEHead(head_module=self.head_module, test_cfg=test_cfg)
+ head.eval()
feat = [
torch.rand(1, in_channels, s // feat_size, s // feat_size)
for in_channels, feat_size in [[32, 8], [64, 16], [128, 32]]
@@ -71,3 +76,130 @@ class TestYOLOXHead(TestCase):
cfg=test_cfg,
rescale=False,
with_nms=False)
+
+ def test_loss_by_feat(self):
+ message_hub = MessageHub.get_instance('test_ppyoloe_loss_by_feat')
+ message_hub.update_info('epoch', 1)
+
+ s = 256
+ img_metas = [{
+ 'img_shape': (s, s, 3),
+ 'batch_input_shape': (s, s),
+ 'scale_factor': 1,
+ }]
+
+ head = PPYOLOEHead(
+ head_module=self.head_module,
+ train_cfg=ConfigDict(
+ initial_epoch=31,
+ initial_assigner=dict(
+ type='BatchATSSAssigner',
+ num_classes=4,
+ topk=9,
+ iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
+ assigner=dict(
+ type='BatchTaskAlignedAssigner',
+ num_classes=4,
+ topk=13,
+ alpha=1,
+ beta=6)))
+ head.train()
+
+ feat = []
+ for i in range(len(self.head_module['in_channels'])):
+ in_channel = self.head_module['in_channels'][i]
+ feat_size = self.head_module['featmap_strides'][i]
+ feat.append(
+ torch.rand(1, in_channel, s // feat_size, s // feat_size))
+
+ cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
+
+ # Test that empty ground truth encourages the network to predict
+ # background
+ gt_instances = torch.empty((0, 6), dtype=torch.float32)
+
+ empty_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
+ bbox_dist_preds, gt_instances,
+ img_metas)
+ # When there is no truth, the cls loss should be nonzero but there
+ # should be no box loss.
+ empty_cls_loss = empty_gt_losses['loss_cls'].sum()
+ empty_box_loss = empty_gt_losses['loss_bbox'].sum()
+ empty_dfl_loss = empty_gt_losses['loss_dfl'].sum()
+ self.assertGreater(empty_cls_loss.item(), 0,
+ 'cls loss should be non-zero')
+ self.assertEqual(
+ empty_box_loss.item(), 0,
+ 'there should be no box loss when there are no true boxes')
+ self.assertEqual(
+ empty_dfl_loss.item(), 0,
+ 'there should be df loss when there are no true boxes')
+
+ # When truth is non-empty then both cls and box loss should be nonzero
+ # for random inputs
+ head = PPYOLOEHead(
+ head_module=self.head_module,
+ train_cfg=ConfigDict(
+ initial_epoch=31,
+ initial_assigner=dict(
+ type='BatchATSSAssigner',
+ num_classes=4,
+ topk=9,
+ iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
+ assigner=dict(
+ type='BatchTaskAlignedAssigner',
+ num_classes=4,
+ topk=13,
+ alpha=1,
+ beta=6)))
+ head.train()
+ gt_instances = torch.Tensor(
+ [[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
+
+ one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
+ bbox_dist_preds, gt_instances,
+ img_metas)
+ onegt_cls_loss = one_gt_losses['loss_cls'].sum()
+ onegt_box_loss = one_gt_losses['loss_bbox'].sum()
+ onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
+ self.assertGreater(onegt_cls_loss.item(), 0,
+ 'cls loss should be non-zero')
+ self.assertGreater(onegt_box_loss.item(), 0,
+ 'box loss should be non-zero')
+ self.assertGreater(onegt_loss_dfl.item(), 0,
+ 'obj loss should be non-zero')
+
+ # test num_class = 1
+ self.head_module['num_classes'] = 1
+ head = PPYOLOEHead(
+ head_module=self.head_module,
+ train_cfg=ConfigDict(
+ initial_epoch=31,
+ initial_assigner=dict(
+ type='BatchATSSAssigner',
+ num_classes=1,
+ topk=9,
+ iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
+ assigner=dict(
+ type='BatchTaskAlignedAssigner',
+ num_classes=1,
+ topk=13,
+ alpha=1,
+ beta=6)))
+ head.train()
+ gt_instances = torch.Tensor(
+ [[0., 0., 23.6667, 23.8757, 238.6326, 151.8874]])
+ cls_scores, bbox_preds, bbox_dist_preds = head.forward(feat)
+
+ one_gt_losses = head.loss_by_feat(cls_scores, bbox_preds,
+ bbox_dist_preds, gt_instances,
+ img_metas)
+ onegt_cls_loss = one_gt_losses['loss_cls'].sum()
+ onegt_box_loss = one_gt_losses['loss_bbox'].sum()
+ onegt_loss_dfl = one_gt_losses['loss_dfl'].sum()
+ self.assertGreater(onegt_cls_loss.item(), 0,
+ 'cls loss should be non-zero')
+ self.assertGreater(onegt_box_loss.item(), 0,
+ 'box loss should be non-zero')
+ self.assertGreater(onegt_loss_dfl.item(), 0,
+ 'obj loss should be non-zero')
diff --git a/tools/model_converters/ppyoloe_to_mmyolo.py b/tools/model_converters/ppyoloe_to_mmyolo.py
index fa8d2233..75c4af69 100644
--- a/tools/model_converters/ppyoloe_to_mmyolo.py
+++ b/tools/model_converters/ppyoloe_to_mmyolo.py
@@ -5,13 +5,13 @@ from collections import OrderedDict
import torch
-def convert_bn(k):
+def convert_bn(k: str):
name = k.replace('._mean',
'.running_mean').replace('._variance', '.running_var')
return name
-def convert_repvgg(k):
+def convert_repvgg(k: str):
if '.conv2.conv1.' in k:
name = k.replace('.conv2.conv1.', '.conv2.rbr_dense.')
return name
@@ -22,111 +22,142 @@ def convert_repvgg(k):
return k
-def convert(src, dst):
- # TODO: add pretrained model convert
+def convert(src: str, dst: str, imagenet_pretrain: bool = False):
with open(src, 'rb') as f:
model = pickle.load(f)
new_state_dict = OrderedDict()
- for k, v in model.items():
- name = k
- if k.startswith('backbone.'):
- if '.stem.' in k:
+ if imagenet_pretrain:
+ for k, v in model.items():
+ if '@@' in k:
+ continue
+ if 'stem.' in k:
# backbone.stem.conv1.conv.weight
# -> backbone.stem.0.conv.weight
- org_ind = k.split('.')[2][-1]
+ org_ind = k.split('.')[1][-1]
new_ind = str(int(org_ind) - 1)
- name = k.replace('.stem.conv%s.' % org_ind,
- '.stem.%s.' % new_ind)
+ name = k.replace('stem.conv%s.' % org_ind,
+ 'stem.%s.' % new_ind)
else:
# backbone.stages.1.conv2.bn._variance
# -> backbone.stage2.0.conv2.bn.running_var
- org_stage_ind = k.split('.')[2]
+ org_stage_ind = k.split('.')[1]
new_stage_ind = str(int(org_stage_ind) + 1)
- name = k.replace('.stages.%s.' % org_stage_ind,
- '.stage%s.0.' % new_stage_ind)
+ name = k.replace('stages.%s.' % org_stage_ind,
+ 'stage%s.0.' % new_stage_ind)
name = convert_repvgg(name)
if '.attn.' in k:
name = name.replace('.attn.fc.', '.attn.fc.conv.')
name = convert_bn(name)
- elif k.startswith('neck.'):
- # fpn_stages
- if k.startswith('neck.fpn_stages.'):
- # neck.fpn_stages.0.0.conv1.conv.weight
- # -> neck.reduce_layers.2.0.conv1.conv.weight
- if k.startswith('neck.fpn_stages.0.0.'):
- name = k.replace('neck.fpn_stages.0.0.',
- 'neck.reduce_layers.2.0.')
- if '.spp.' in name:
- name = name.replace('.spp.conv.', '.spp.conv2.')
- # neck.fpn_stages.1.0.conv1.conv.weight
- # -> neck.top_down_layers.0.0.conv1.conv.weight
- elif k.startswith('neck.fpn_stages.1.0.'):
- name = k.replace('neck.fpn_stages.1.0.',
- 'neck.top_down_layers.0.0.')
- elif k.startswith('neck.fpn_stages.2.0.'):
- name = k.replace('neck.fpn_stages.2.0.',
- 'neck.top_down_layers.1.0.')
+ name = 'backbone.' + name
+
+ new_state_dict[name] = torch.from_numpy(v)
+ else:
+ for k, v in model.items():
+ name = k
+ if k.startswith('backbone.'):
+ if '.stem.' in k:
+ # backbone.stem.conv1.conv.weight
+ # -> backbone.stem.0.conv.weight
+ org_ind = k.split('.')[2][-1]
+ new_ind = str(int(org_ind) - 1)
+ name = k.replace('.stem.conv%s.' % org_ind,
+ '.stem.%s.' % new_ind)
else:
- raise NotImplementedError('Not implemented.')
- name = name.replace('.0.convs.', '.0.blocks.')
- elif k.startswith('neck.fpn_routes.'):
- # neck.fpn_routes.0.conv.weight
- # -> neck.upsample_layers.0.0.conv.weight
- index = k.split('.')[2]
- name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
- k.split('.')[-2:])
- name = name.replace('.0.convs.', '.0.blocks.')
- elif k.startswith('neck.pan_stages.'):
- # neck.pan_stages.0.0.conv1.conv.weight
- # -> neck.bottom_up_layers.1.0.conv1.conv.weight
- ind = k.split('.')[2]
- name = k.replace(
- 'neck.pan_stages.' + ind,
- 'neck.bottom_up_layers.' + ('0' if ind == '1' else '1'))
- name = name.replace('.0.convs.', '.0.blocks.')
- elif k.startswith('neck.pan_routes.'):
- # neck.pan_routes.0.conv.weight
- # -> neck.downsample_layers.0.conv.weight
- ind = k.split('.')[2]
- name = k.replace(
- 'neck.pan_routes.' + ind,
- 'neck.downsample_layers.' + ('0' if ind == '1' else '1'))
- name = name.replace('.0.convs.', '.0.blocks.')
+ # backbone.stages.1.conv2.bn._variance
+ # -> backbone.stage2.0.conv2.bn.running_var
+ org_stage_ind = k.split('.')[2]
+ new_stage_ind = str(int(org_stage_ind) + 1)
+ name = k.replace('.stages.%s.' % org_stage_ind,
+ '.stage%s.0.' % new_stage_ind)
+ name = convert_repvgg(name)
+ if '.attn.' in k:
+ name = name.replace('.attn.fc.', '.attn.fc.conv.')
+ name = convert_bn(name)
+ elif k.startswith('neck.'):
+ # fpn_stages
+ if k.startswith('neck.fpn_stages.'):
+ # neck.fpn_stages.0.0.conv1.conv.weight
+ # -> neck.reduce_layers.2.0.conv1.conv.weight
+ if k.startswith('neck.fpn_stages.0.0.'):
+ name = k.replace('neck.fpn_stages.0.0.',
+ 'neck.reduce_layers.2.0.')
+ if '.spp.' in name:
+ name = name.replace('.spp.conv.', '.spp.conv2.')
+ # neck.fpn_stages.1.0.conv1.conv.weight
+ # -> neck.top_down_layers.0.0.conv1.conv.weight
+ elif k.startswith('neck.fpn_stages.1.0.'):
+ name = k.replace('neck.fpn_stages.1.0.',
+ 'neck.top_down_layers.0.0.')
+ elif k.startswith('neck.fpn_stages.2.0.'):
+ name = k.replace('neck.fpn_stages.2.0.',
+ 'neck.top_down_layers.1.0.')
+ else:
+ raise NotImplementedError('Not implemented.')
+ name = name.replace('.0.convs.', '.0.blocks.')
+ elif k.startswith('neck.fpn_routes.'):
+ # neck.fpn_routes.0.conv.weight
+ # -> neck.upsample_layers.0.0.conv.weight
+ index = k.split('.')[2]
+ name = 'neck.upsample_layers.' + index + '.0.' + '.'.join(
+ k.split('.')[-2:])
+ name = name.replace('.0.convs.', '.0.blocks.')
+ elif k.startswith('neck.pan_stages.'):
+ # neck.pan_stages.0.0.conv1.conv.weight
+ # -> neck.bottom_up_layers.1.0.conv1.conv.weight
+ ind = k.split('.')[2]
+ name = k.replace(
+ 'neck.pan_stages.' + ind, 'neck.bottom_up_layers.' +
+ ('0' if ind == '1' else '1'))
+ name = name.replace('.0.convs.', '.0.blocks.')
+ elif k.startswith('neck.pan_routes.'):
+ # neck.pan_routes.0.conv.weight
+ # -> neck.downsample_layers.0.conv.weight
+ ind = k.split('.')[2]
+ name = k.replace(
+ 'neck.pan_routes.' + ind, 'neck.downsample_layers.' +
+ ('0' if ind == '1' else '1'))
+ name = name.replace('.0.convs.', '.0.blocks.')
+ else:
+ raise NotImplementedError('Not implement.')
+ name = convert_repvgg(name)
+ name = convert_bn(name)
+ elif k.startswith('yolo_head.'):
+ if ('anchor_points' in k) or ('stride_tensor' in k):
+ continue
+ if 'proj_conv' in k:
+ name = k.replace('yolo_head.proj_conv.',
+ 'bbox_head.head_module.proj_conv.')
+ else:
+ for org_key, rep_key in [
+ [
+ 'yolo_head.stem_cls.',
+ 'bbox_head.head_module.cls_stems.'
+ ],
+ [
+ 'yolo_head.stem_reg.',
+ 'bbox_head.head_module.reg_stems.'
+ ],
+ [
+ 'yolo_head.pred_cls.',
+ 'bbox_head.head_module.cls_preds.'
+ ],
+ [
+ 'yolo_head.pred_reg.',
+ 'bbox_head.head_module.reg_preds.'
+ ]
+ ]:
+ name = name.replace(org_key, rep_key)
+ name = name.split('.')
+ ind = name[3]
+ name[3] = str(2 - int(ind))
+ name = '.'.join(name)
+ name = convert_bn(name)
else:
- raise NotImplementedError('Not implement.')
- name = convert_repvgg(name)
- name = convert_bn(name)
- elif k.startswith('yolo_head.'):
- if ('anchor_points' in k) or ('stride_tensor' in k):
continue
- if 'proj_conv' in k:
- name = k.replace('yolo_head.proj_conv.',
- 'bbox_head.head_module.proj_conv.')
- else:
- for org_key, rep_key in [[
- 'yolo_head.stem_cls.',
- 'bbox_head.head_module.cls_stems.'
- ], ['yolo_head.stem_reg.', 'bbox_head.head_module.reg_stems.'],
- [
- 'yolo_head.pred_cls.',
- 'bbox_head.head_module.cls_preds.'
- ],
- [
- 'yolo_head.pred_reg.',
- 'bbox_head.head_module.reg_preds.'
- ]]:
- name = name.replace(org_key, rep_key)
- name = name.split('.')
- ind = name[3]
- name[3] = str(2 - int(ind))
- name = '.'.join(name)
- name = convert_bn(name)
- else:
- continue
- new_state_dict[name] = torch.from_numpy(v)
+ new_state_dict[name] = torch.from_numpy(v)
data = {'state_dict': new_state_dict}
torch.save(data, dst)
@@ -139,8 +170,14 @@ def main():
help='src ppyoloe model path')
parser.add_argument(
'--dst', default='mmppyoloe_plus_s.pt', help='save path')
+ parser.add_argument(
+ '--imagenet-pretrain',
+ action='store_true',
+ default=False,
+ help='Load model pretrained on imagenet dataset which only '
+ 'have weight for backbone.')
args = parser.parse_args()
- convert(args.src, args.dst)
+ convert(args.src, args.dst, args.imagenet_pretrain)
if __name__ == '__main__':