mirror of https://github.com/alibaba/EasyCV.git
support DINO algo (#144)
dino_4sc_r50_12e:48.71 dino_4sc_r50_24e:50.53 dino_4sc_r50_36e:50.69 dino_4sc_swinl_12e: 56.86 dino_4sc_swinl_36e: 58.04 dino_5sc_swinl_36e: 58.47release/0.6.0
parent
38ae771e7d
commit
b198c5a81f
|
@ -32,7 +32,7 @@ EasyCV is an all-in-one computer vision toolbox based on PyTorch, mainly focuses
|
|||
|
||||
- **Vision Transformers**
|
||||
|
||||
EasyCV aims to provide an easy way to use the off-the-shelf SOTA transformer models trained either using supervised learning or self-supervised learning, such as ViT, Swin-Transformer and Shuffle Transformer. More models will be added in the future. In addition, we support all the pretrained models from [timm](https://github.com/rwightman/pytorch-image-models).
|
||||
EasyCV aims to provide an easy way to use the off-the-shelf SOTA transformer models trained either using supervised learning or self-supervised learning, such as ViT, Swin Transformer and DETR Series. More models will be added in the future. In addition, we support all the pretrained models from [timm](https://github.com/rwightman/pytorch-image-models).
|
||||
|
||||
- **Functionality & Extensibility**
|
||||
|
||||
|
@ -144,6 +144,7 @@ notebook
|
|||
<li><a href="configs/detection/detr">DETR (ECCV'2020)</a></li>
|
||||
<li><a href="configs/detection/dab_detr">DAB-DETR (ICLR'2022)</a></li>
|
||||
<li><a href="configs/detection/dab_detr">DN-DETR (CVPR'2022)</a></li>
|
||||
<li><a href="configs/detection/dino">DINO (ArXiv'2022)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
|
|
@ -135,6 +135,7 @@ EasyCV是一个涵盖多个领域的基于Pytorch的计算机视觉工具箱,
|
|||
<li><a href="configs/detection/detr">DETR (ECCV'2020)</a></li>
|
||||
<li><a href="configs/detection/dab_detr">DAB-DETR (ICLR'2022)</a></li>
|
||||
<li><a href="configs/detection/dab_detr">DN-DETR (CVPR'2022)</a></li>
|
||||
<li><a href="configs/detection/dino">DINO (ArXiv'2022)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
|
|
@ -12,6 +12,7 @@ import torch
|
|||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||||
|
||||
from easycv.apis import set_random_seed
|
||||
from easycv.datasets import build_dataloader, build_dataset
|
||||
from easycv.file import io
|
||||
from easycv.models import build_model
|
||||
|
@ -20,25 +21,6 @@ from easycv.utils.config_tools import mmcv_config_fromfile
|
|||
from easycv.utils.logger import get_root_logger
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=True):
|
||||
"""Set random seed.
|
||||
|
||||
Args:
|
||||
seed (int): Seed to be used.
|
||||
deterministic (bool): Whether to set the deterministic option for
|
||||
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
||||
to True and `torch.backends.cudnn.benchmark` to False.
|
||||
Default: False.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if deterministic:
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
class ExtractProcess(object):
|
||||
|
||||
def __init__(self, extract_list=['neck']):
|
||||
|
|
|
@ -23,36 +23,41 @@ train_pipeline = [
|
|||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(
|
||||
type='MMAutoAugment',
|
||||
policies=[[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
||||
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
||||
(736, 1333), (768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='MMRandomCrop',
|
||||
crop_type='absolute_range',
|
||||
crop_size=(384, 600),
|
||||
allow_negative_crop=True),
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
override=True,
|
||||
keep_ratio=True)
|
||||
]]),
|
||||
policies=[
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
# The radio of all image in train dataset < 7
|
||||
# follow the original impl
|
||||
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='MMRandomCrop',
|
||||
crop_type='absolute_range',
|
||||
crop_size=(384, 600),
|
||||
allow_negative_crop=True),
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
override=True,
|
||||
keep_ratio=True)
|
||||
]
|
||||
]),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
|
@ -96,7 +101,7 @@ train_dataset = dict(
|
|||
],
|
||||
classes=CLASSES,
|
||||
test_mode=False,
|
||||
filter_empty_gt=True,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline)
|
||||
|
||||
|
@ -118,13 +123,18 @@ val_dataset = dict(
|
|||
pipeline=test_pipeline)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
||||
imgs_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
train=train_dataset,
|
||||
val=val_dataset,
|
||||
drop_last=True)
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
dist_eval=True,
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
|
@ -1,5 +1,5 @@
|
|||
_base_ = [
|
||||
'./dab_detr.py', '../_base_/dataset/autoaug_coco_detection.py',
|
||||
'./dab_detr.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'configs/base.py'
|
||||
]
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
_base_ = [
|
||||
'./detr.py', '../_base_/dataset/autoaug_coco_detection.py',
|
||||
'./detr.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'configs/base.py'
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# DINO
|
||||
|
||||
> [DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection](https://arxiv.org/abs/2203.03605)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
We present DINO(DETR with Improved deNoising anchOr boxes), a state-of-the-art end-to-end object detector. DINO improves over previous DETR-like models in performance and efficiency by using a contrastive way for denoising training, a mixed query selection method for anchor initialization, and a look forward twice scheme for box pre- diction. DINO achieves 49.4AP in 12 epochs and 51.3AP in 24 epochs on COCO with a ResNet-50 backbone and multi-scale features, yield- ing a significant improvement of +6.0AP and +2.7AP, respectively, compared to DN-DETR, the previous best DETR-like model. DINO scales well in both model size and data size. Without bells and whistles, after pre-training on the Objects365 dataset with a SwinL backbone, DINO obtains the best results on both COCO val2017 (63.2AP) and test-dev (63.3AP). Compared to other models on the leaderboard, DINO significantly reduces its model size and pre-training data size while achieving better results.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/algo_images/detection/DINO.png"/>
|
||||
</div>
|
||||
|
||||
## Results and Models
|
||||
|
||||
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| DINO_4sc_r50_12e | [DINO_4sc_r50_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_r50_12e_coco.py) | 23M/47M | 184ms | 48.71 | 66.27 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_12e/20220815_141403.log.json) |
|
||||
| DINO_4sc_r50_36e | [DINO_4sc_r50_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_r50_36e_coco.py) | 23M/47M | 184ms | 50.69 | 68.60 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/epoch_29.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/20220817_101549.log.json) |
|
||||
| DINO_4sc_swinl_12e | [DINO_4sc_swinl_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_12e_coco.py) | 195M/217M | 155ms | 56.86 | 75.61 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/20220815_211633.log.json) |
|
||||
| DINO_4sc_swinl_36e | [DINO_4sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_36e_coco.py) | 195M/217M | 155ms | 58.04 | 76.76 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/epoch_34.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/20220817_101416.log.json) |
|
||||
| DINO_5sc_swinl_36e | [DINO_5sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_36e_coco.py) | 195M/217M | 235ms | 58.47 | 77.10 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/epoch_35.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/20220820_215711.log.json) |
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@misc{zhang2022dino,
|
||||
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
|
||||
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
|
||||
year={2022},
|
||||
eprint={2203.03605},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,94 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='Detection',
|
||||
pretrained=True,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(2, 3, 4),
|
||||
frozen_stages=1,
|
||||
norm_cfg=dict(type='BN', requires_grad=False),
|
||||
norm_eval=True,
|
||||
style='pytorch'),
|
||||
head=dict(
|
||||
type='DINOHead',
|
||||
transformer=dict(
|
||||
type='DeformableTransformer',
|
||||
d_model=256,
|
||||
nhead=8,
|
||||
num_queries=900,
|
||||
num_encoder_layers=6,
|
||||
num_unicoder_layers=0,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.0,
|
||||
activation='relu',
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=True,
|
||||
query_dim=4,
|
||||
num_patterns=0,
|
||||
modulate_hw_attn=True,
|
||||
# for deformable encoder
|
||||
deformable_encoder=True,
|
||||
deformable_decoder=True,
|
||||
num_feature_levels=4,
|
||||
enc_n_points=4,
|
||||
dec_n_points=4,
|
||||
# init query
|
||||
decoder_query_perturber=None,
|
||||
add_channel_attention=False,
|
||||
random_refpoints_xy=False,
|
||||
# two stage
|
||||
two_stage_type=
|
||||
'standard', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
|
||||
two_stage_pat_embed=0,
|
||||
two_stage_add_query_num=0,
|
||||
two_stage_learn_wh=False,
|
||||
two_stage_keep_all_tokens=False,
|
||||
# evo of #anchors
|
||||
dec_layer_number=None,
|
||||
rm_dec_query_scale=True,
|
||||
rm_self_attn_layers=None,
|
||||
key_aware_type=None,
|
||||
# layer share
|
||||
layer_share_type=None,
|
||||
# for detach
|
||||
rm_detach=None,
|
||||
decoder_sa_type='sa',
|
||||
module_seq=['sa', 'ca', 'ffn'],
|
||||
# for dn
|
||||
embed_init_tgt=True,
|
||||
use_detached_boxes_dec_out=False),
|
||||
dn_components=dict(
|
||||
dn_number=100,
|
||||
dn_label_noise_ratio=0.5, # paper 0.5, release code 0.25
|
||||
dn_box_noise_scale=1.0,
|
||||
dn_labelbook_size=80,
|
||||
),
|
||||
num_classes=80,
|
||||
in_channels=[512, 1024, 2048],
|
||||
embed_dims=256,
|
||||
query_dim=4,
|
||||
num_queries=900,
|
||||
num_select=300,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
fix_refpoints_hw=-1,
|
||||
num_feature_levels=4,
|
||||
# two stage
|
||||
two_stage_type='standard', # ['no', 'standard']
|
||||
two_stage_add_query_num=0,
|
||||
dec_pred_class_embed_share=True,
|
||||
dec_pred_bbox_embed_share=True,
|
||||
two_stage_class_embed_share=False,
|
||||
two_stage_bbox_embed_share=False,
|
||||
decoder_sa_type='sa',
|
||||
temperatureH=20,
|
||||
temperatureW=20,
|
||||
cost_dict=dict(
|
||||
cost_class=2,
|
||||
cost_bbox=5,
|
||||
cost_giou=2,
|
||||
),
|
||||
weight_dict=dict(loss_ce=1, loss_bbox=5, loss_giou=2)))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'./dino_4sc_r50.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_4sc_r50_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[22])
|
||||
|
||||
total_epochs = 24
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_4sc_r50_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[27, 33])
|
||||
|
||||
total_epochs = 36
|
|
@ -0,0 +1,95 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='Detection',
|
||||
pretrained=
|
||||
'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/warpper_swin_large_patch4_window12_384_22k.pth',
|
||||
backbone=dict(
|
||||
type='SwinTransformer',
|
||||
pretrain_img_size=384,
|
||||
embed_dim=192,
|
||||
depths=[2, 2, 18, 2],
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=12,
|
||||
out_indices=(1, 2, 3),
|
||||
use_checkpoint=True),
|
||||
head=dict(
|
||||
type='DINOHead',
|
||||
transformer=dict(
|
||||
type='DeformableTransformer',
|
||||
d_model=256,
|
||||
nhead=8,
|
||||
num_queries=900,
|
||||
num_encoder_layers=6,
|
||||
num_unicoder_layers=0,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.0,
|
||||
activation='relu',
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=True,
|
||||
query_dim=4,
|
||||
num_patterns=0,
|
||||
modulate_hw_attn=True,
|
||||
# for deformable encoder
|
||||
deformable_encoder=True,
|
||||
deformable_decoder=True,
|
||||
num_feature_levels=4,
|
||||
enc_n_points=4,
|
||||
dec_n_points=4,
|
||||
# init query
|
||||
decoder_query_perturber=None,
|
||||
add_channel_attention=False,
|
||||
random_refpoints_xy=False,
|
||||
# two stage
|
||||
two_stage_type=
|
||||
'standard', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
|
||||
two_stage_pat_embed=0,
|
||||
two_stage_add_query_num=0,
|
||||
two_stage_learn_wh=False,
|
||||
two_stage_keep_all_tokens=False,
|
||||
# evo of #anchors
|
||||
dec_layer_number=None,
|
||||
rm_dec_query_scale=True,
|
||||
rm_self_attn_layers=None,
|
||||
key_aware_type=None,
|
||||
# layer share
|
||||
layer_share_type=None,
|
||||
# for detach
|
||||
rm_detach=None,
|
||||
decoder_sa_type='sa',
|
||||
module_seq=['sa', 'ca', 'ffn'],
|
||||
# for dn
|
||||
embed_init_tgt=True,
|
||||
use_detached_boxes_dec_out=False),
|
||||
dn_components=dict(
|
||||
dn_number=100,
|
||||
dn_label_noise_ratio=0.5, # paper 0.5, release code 0.25
|
||||
dn_box_noise_scale=1.0,
|
||||
dn_labelbook_size=80,
|
||||
),
|
||||
num_classes=80,
|
||||
in_channels=[384, 768, 1536],
|
||||
embed_dims=256,
|
||||
query_dim=4,
|
||||
num_queries=900,
|
||||
num_select=300,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
fix_refpoints_hw=-1,
|
||||
num_feature_levels=4,
|
||||
# two stage
|
||||
two_stage_type='standard', # ['no', 'standard']
|
||||
two_stage_add_query_num=0,
|
||||
dec_pred_class_embed_share=True,
|
||||
dec_pred_bbox_embed_share=True,
|
||||
two_stage_class_embed_share=False,
|
||||
two_stage_bbox_embed_share=False,
|
||||
decoder_sa_type='sa',
|
||||
temperatureH=20,
|
||||
temperatureW=20,
|
||||
cost_dict=dict(
|
||||
cost_class=2,
|
||||
cost_bbox=5,
|
||||
cost_giou=2,
|
||||
),
|
||||
weight_dict=dict(loss_ce=1, loss_bbox=5, loss_giou=2)))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'./dino_4sc_swinl.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_4sc_swinl_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[22])
|
||||
|
||||
total_epochs = 24
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_4sc_swinl_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[27, 33])
|
||||
|
||||
total_epochs = 36
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = './dino_4sc_r50.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(out_indices=(1, 2, 3, 4)),
|
||||
head=dict(
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
num_feature_levels=5,
|
||||
transformer=dict(num_feature_levels=5)))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'./dino_5sc_r50.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_5sc_r50_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[20])
|
||||
|
||||
total_epochs = 24
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_5sc_r50_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[27, 33])
|
||||
|
||||
total_epochs = 36
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = './dino_4sc_swinl.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(out_indices=(0, 1, 2, 3)),
|
||||
head=dict(
|
||||
in_channels=[192, 384, 768, 1536],
|
||||
num_feature_levels=5,
|
||||
transformer=dict(num_feature_levels=5)))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'./dino_5sc_swinl.py', '../common/dataset/autoaug_coco_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_5sc_swinl_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[20])
|
||||
|
||||
total_epochs = 24
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = './dino_5sc_swinl_12e_coco.py'
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[27, 33])
|
||||
|
||||
total_epochs = 36
|
|
@ -0,0 +1,19 @@
|
|||
_base_ = 'configs/base.py'
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
paramwise_options=paramwise_options)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[11])
|
||||
|
||||
total_epochs = 12
|
||||
|
||||
find_unused_parameters = False
|
|
@ -15,7 +15,7 @@ CLASSES = [
|
|||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = '/root/data/coco/'
|
||||
data_root = 'data/coco/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
|
|
|
@ -1,37 +1,4 @@
|
|||
_base_ = [
|
||||
'./_base_/models/vitdet.py', './_base_/datasets/coco_instance.py',
|
||||
'configs/base.py'
|
||||
'./vitdet_mask_rcnn.py', './lsj_coco_instance.py',
|
||||
'./vitdet_schedule_100e.py'
|
||||
]
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {
|
||||
'norm': dict(weight_decay=0.),
|
||||
'bias': dict(weight_decay=0.),
|
||||
'pos_embed': dict(weight_decay=0.),
|
||||
'cls_token': dict(weight_decay=0.)
|
||||
}
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.1,
|
||||
paramwise_options=paramwise_options)
|
||||
optimizer_config = dict(grad_clip=None, loss_scale=512.)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=250,
|
||||
warmup_ratio=0.067,
|
||||
step=[88, 96])
|
||||
total_epochs = 100
|
||||
|
||||
find_unused_parameters = False
|
||||
|
|
|
@ -1,37 +1,4 @@
|
|||
_base_ = [
|
||||
'./_base_/models/vitdet_faster_rcnn.py',
|
||||
'./_base_/datasets/coco_detection.py', 'configs/base.py'
|
||||
'./vitdet_faster_rcnn.py', './lsj_coco_detection.py',
|
||||
'./vitdet_schedule_100e.py'
|
||||
]
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {
|
||||
'norm': dict(weight_decay=0.),
|
||||
'bias': dict(weight_decay=0.),
|
||||
'pos_embed': dict(weight_decay=0.),
|
||||
'cls_token': dict(weight_decay=0.)
|
||||
}
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.1,
|
||||
paramwise_options=paramwise_options)
|
||||
optimizer_config = dict(grad_clip=None, loss_scale=512.)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=250,
|
||||
warmup_ratio=0.067,
|
||||
step=[88, 96])
|
||||
total_epochs = 100
|
||||
|
||||
find_unused_parameters = False
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
_base_ = 'configs/base.py'
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {
|
||||
'norm': dict(weight_decay=0.),
|
||||
'bias': dict(weight_decay=0.),
|
||||
'pos_embed': dict(weight_decay=0.),
|
||||
'cls_token': dict(weight_decay=0.)
|
||||
}
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.1,
|
||||
paramwise_options=paramwise_options)
|
||||
optimizer_config = dict(grad_clip=None, loss_scale=512.)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=250,
|
||||
warmup_ratio=0.067,
|
||||
step=[88, 96])
|
||||
total_epochs = 100
|
||||
|
||||
find_unused_parameters = False
|
|
@ -1,5 +1,7 @@
|
|||
# Detection Model Zoo
|
||||
|
||||
Inference default use V100 16G.
|
||||
|
||||
## YOLOX-PAI
|
||||
|
||||
Pretrained on COCO2017 dataset.
|
||||
|
@ -36,3 +38,13 @@ Pretrained on COCO2017 dataset.
|
|||
| DETR-r50 | [detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/detr/detr_r50_8x2_150e_coco.py) | 23M/41M | 48.5ms | 39.92 | 60.52 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/20220609_101243.log.json) |
|
||||
| DAB-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 42.52 | 63.03 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/20220610_122811.log.json) |
|
||||
| DN-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 44.39 | 64.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/20220713_105127.log.json) |
|
||||
|
||||
## DINO
|
||||
|
||||
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download | Comment |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------- |
|
||||
| DINO_4sc_r50_12e | [DINO_4sc_r50_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_r50_12e_coco.py) | 23M/47M | 184ms | 48.71 | 66.27 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_12e/20220815_141403.log.json) |Inference use V100 32G|
|
||||
| DINO_4sc_r50_36e | [DINO_4sc_r50_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_r50_36e_coco.py) | 23M/47M | 184ms | 50.69 | 68.60 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/epoch_29.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/20220817_101549.log.json) |Inference use V100 32G|
|
||||
| DINO_4sc_swinl_12e | [DINO_4sc_swinl_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_12e_coco.py) | 195M/217M | 155ms | 56.86 | 75.61 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/20220815_211633.log.json) |Inference use V100 32G|
|
||||
| DINO_4sc_swinl_36e | [DINO_4sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_36e_coco.py) | 195M/217M | 155ms | 58.04 | 76.76 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/epoch_34.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/20220817_101416.log.json) |Inference use V100 32G|
|
||||
| DINO_5sc_swinl_36e | [DINO_5sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_36e_coco.py) | 195M/217M | 235ms | 58.47 | 77.10 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/epoch_35.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/20220820_215711.log.json) |Inference use V100 32G|
|
||||
|
|
|
@ -79,6 +79,16 @@
|
|||
|
||||
```
|
||||
|
||||
6. If you want to use MSDeformAttn, you need to compiling CUDA operators
|
||||
|
||||
```shell
|
||||
cd thirdparty/deformable_attention/
|
||||
python setup.py build install
|
||||
# unit test (should see all checking is True)
|
||||
python test.py
|
||||
cd ../../..
|
||||
|
||||
```
|
||||
|
||||
### Verification
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .test import multi_gpu_test, single_cpu_test, single_gpu_test
|
||||
from .train import (build_optimizer, get_root_logger, set_random_seed,
|
||||
train_model)
|
||||
from .train import (build_optimizer, get_root_logger, init_random_seed,
|
||||
set_random_seed, train_model)
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import DistSamplerSeedHook, obj_from_dict
|
||||
from torch import optim
|
||||
from mmcv.runner.dist_utils import get_dist_info
|
||||
|
||||
from easycv.apis.train_misc import build_yolo_optimizer
|
||||
from easycv.core import optimizer
|
||||
|
@ -26,6 +26,36 @@ from easycv.utils.logger import get_root_logger, print_log
|
|||
from easycv.utils.torchacc_util import is_torchacc_enabled
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
"""Initialize random seed.
|
||||
If the seed is not set, the seed will be automatically randomized,
|
||||
and then broadcast to all processes to prevent some potential bugs.
|
||||
Args:
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
device (str): The device where the seed will be put on.
|
||||
Default to 'cuda'.
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
if seed is not None:
|
||||
return seed
|
||||
|
||||
# Make sure all ranks share the same random seed to prevent
|
||||
# some potential bugs. Please refer to
|
||||
# https://github.com/open-mmlab/mmdetection/issues/6339
|
||||
rank, world_size = get_dist_info()
|
||||
seed = np.random.randint(2**31)
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=False):
|
||||
"""Set random seed.
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from mmcv.runner import get_dist_info
|
|||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from easycv.datasets.shared.odps_reader import set_dataloader_workid
|
||||
from easycv.utils.dist_utils import sync_random_seed
|
||||
from easycv.utils.torchacc_util import is_torchacc_enabled
|
||||
from .collate import CollateWrapper
|
||||
from .sampler import DistributedMPSampler, DistributedSampler
|
||||
|
@ -50,6 +51,7 @@ def build_dataloader(dataset,
|
|||
Default: True.
|
||||
replace (bool): Replace or not in random shuffle.
|
||||
It works on when shuffle is True.
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
reuse_worker_cache (bool): If set true, will reuse worker process so that cached
|
||||
data in worker process can be reused.
|
||||
persistent_workers (bool) : After pytorch1.7, could use persistent_workers=True to
|
||||
|
@ -58,9 +60,10 @@ def build_dataloader(dataset,
|
|||
Returns:
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
|
||||
if dist:
|
||||
rank, world_size = get_dist_info()
|
||||
seed = sync_random_seed(seed)
|
||||
split_huge_listfile_byrank = getattr(dataset,
|
||||
'split_huge_listfile_byrank',
|
||||
False)
|
||||
|
@ -78,6 +81,7 @@ def build_dataloader(dataset,
|
|||
world_size,
|
||||
rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
split_huge_listfile_byrank=split_huge_listfile_byrank)
|
||||
batch_size = imgs_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
|
@ -93,7 +97,12 @@ def build_dataloader(dataset,
|
|||
batch_size = num_gpus * imgs_per_gpu
|
||||
num_workers = num_gpus * workers_per_gpu
|
||||
|
||||
init_fn = partial(worker_init_fn, seed=seed, odps_config=odps_config)
|
||||
init_fn = partial(
|
||||
worker_init_fn,
|
||||
num_workers=num_workers,
|
||||
rank=rank,
|
||||
seed=seed,
|
||||
odps_config=odps_config) if seed is not None else None
|
||||
collate_fn = dataset.collate_fn if hasattr(
|
||||
dataset, 'collate_fn') else partial(
|
||||
collate, samples_per_gpu=imgs_per_gpu)
|
||||
|
@ -145,12 +154,13 @@ def build_dataloader(dataset,
|
|||
return data_loader
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, seed=None, odps_config=None):
|
||||
if seed is not None:
|
||||
worker_seed = worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed, odps_config=None):
|
||||
# The seed of each worker equals to
|
||||
# num_worker * rank + worker_id + user_seed
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
|
||||
if odps_config is not None:
|
||||
# for odps to set correct offset in multi-process pytorch dataloader
|
||||
|
|
|
@ -161,6 +161,7 @@ class DistributedSampler(_DistributedSampler):
|
|||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
replace=False,
|
||||
split_huge_listfile_byrank=False,
|
||||
):
|
||||
|
@ -171,11 +172,13 @@ class DistributedSampler(_DistributedSampler):
|
|||
distributed training.
|
||||
rank (optional): Rank of the current process within num_replicas.
|
||||
shuffle (optional): If true (default), sampler will shuffle the indices
|
||||
seed (int, Optional): The seed. Default to 0.
|
||||
split_huge_listfile_byrank: if split, return all indice for each rank, because list for each rank has been
|
||||
split before build dataset in dist training
|
||||
"""
|
||||
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.replace = replace
|
||||
self.unif_sampling_flag = False
|
||||
self.split_huge_listfile_byrank = split_huge_listfile_byrank
|
||||
|
@ -197,7 +200,7 @@ class DistributedSampler(_DistributedSampler):
|
|||
def generate_new_list(self):
|
||||
if self.shuffle:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
g.manual_seed(self.epoch + self.seed)
|
||||
if self.replace:
|
||||
indices = torch.randint(
|
||||
low=0,
|
||||
|
@ -299,6 +302,7 @@ class DistributedGroupSampler(Sampler):
|
|||
Dataset is assumed to be of constant size.
|
||||
Args:
|
||||
dataset: Dataset used for sampling.
|
||||
seed (int, Optional): The seed. Default to 0.
|
||||
num_replicas (optional): Number of processes participating in
|
||||
distributed training.
|
||||
rank (optional): Rank of the current process within num_replicas.
|
||||
|
@ -307,6 +311,7 @@ class DistributedGroupSampler(Sampler):
|
|||
def __init__(self,
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
seed=0,
|
||||
num_replicas=None,
|
||||
rank=None):
|
||||
_rank, _num_replicas = get_dist_info()
|
||||
|
@ -316,6 +321,7 @@ class DistributedGroupSampler(Sampler):
|
|||
rank = _rank
|
||||
self.dataset = dataset
|
||||
self.samples_per_gpu = samples_per_gpu
|
||||
self.seed = seed
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
|
@ -334,7 +340,7 @@ class DistributedGroupSampler(Sampler):
|
|||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
g.manual_seed(self.epoch + self.seed)
|
||||
|
||||
indices = []
|
||||
for i, size in enumerate(self.group_sizes):
|
||||
|
@ -436,7 +442,6 @@ class DistributedGivenIterationSampler(Sampler):
|
|||
self.indices = indices
|
||||
|
||||
def gen_new_list(self):
|
||||
|
||||
# each process shuffle all list with same seed, and pick one piece according to rank
|
||||
np.random.seed(0)
|
||||
|
||||
|
|
|
@ -34,10 +34,8 @@ class OptimizerHook(_OptimizerHook):
|
|||
'''
|
||||
ignore_key: [str,...], ignore_key[i], name of parameters, which's gradient will be set to zero before every optimizer step when epoch < ignore_key_epoch[i]
|
||||
ignore_key_epoch: [int,...], epoch < ignore_key_epoch[i], ignore_key[i]'s gradient will be set to zero.
|
||||
|
||||
multiply_key:[str,...] multiply_key[i], name of parameters, which will set different learning rate ratio by multipy_rate
|
||||
multiply_rate:[float,...] multiply_rate[i], different ratio
|
||||
|
||||
'''
|
||||
self.grad_clip = grad_clip
|
||||
self.coalesce = coalesce
|
||||
|
@ -48,9 +46,6 @@ class OptimizerHook(_OptimizerHook):
|
|||
self.multiply_key = multiply_key
|
||||
self.multiply_rate = multiply_rate
|
||||
|
||||
def before_run(self, runner):
|
||||
runner.optimizer.zero_grad()
|
||||
|
||||
def _get_module(self, runner):
|
||||
module = runner.model
|
||||
if is_module_wrapper(module):
|
||||
|
@ -152,8 +147,6 @@ class AMPFP16OptimizerHook(OptimizerHook):
|
|||
if hasattr(m, 'fp16_enabled'):
|
||||
m.fp16_enabled = True
|
||||
|
||||
runner.optimizer.zero_grad()
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
loss = runner.outputs['loss'] / self.update_interval
|
||||
|
||||
|
|
|
@ -18,5 +18,5 @@ from .resnet import ResNet
|
|||
from .resnet_jit import ResNetJIT
|
||||
from .resnext import ResNeXt
|
||||
from .shuffle_transformer import ShuffleTransformer
|
||||
from .swin_transformer_dynamic import SwinTransformer
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .vitdet import ViTDet
|
||||
|
|
|
@ -8,8 +8,9 @@ import torch.nn as nn
|
|||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from easycv.models.registry import BACKBONES
|
||||
from easycv.models.utils import DropPath
|
||||
from easycv.models.utils.pos_embed import get_2d_sincos_pos_embed
|
||||
from .vit_transfomer_dynamic import Block, DropPath
|
||||
from .vit_transfomer_dynamic import Block
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
|
|
|
@ -437,8 +437,8 @@ class ResNet(nn.Module):
|
|||
|
||||
self.feat_dim = self.block.expansion * self.original_inplanes * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
if num_classes > 0:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(self.feat_dim, num_classes)
|
||||
|
||||
self.default_pretrained_model_path = model_urls.get(
|
||||
|
|
|
@ -0,0 +1,708 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# DINO
|
||||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from easycv.models.utils import Mlp
|
||||
from ..registry import BACKBONES
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
""" Forward function.
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
""" Swin Transformer Block.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
""" Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
mask_matrix: Attention mask for cyclic shift.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(
|
||||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(
|
||||
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp,
|
||||
Wp) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
Args:
|
||||
dim (int): Number of feature channels
|
||||
depth (int): Depths of this stage.
|
||||
num_heads (int): Number of attention head.
|
||||
window_size (int): Local window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer) for i in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
""" Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1,
|
||||
self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x,
|
||||
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module
|
||||
class SwinTransformer(nn.Module):
|
||||
""" Swin Transformer backbone.
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
Args:
|
||||
pretrain_img_size (int): Input image size for training the pretrained model,
|
||||
used in absolute postion embedding. Default 224.
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
num_heads (tuple[int]): Number of attention head of each stage.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop_rate (float): Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
dilation=False,
|
||||
use_checkpoint=False):
|
||||
super().__init__()
|
||||
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.dilation = dilation
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
pretrain_img_size[0] // patch_size[0],
|
||||
pretrain_img_size[1] // patch_size[1]
|
||||
]
|
||||
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, patches_resolution[0],
|
||||
patches_resolution[1]))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
# prepare downsample list
|
||||
downsamplelist = [PatchMerging for i in range(self.num_layers)]
|
||||
downsamplelist[-1] = None
|
||||
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
||||
if self.dilation:
|
||||
downsamplelist[-2] = None
|
||||
num_features[-1] = int(embed_dim * 2**(self.num_layers - 1)) // 2
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
# dim=int(embed_dim * 2 ** i_layer),
|
||||
dim=num_features[i_layer],
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
downsample=downsamplelist[i_layer],
|
||||
use_checkpoint=use_checkpoint)
|
||||
self.layers.append(layer)
|
||||
|
||||
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
# add a norm layer for each output
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 1 and self.ape:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 2:
|
||||
self.pos_drop.eval()
|
||||
for i in range(0, self.frozen_stages - 1):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward_raw(self, x):
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(
|
||||
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1,
|
||||
2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
# in:
|
||||
# torch.Size([2, 3, 1024, 1024])
|
||||
# outs:
|
||||
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
||||
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
||||
return tuple(outs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(
|
||||
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1,
|
||||
2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
return outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super(SwinTransformer, self).train(mode)
|
||||
self._freeze_stages()
|
|
@ -15,67 +15,9 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from easycv.models.utils import Mlp
|
||||
from ..registry import BACKBONES
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super(Mlp, self).__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
from .swin_transformer import window_partition, window_reverse
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
|
@ -661,7 +603,7 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
|
||||
@BACKBONES.register_module
|
||||
class SwinTransformer(nn.Module):
|
||||
class DynamicSwinTransformer(nn.Module):
|
||||
r""" Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
@ -984,7 +926,7 @@ class SwinTransformer(nn.Module):
|
|||
|
||||
|
||||
def dynamic_swin_tiny_p4_w7_224(pretrained=False, **kwargs):
|
||||
model = SwinTransformer(
|
||||
model = DynamicSwinTransformer(
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=kwargs['num_classes'],
|
||||
|
@ -1006,7 +948,7 @@ def dynamic_swin_tiny_p4_w7_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
def dynamic_swin_small_p4_w7_224(pretrained=False, **kwargs):
|
||||
model = SwinTransformer(
|
||||
model = DynamicSwinTransformer(
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=kwargs['num_classes'],
|
||||
|
@ -1028,7 +970,7 @@ def dynamic_swin_small_p4_w7_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
def dynamic_swin_base_p4_w7_224(pretrained=False, **kwargs):
|
||||
model = SwinTransformer(
|
||||
model = DynamicSwinTransformer(
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
num_classes=kwargs['num_classes'],
|
||||
|
|
|
@ -14,55 +14,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (
|
||||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
from easycv.models.utils import DropPath, Mlp
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@ -171,8 +123,8 @@ class PatchEmbed(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer """
|
||||
class DynamicVisionTransformer(nn.Module):
|
||||
"""Dynamic Vision Transformer """
|
||||
|
||||
def __init__(self,
|
||||
img_size=[224],
|
||||
|
@ -449,7 +401,7 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
|
||||
def dynamic_deit_tiny_p16(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
model = DynamicVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=192,
|
||||
depth=12,
|
||||
|
@ -462,7 +414,7 @@ def dynamic_deit_tiny_p16(patch_size=16, **kwargs):
|
|||
|
||||
|
||||
def dynamic_deit_small_p16(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
model = DynamicVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
|
@ -475,7 +427,7 @@ def dynamic_deit_small_p16(patch_size=16, **kwargs):
|
|||
|
||||
|
||||
def dynamic_vit_base_p16(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
model = DynamicVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
|
@ -488,7 +440,7 @@ def dynamic_vit_base_p16(patch_size=16, **kwargs):
|
|||
|
||||
|
||||
def dynamic_vit_large_p16(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
model = DynamicVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
|
@ -501,7 +453,7 @@ def dynamic_vit_large_p16(patch_size=16, **kwargs):
|
|||
|
||||
|
||||
def dynamic_vit_huge_p14(patch_size=14, **kwargs):
|
||||
model = VisionTransformer(
|
||||
model = DynamicVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1280,
|
||||
depth=32,
|
||||
|
|
|
@ -9,56 +9,16 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
from mmcv.cnn import build_norm_layer, constant_init, kaiming_init
|
||||
from mmcv.runner import get_dist_info
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
from timm.models.layers import to_2tuple, trunc_normal_
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from easycv.models.utils import DropPath, Mlp
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from ..registry import BACKBONES
|
||||
from ..utils import build_conv_layer
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from easycv.models.detection.detectors.dab_detr import (DABDETRHead,
|
|||
DABDetrTransformer)
|
||||
from easycv.models.detection.detectors.detection import Detection
|
||||
from easycv.models.detection.detectors.detr import DETRHead, DetrTransformer
|
||||
from easycv.models.detection.detectors.dino import (DeformableTransformer,
|
||||
DINOHead)
|
||||
from easycv.models.detection.detectors.fcos import FCOSHead
|
||||
|
||||
try:
|
||||
|
|
|
@ -2,23 +2,23 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from easycv.models.builder import HEADS, build_neck
|
||||
from easycv.models.detection.utils import (HungarianMatcher, SetCriterion,
|
||||
box_cxcywh_to_xyxy,
|
||||
box_xyxy_to_cxcywh, inverse_sigmoid)
|
||||
from easycv.models.detection.utils import (DetrPostProcess, box_xyxy_to_cxcywh,
|
||||
inverse_sigmoid)
|
||||
from easycv.models.loss import DNCriterion, HungarianMatcher, SetCriterion
|
||||
from easycv.models.utils import MLP
|
||||
from .dn_components import dn_post_process, prepare_for_dn
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DABDETRHead(nn.Module):
|
||||
"""Implements the DETR transformer head.
|
||||
See `paper: End-to-End Object Detection with Transformers
|
||||
<https://arxiv.org/pdf/2005.12872>`_ for details.
|
||||
"""Implements the DAB-DETR head.
|
||||
See `paper: DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR
|
||||
<https://arxiv.org/abs/2201.12329> and DN-DETR: Accelerate DETR Training by Introducing Query DeNoising
|
||||
<https://arxiv.org/abs/2203.01305>`_ for details.
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background.
|
||||
"""
|
||||
|
@ -56,9 +56,10 @@ class DABDETRHead(nn.Module):
|
|||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes'],
|
||||
loss_class_type='focal_loss',
|
||||
dn_components=dn_components)
|
||||
self.postprocess = PostProcess(num_select=num_select)
|
||||
loss_class_type='focal_loss')
|
||||
if dn_components is not None:
|
||||
self.dn_criterion = DNCriterion(weight_dict)
|
||||
self.postprocess = DetrPostProcess(num_select=num_select)
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
self.class_embed = nn.Linear(embed_dims, num_classes)
|
||||
|
@ -256,7 +257,10 @@ class DABDETRHead(nn.Module):
|
|||
attn_mask=attn_mask,
|
||||
mask_dict=mask_dict)
|
||||
|
||||
losses = self.criterion(outputs, targets, mask_dict)
|
||||
losses = self.criterion(outputs, targets)
|
||||
if self.dn_components:
|
||||
losses.update(
|
||||
self.dn_criterion(mask_dict, len(outputs['aux_outputs'])))
|
||||
|
||||
return losses
|
||||
|
||||
|
@ -279,51 +283,3 @@ class DABDETRHead(nn.Module):
|
|||
|
||||
results = self.postprocess(outputs, orig_target_sizes, img_metas)
|
||||
return results
|
||||
|
||||
|
||||
class PostProcess(nn.Module):
|
||||
""" This module converts the model's output into the format expected by the coco api"""
|
||||
|
||||
def __init__(self, num_select=100) -> None:
|
||||
super().__init__()
|
||||
self.num_select = num_select
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, target_sizes, img_metas):
|
||||
""" Perform the computation
|
||||
Parameters:
|
||||
outputs: raw outputs of the model
|
||||
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||||
For evaluation, this must be the original image size (before any data augmentation)
|
||||
For visualization, this should be the image size after data augment, but before padding
|
||||
"""
|
||||
num_select = self.num_select
|
||||
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||||
|
||||
assert len(out_logits) == len(target_sizes)
|
||||
assert target_sizes.shape[1] == 2
|
||||
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(
|
||||
prob.view(out_logits.shape[0], -1), num_select, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
boxes = torch.gather(boxes, 1,
|
||||
topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = {
|
||||
'detection_boxes': [boxes[0].cpu().numpy()],
|
||||
'detection_scores': [scores[0].cpu().numpy()],
|
||||
'detection_classes': [labels[0].cpu().numpy().astype(np.int32)],
|
||||
'img_metas': img_metas
|
||||
}
|
||||
|
||||
return results
|
||||
|
|
|
@ -15,19 +15,13 @@ import torch.nn.functional as F
|
|||
from torch import Tensor, nn
|
||||
|
||||
from easycv.models.builder import NECKS
|
||||
from easycv.models.detection.utils import inverse_sigmoid
|
||||
from easycv.models.utils import (MLP, TransformerEncoder,
|
||||
TransformerEncoderLayer, _get_activation_fn,
|
||||
_get_clones)
|
||||
from .attention import MultiheadAttention
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-3):
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1 / x2)
|
||||
|
||||
|
||||
@NECKS.register_module
|
||||
class DABDetrTransformer(nn.Module):
|
||||
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import HEADS, build_neck
|
||||
from easycv.models.detection.utils import (HungarianMatcher, SetCriterion,
|
||||
box_cxcywh_to_xyxy,
|
||||
box_xyxy_to_cxcywh)
|
||||
from easycv.models.detection.utils import DetrPostProcess, box_xyxy_to_cxcywh
|
||||
from easycv.models.loss import HungarianMatcher, SetCriterion
|
||||
from easycv.models.utils import MLP
|
||||
|
||||
|
||||
|
@ -49,7 +46,7 @@ class DETRHead(nn.Module):
|
|||
weight_dict=weight_dict,
|
||||
eos_coef=eos_coef,
|
||||
losses=['labels', 'boxes'])
|
||||
self.postprocess = PostProcess()
|
||||
self.postprocess = DetrPostProcess()
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
self.class_embed = nn.Linear(embed_dims, num_classes + 1)
|
||||
|
@ -149,41 +146,3 @@ class DETRHead(nn.Module):
|
|||
|
||||
results = self.postprocess(outputs, orig_target_sizes, img_metas)
|
||||
return results
|
||||
|
||||
|
||||
class PostProcess(nn.Module):
|
||||
""" This module converts the model's output into the format expected by the coco api"""
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, target_sizes, img_metas):
|
||||
""" Perform the computation
|
||||
Parameters:
|
||||
outputs: raw outputs of the model
|
||||
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||||
For evaluation, this must be the original image size (before any data augmentation)
|
||||
For visualization, this should be the image size after data augment, but before padding
|
||||
"""
|
||||
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||||
|
||||
assert len(out_logits) == len(target_sizes)
|
||||
assert target_sizes.shape[1] == 2
|
||||
|
||||
prob = F.softmax(out_logits, -1)
|
||||
scores, labels = prob[..., :-1].max(-1)
|
||||
|
||||
# convert to [x0, y0, x1, y1] format
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = {
|
||||
'detection_boxes': [boxes[0].cpu().numpy()],
|
||||
'detection_scores': [scores[0].cpu().numpy()],
|
||||
'detection_classes': [labels[0].cpu().numpy().astype(np.int32)],
|
||||
'img_metas': img_metas
|
||||
}
|
||||
|
||||
return results
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .deformable_transformer import DeformableTransformer
|
||||
from .dino_head import DINOHead
|
|
@ -0,0 +1,169 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# DN-DETR
|
||||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.models.detection.utils import inverse_sigmoid
|
||||
|
||||
|
||||
def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim,
|
||||
label_enc):
|
||||
"""
|
||||
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
|
||||
forward function and use learnable tgt embedding, so we change this function a little bit.
|
||||
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
|
||||
:param training: if it is training or inference
|
||||
:param num_queries: number of queires
|
||||
:param num_classes: number of classes
|
||||
:param hidden_dim: transformer hidden dim
|
||||
:param label_enc: encode labels in dn
|
||||
:return:
|
||||
"""
|
||||
if training:
|
||||
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
|
||||
# positive and negative dn queries
|
||||
dn_number = dn_number * 2
|
||||
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
||||
batch_size = len(known)
|
||||
known_num = [sum(k) for k in known]
|
||||
if int(max(known_num)) == 0:
|
||||
dn_number = 1
|
||||
else:
|
||||
if dn_number >= 100:
|
||||
dn_number = dn_number // (int(max(known_num) * 2))
|
||||
elif dn_number < 1:
|
||||
dn_number = 1
|
||||
if dn_number == 0:
|
||||
dn_number = 1
|
||||
unmask_bbox = unmask_label = torch.cat(known)
|
||||
labels = torch.cat([t['labels'] for t in targets])
|
||||
boxes = torch.cat([t['boxes'] for t in targets])
|
||||
batch_idx = torch.cat([
|
||||
torch.full_like(t['labels'].long(), i)
|
||||
for i, t in enumerate(targets)
|
||||
])
|
||||
|
||||
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
||||
known_indice = known_indice.view(-1)
|
||||
|
||||
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
||||
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bboxs = boxes.repeat(2 * dn_number, 1)
|
||||
known_labels_expaned = known_labels.clone()
|
||||
known_bbox_expand = known_bboxs.clone()
|
||||
|
||||
if label_noise_ratio > 0:
|
||||
p = torch.rand_like(known_labels_expaned.float())
|
||||
chosen_indice = torch.nonzero(p < (label_noise_ratio)).view(
|
||||
-1) # half of bbox prob
|
||||
new_label = torch.randint_like(
|
||||
chosen_indice, 0, num_classes) # randomly put a new one here
|
||||
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
||||
single_pad = int(max(known_num))
|
||||
|
||||
pad_size = int(single_pad * 2 * dn_number)
|
||||
positive_idx = torch.tensor(range(
|
||||
len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
|
||||
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) *
|
||||
2).long().cuda().unsqueeze(1)
|
||||
positive_idx = positive_idx.flatten()
|
||||
negative_idx = positive_idx + len(boxes)
|
||||
if box_noise_scale > 0:
|
||||
known_bbox_ = torch.zeros_like(known_bboxs)
|
||||
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
|
||||
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
|
||||
|
||||
diff = torch.zeros_like(known_bboxs)
|
||||
diff[:, :2] = known_bboxs[:, 2:] / 2
|
||||
diff[:, 2:] = known_bboxs[:, 2:] / 2
|
||||
|
||||
rand_sign = torch.randint_like(
|
||||
known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(known_bboxs)
|
||||
rand_part[negative_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox_ = known_bbox_ + torch.mul(
|
||||
rand_part, diff).cuda() * box_noise_scale
|
||||
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
|
||||
known_bbox_expand[:, :2] = (known_bbox_[:, :2] +
|
||||
known_bbox_[:, 2:]) / 2
|
||||
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
|
||||
|
||||
m = known_labels_expaned.long().to('cuda')
|
||||
input_label_embed = label_enc(m)
|
||||
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
||||
|
||||
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
||||
padding_bbox = torch.zeros(pad_size, 4).cuda()
|
||||
|
||||
input_query_label = padding_label.repeat(batch_size, 1, 1)
|
||||
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
|
||||
|
||||
map_known_indice = torch.tensor([]).to('cuda')
|
||||
if len(known_num):
|
||||
map_known_indice = torch.cat([
|
||||
torch.tensor(range(num)) for num in known_num
|
||||
]) # [1,2, 1,2,3]
|
||||
map_known_indice = torch.cat([
|
||||
map_known_indice + single_pad * i for i in range(2 * dn_number)
|
||||
]).long()
|
||||
if len(known_bid):
|
||||
input_query_label[(known_bid.long(),
|
||||
map_known_indice)] = input_label_embed
|
||||
input_query_bbox[(known_bid.long(),
|
||||
map_known_indice)] = input_bbox_embed
|
||||
|
||||
tgt_size = pad_size + num_queries
|
||||
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[pad_size:, :pad_size] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(dn_number):
|
||||
if i == 0:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
if i == dn_number - 1:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * i * 2] = True
|
||||
else:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * 2 * i] = True
|
||||
|
||||
dn_meta = {
|
||||
'pad_size': pad_size,
|
||||
'num_dn_group': dn_number,
|
||||
}
|
||||
else:
|
||||
|
||||
input_query_label = None
|
||||
input_query_bbox = None
|
||||
attn_mask = None
|
||||
dn_meta = None
|
||||
|
||||
return input_query_label, input_query_bbox, attn_mask, dn_meta
|
||||
|
||||
|
||||
def cdn_post_process(outputs_class, outputs_coord, dn_meta, _set_aux_loss):
|
||||
"""
|
||||
post process of dn after output from the transformer
|
||||
put the dn part in the dn_meta
|
||||
"""
|
||||
if dn_meta and dn_meta['pad_size'] > 0:
|
||||
output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
|
||||
output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
|
||||
outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
|
||||
outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
|
||||
out = {
|
||||
'pred_logits': output_known_class[-1],
|
||||
'pred_boxes': output_known_coord[-1]
|
||||
}
|
||||
out['aux_outputs'] = _set_aux_loss(output_known_class,
|
||||
output_known_coord)
|
||||
dn_meta['output_known_lbs_bboxes'] = out
|
||||
return outputs_class, outputs_coord
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,474 @@
|
|||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import HEADS, build_neck
|
||||
from easycv.models.detection.utils import (DetrPostProcess, box_xyxy_to_cxcywh,
|
||||
inverse_sigmoid)
|
||||
from easycv.models.loss import CDNCriterion, HungarianMatcher, SetCriterion
|
||||
from easycv.models.utils import (MLP, get_world_size,
|
||||
is_dist_avail_and_initialized)
|
||||
from ..dab_detr.dab_detr_transformer import PositionEmbeddingSineHW
|
||||
from .cdn_components import cdn_post_process, prepare_for_cdn
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DINOHead(nn.Module):
|
||||
""" Initializes the DINO Head.
|
||||
See `paper: DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection
|
||||
<https://arxiv.org/abs/2203.03605>`_ for details.
|
||||
Parameters:
|
||||
backbone: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
num_classes: number of object classes
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
|
||||
fix_refpoints_hw: -1(default): learn w and h for each box seperately
|
||||
>0 : given fixed number
|
||||
-2 : learn a shared w and h
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
embed_dims,
|
||||
in_channels=[512, 1024, 2048],
|
||||
query_dim=4,
|
||||
num_queries=300,
|
||||
num_select=300,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
dn_components=None,
|
||||
transformer=None,
|
||||
fix_refpoints_hw=-1,
|
||||
num_feature_levels=1,
|
||||
# two stage
|
||||
two_stage_type='standard', # ['no', 'standard']
|
||||
two_stage_add_query_num=0,
|
||||
dec_pred_class_embed_share=True,
|
||||
dec_pred_bbox_embed_share=True,
|
||||
two_stage_class_embed_share=True,
|
||||
two_stage_bbox_embed_share=True,
|
||||
decoder_sa_type='sa',
|
||||
temperatureH=20,
|
||||
temperatureW=20,
|
||||
cost_dict={
|
||||
'cost_class': 1,
|
||||
'cost_bbox': 5,
|
||||
'cost_giou': 2,
|
||||
},
|
||||
weight_dict={
|
||||
'loss_ce': 1,
|
||||
'loss_bbox': 5,
|
||||
'loss_giou': 2
|
||||
},
|
||||
**kwargs):
|
||||
|
||||
super(DINOHead, self).__init__()
|
||||
|
||||
self.matcher = HungarianMatcher(
|
||||
cost_dict=cost_dict, cost_class_type='focal_loss_cost')
|
||||
self.criterion = SetCriterion(
|
||||
num_classes,
|
||||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes'],
|
||||
loss_class_type='focal_loss')
|
||||
if dn_components is not None:
|
||||
self.dn_criterion = CDNCriterion(
|
||||
num_classes,
|
||||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes'],
|
||||
loss_class_type='focal_loss')
|
||||
self.postprocess = DetrPostProcess(num_select=num_select)
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
self.positional_encoding = PositionEmbeddingSineHW(
|
||||
embed_dims // 2,
|
||||
temperatureH=temperatureH,
|
||||
temperatureW=temperatureW,
|
||||
normalize=True)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.embed_dims = embed_dims
|
||||
self.query_dim = query_dim
|
||||
self.dn_components = dn_components
|
||||
|
||||
self.random_refpoints_xy = random_refpoints_xy
|
||||
self.fix_refpoints_hw = fix_refpoints_hw
|
||||
|
||||
# for dn training
|
||||
self.dn_number = self.dn_components['dn_number']
|
||||
self.dn_box_noise_scale = self.dn_components['dn_box_noise_scale']
|
||||
self.dn_label_noise_ratio = self.dn_components['dn_label_noise_ratio']
|
||||
self.dn_labelbook_size = self.dn_components['dn_labelbook_size']
|
||||
self.label_enc = nn.Embedding(self.dn_labelbook_size + 1, embed_dims)
|
||||
|
||||
# prepare input projection layers
|
||||
self.num_feature_levels = num_feature_levels
|
||||
if num_feature_levels > 1:
|
||||
num_backbone_outs = len(in_channels)
|
||||
input_proj_list = []
|
||||
for i in range(num_backbone_outs):
|
||||
in_channels_i = in_channels[i]
|
||||
input_proj_list.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels_i, embed_dims, kernel_size=1),
|
||||
nn.GroupNorm(32, embed_dims),
|
||||
))
|
||||
for _ in range(num_feature_levels - num_backbone_outs):
|
||||
input_proj_list.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels_i,
|
||||
embed_dims,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
nn.GroupNorm(32, embed_dims),
|
||||
))
|
||||
in_channels_i = embed_dims
|
||||
self.input_proj = nn.ModuleList(input_proj_list)
|
||||
else:
|
||||
assert two_stage_type == 'no', 'two_stage_type should be no if num_feature_levels=1 !!!'
|
||||
self.input_proj = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels[-1], embed_dims, kernel_size=1),
|
||||
nn.GroupNorm(32, embed_dims),
|
||||
)
|
||||
])
|
||||
|
||||
# prepare pred layers
|
||||
self.dec_pred_class_embed_share = dec_pred_class_embed_share
|
||||
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
|
||||
# prepare class & box embed
|
||||
_class_embed = nn.Linear(embed_dims, num_classes)
|
||||
_bbox_embed = MLP(embed_dims, embed_dims, 4, 3)
|
||||
# init the two embed layers
|
||||
prior_prob = 0.01
|
||||
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
||||
_class_embed.bias.data = torch.ones(self.num_classes) * bias_value
|
||||
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
if dec_pred_bbox_embed_share:
|
||||
box_embed_layerlist = [
|
||||
_bbox_embed for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
box_embed_layerlist = [
|
||||
copy.deepcopy(_bbox_embed)
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
if dec_pred_class_embed_share:
|
||||
class_embed_layerlist = [
|
||||
_class_embed for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
class_embed_layerlist = [
|
||||
copy.deepcopy(_class_embed)
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
|
||||
self.class_embed = nn.ModuleList(class_embed_layerlist)
|
||||
self.transformer.decoder.bbox_embed = self.bbox_embed
|
||||
self.transformer.decoder.class_embed = self.class_embed
|
||||
|
||||
# two stage
|
||||
self.two_stage_type = two_stage_type
|
||||
self.two_stage_add_query_num = two_stage_add_query_num
|
||||
assert two_stage_type in [
|
||||
'no', 'standard'
|
||||
], 'unknown param {} of two_stage_type'.format(two_stage_type)
|
||||
if two_stage_type != 'no':
|
||||
if two_stage_bbox_embed_share:
|
||||
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
||||
self.transformer.enc_out_bbox_embed = _bbox_embed
|
||||
else:
|
||||
self.transformer.enc_out_bbox_embed = copy.deepcopy(
|
||||
_bbox_embed)
|
||||
|
||||
if two_stage_class_embed_share:
|
||||
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
||||
self.transformer.enc_out_class_embed = _class_embed
|
||||
else:
|
||||
self.transformer.enc_out_class_embed = copy.deepcopy(
|
||||
_class_embed)
|
||||
|
||||
self.refpoint_embed = None
|
||||
if self.two_stage_add_query_num > 0:
|
||||
self.init_ref_points(two_stage_add_query_num)
|
||||
|
||||
self.decoder_sa_type = decoder_sa_type
|
||||
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
|
||||
# self.replace_sa_with_double_ca = replace_sa_with_double_ca
|
||||
if decoder_sa_type == 'ca_label':
|
||||
self.label_embedding = nn.Embedding(num_classes, embed_dims)
|
||||
for layer in self.transformer.decoder.layers:
|
||||
layer.label_embedding = self.label_embedding
|
||||
else:
|
||||
for layer in self.transformer.decoder.layers:
|
||||
layer.label_embedding = None
|
||||
self.label_embedding = None
|
||||
|
||||
def init_weights(self):
|
||||
# init input_proj
|
||||
for proj in self.input_proj:
|
||||
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
||||
nn.init.constant_(proj[0].bias, 0)
|
||||
|
||||
def init_ref_points(self, use_num_queries):
|
||||
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
||||
|
||||
if self.random_refpoints_xy:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
|
||||
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
|
||||
self.refpoint_embed.weight.data[:, :2])
|
||||
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
||||
|
||||
if self.fix_refpoints_hw > 0:
|
||||
print('fix_refpoints_hw: {}'.format(self.fix_refpoints_hw))
|
||||
assert self.random_refpoints_xy
|
||||
self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw
|
||||
self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(
|
||||
self.refpoint_embed.weight.data[:, 2:])
|
||||
self.refpoint_embed.weight.data[:, 2:].requires_grad = False
|
||||
elif int(self.fix_refpoints_hw) == -1:
|
||||
pass
|
||||
elif int(self.fix_refpoints_hw) == -2:
|
||||
print('learn a shared h and w')
|
||||
assert self.random_refpoints_xy
|
||||
self.refpoint_embed = nn.Embedding(use_num_queries, 2)
|
||||
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
|
||||
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
|
||||
self.refpoint_embed.weight.data[:, :2])
|
||||
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
||||
self.hw_embed = nn.Embedding(1, 1)
|
||||
else:
|
||||
raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(
|
||||
self.fix_refpoints_hw))
|
||||
|
||||
def prepare(self, features, targets=None, mode='train'):
|
||||
|
||||
if self.dn_number > 0 or targets is not None:
|
||||
input_query_label, input_query_bbox, attn_mask, dn_meta =\
|
||||
prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale),
|
||||
training=self.training, num_queries=self.num_queries, num_classes=self.num_classes,
|
||||
hidden_dim=self.embed_dims, label_enc=self.label_enc)
|
||||
else:
|
||||
assert targets is None
|
||||
input_query_bbox = input_query_label = attn_mask = dn_meta = None
|
||||
|
||||
return input_query_bbox, input_query_label, attn_mask, dn_meta
|
||||
|
||||
def forward(self,
|
||||
feats,
|
||||
img_metas,
|
||||
query_embed=None,
|
||||
tgt=None,
|
||||
attn_mask=None,
|
||||
dn_meta=None):
|
||||
"""Forward function.
|
||||
Args:
|
||||
feats (tuple[Tensor]): Features from the upstream network, each is
|
||||
a 4D-tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
Returns:
|
||||
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
|
||||
- all_cls_scores_list (list[Tensor]): Classification scores \
|
||||
for each scale level. Each is a 4D-tensor with shape \
|
||||
[nb_dec, bs, num_query, cls_out_channels]. Note \
|
||||
`cls_out_channels` should includes background.
|
||||
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
|
||||
outputs for each scale level. Each is a 4D-tensor with \
|
||||
normalized coordinate format (cx, cy, w, h) and shape \
|
||||
[nb_dec, bs, num_query, 4].
|
||||
"""
|
||||
# construct binary masks which used for the transformer.
|
||||
# NOTE following the official DETR repo, non-zero values representing
|
||||
# ignored positions, while zero values means valid positions.
|
||||
bs = feats[0].size(0)
|
||||
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
|
||||
img_masks = feats[0].new_ones((bs, input_img_h, input_img_w))
|
||||
for img_id in range(bs):
|
||||
img_h, img_w, _ = img_metas[img_id]['img_shape']
|
||||
img_masks[img_id, :img_h, :img_w] = 0
|
||||
|
||||
srcs = []
|
||||
masks = []
|
||||
poss = []
|
||||
for l, src in enumerate(feats):
|
||||
mask = F.interpolate(
|
||||
img_masks[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
|
||||
# position encoding
|
||||
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
|
||||
srcs.append(self.input_proj[l](src))
|
||||
masks.append(mask)
|
||||
poss.append(pos_l)
|
||||
assert mask is not None
|
||||
if self.num_feature_levels > len(srcs):
|
||||
_len_srcs = len(srcs)
|
||||
for l in range(_len_srcs, self.num_feature_levels):
|
||||
if l == _len_srcs:
|
||||
src = self.input_proj[l](feats[-1])
|
||||
else:
|
||||
src = self.input_proj[l](srcs[-1])
|
||||
mask = F.interpolate(
|
||||
img_masks[None].float(),
|
||||
size=src.shape[-2:]).to(torch.bool)[0]
|
||||
# position encoding
|
||||
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
|
||||
srcs.append(src)
|
||||
masks.append(mask)
|
||||
poss.append(pos_l)
|
||||
|
||||
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
||||
srcs, masks, query_embed, poss, tgt, attn_mask)
|
||||
# In case num object=0
|
||||
hs[0] += self.label_enc.weight[0, 0] * 0.0
|
||||
|
||||
# deformable-detr-like anchor update
|
||||
# reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
|
||||
outputs_coord_list = []
|
||||
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
|
||||
zip(reference[:-1], self.bbox_embed, hs)):
|
||||
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
||||
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
|
||||
layer_ref_sig)
|
||||
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
||||
outputs_coord_list.append(layer_outputs_unsig)
|
||||
outputs_coord_list = torch.stack(outputs_coord_list)
|
||||
|
||||
# outputs_class = self.class_embed(hs)
|
||||
outputs_class = torch.stack([
|
||||
layer_cls_embed(layer_hs)
|
||||
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
|
||||
])
|
||||
if self.dn_number > 0 and dn_meta is not None:
|
||||
outputs_class, outputs_coord_list = cdn_post_process(
|
||||
outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss)
|
||||
out = {
|
||||
'pred_logits': outputs_class[-1],
|
||||
'pred_boxes': outputs_coord_list[-1]
|
||||
}
|
||||
|
||||
out['aux_outputs'] = self._set_aux_loss(outputs_class,
|
||||
outputs_coord_list)
|
||||
|
||||
# for encoder output
|
||||
if hs_enc is not None:
|
||||
# prepare intermediate outputs
|
||||
interm_coord = ref_enc[-1]
|
||||
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
|
||||
out['interm_outputs'] = {
|
||||
'pred_logits': interm_class,
|
||||
'pred_boxes': interm_coord
|
||||
}
|
||||
out['interm_outputs_for_matching_pre'] = {
|
||||
'pred_logits': interm_class,
|
||||
'pred_boxes': init_box_proposal
|
||||
}
|
||||
|
||||
out['dn_meta'] = dn_meta
|
||||
|
||||
return out
|
||||
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{
|
||||
'pred_logits': a,
|
||||
'pred_boxes': b
|
||||
} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
# over-write because img_metas are needed as inputs for bbox_head.
|
||||
def forward_train(self, x, img_metas, gt_bboxes, gt_labels):
|
||||
"""Forward function for training mode.
|
||||
Args:
|
||||
x (list[Tensor]): Features from backbone.
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
gt_bboxes (Tensor): Ground truth bboxes of the image,
|
||||
shape (num_gts, 4).
|
||||
gt_labels (Tensor): Ground truth labels of each box,
|
||||
shape (num_gts,).
|
||||
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
|
||||
ignored, shape (num_ignored_gts, 4).
|
||||
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
# prepare ground truth
|
||||
for i in range(len(img_metas)):
|
||||
img_h, img_w, _ = img_metas[i]['img_shape']
|
||||
# DETR regress the relative position of boxes (cxcywh) in the image.
|
||||
# Thus the learning target should be normalized by the image size, also
|
||||
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
||||
factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w,
|
||||
img_h]).unsqueeze(0)
|
||||
gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor
|
||||
|
||||
targets = []
|
||||
for gt_label, gt_bbox in zip(gt_labels, gt_bboxes):
|
||||
targets.append({'labels': gt_label, 'boxes': gt_bbox})
|
||||
|
||||
query_embed, tgt, attn_mask, dn_meta = self.prepare(
|
||||
x, targets=targets, mode='train')
|
||||
|
||||
outputs = self.forward(
|
||||
x,
|
||||
img_metas,
|
||||
query_embed=query_embed,
|
||||
tgt=tgt,
|
||||
attn_mask=attn_mask,
|
||||
dn_meta=dn_meta)
|
||||
|
||||
# Avoid inconsistent num_boxes for set_critertion and dn_critertion
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
|
||||
losses = self.criterion(outputs, targets, num_boxes=num_boxes)
|
||||
losses.update(
|
||||
self.dn_criterion(outputs, targets, len(outputs['aux_outputs']),
|
||||
num_boxes))
|
||||
|
||||
return losses
|
||||
|
||||
def forward_test(self, x, img_metas):
|
||||
query_embed, tgt, attn_mask, dn_meta = self.prepare(x, mode='test')
|
||||
|
||||
outputs = self.forward(
|
||||
x,
|
||||
img_metas,
|
||||
query_embed=query_embed,
|
||||
tgt=tgt,
|
||||
attn_mask=attn_mask,
|
||||
dn_meta=dn_meta)
|
||||
|
||||
ori_shape_list = []
|
||||
for i in range(len(img_metas)):
|
||||
ori_h, ori_w, _ = img_metas[i]['ori_shape']
|
||||
ori_shape_list.append(torch.as_tensor([ori_h, ori_w]))
|
||||
orig_target_sizes = torch.stack(ori_shape_list, dim=0)
|
||||
|
||||
results = self.postprocess(outputs, orig_target_sizes, img_metas)
|
||||
return results
|
|
@ -2,9 +2,9 @@
|
|||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
||||
from .boxes import (batched_nms, bbox2result, bbox_overlaps, bboxes_iou,
|
||||
box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, distance2bbox,
|
||||
generalized_box_iou, postprocess)
|
||||
fp16_clamp, generalized_box_iou)
|
||||
from .generator import MlvlPointGenerator
|
||||
from .matcher import HungarianMatcher
|
||||
from .misc import (accuracy, filter_scores_and_topk, fp16_clamp, interpolate,
|
||||
inverse_sigmoid, output_postprocess, select_single_mlvl)
|
||||
from .set_criterion import SetCriterion
|
||||
from .misc import (accuracy, filter_scores_and_topk,
|
||||
gen_encoder_output_proposals, gen_sineembed_for_position,
|
||||
interpolate, inverse_sigmoid, select_single_mlvl)
|
||||
from .postprocess import DetrPostProcess, output_postprocess, postprocess
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision.ops.boxes import box_area, nms
|
||||
|
||||
from easycv.models.detection.utils.misc import fp16_clamp
|
||||
|
@ -36,55 +33,6 @@ def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
|
|||
return area_i / (area_a[:, None] + area_b - area_i)
|
||||
|
||||
|
||||
# refer to easycv/models/detection/detectors/yolox/postprocess.py and test.py to rebuild a torch-blade-trtplugin NMS, which is checked by zhoulou in test.py
|
||||
# infer docker images is : registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easycv_blade_181_export
|
||||
|
||||
|
||||
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
|
||||
box_corner = prediction.new(prediction.shape)
|
||||
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||||
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||||
prediction[:, :, :4] = box_corner[:, :, :4]
|
||||
|
||||
output = [None for _ in range(len(prediction))]
|
||||
for i, image_pred in enumerate(prediction):
|
||||
|
||||
# If none are remaining => process next image
|
||||
if not image_pred.numel():
|
||||
continue
|
||||
# Get score and class with highest confidence
|
||||
class_conf, class_pred = torch.max(
|
||||
image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
||||
|
||||
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
|
||||
conf_thre).squeeze()
|
||||
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
|
||||
detections = torch.cat(
|
||||
(image_pred[:, :5], class_conf, class_pred.float()), 1)
|
||||
detections = detections[conf_mask]
|
||||
if not detections.numel():
|
||||
continue
|
||||
|
||||
if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
|
||||
nms_out_index = torchvision.ops.batched_nms(
|
||||
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||||
detections[:, 6], nms_thre)
|
||||
else:
|
||||
nms_out_index = torchvision.ops.nms(
|
||||
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||||
nms_thre)
|
||||
|
||||
detections = detections[nms_out_index]
|
||||
if output[i] is None:
|
||||
output[i] = detections
|
||||
else:
|
||||
output[i] = torch.cat((output[i], detections))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def bbox2result(bboxes, labels, num_classes):
|
||||
"""Convert detection results to a list of numpy arrays.
|
||||
Args:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -136,47 +137,6 @@ def filter_scores_and_topk(scores, score_thr, topk, results=None):
|
|||
return scores, labels, keep_idxs, filtered_results
|
||||
|
||||
|
||||
def output_postprocess(outputs, img_metas=None):
|
||||
detection_boxes = []
|
||||
detection_scores = []
|
||||
detection_classes = []
|
||||
img_metas_list = []
|
||||
|
||||
for i in range(len(outputs)):
|
||||
if img_metas:
|
||||
img_metas_list.append(img_metas[i])
|
||||
if outputs[i] is not None:
|
||||
bboxes = outputs[i][:, 0:4] if outputs[i] is not None else None
|
||||
if img_metas:
|
||||
bboxes /= img_metas[i]['scale_factor'][0]
|
||||
detection_boxes.append(bboxes.cpu().numpy())
|
||||
detection_scores.append(
|
||||
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
|
||||
detection_classes.append(outputs[i][:, 6].cpu().numpy().astype(
|
||||
np.int32))
|
||||
else:
|
||||
detection_boxes.append(None)
|
||||
detection_scores.append(None)
|
||||
detection_classes.append(None)
|
||||
|
||||
test_outputs = {
|
||||
'detection_boxes': detection_boxes,
|
||||
'detection_scores': detection_scores,
|
||||
'detection_classes': detection_classes,
|
||||
'img_metas': img_metas_list
|
||||
}
|
||||
|
||||
return test_outputs
|
||||
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-3):
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
|
@ -184,6 +144,114 @@ def inverse_sigmoid(x, eps=1e-3):
|
|||
return torch.log(x1 / x2)
|
||||
|
||||
|
||||
def gen_encoder_output_proposals(memory: Tensor,
|
||||
memory_padding_mask: Tensor,
|
||||
spatial_shapes: Tensor,
|
||||
learnedwh=None):
|
||||
"""
|
||||
Input:
|
||||
- memory: bs, \sum{hw}, d_model
|
||||
- memory_padding_mask: bs, \sum{hw}
|
||||
- spatial_shapes: nlevel, 2
|
||||
- learnedwh: 2
|
||||
Output:
|
||||
- output_memory: bs, \sum{hw}, d_model
|
||||
- output_proposals: bs, \sum{hw}, 4
|
||||
"""
|
||||
N_, S_, C_ = memory.shape
|
||||
base_scale = 4.0
|
||||
proposals = []
|
||||
_cur = 0
|
||||
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||||
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(
|
||||
N_, H_, W_, 1)
|
||||
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
||||
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(
|
||||
0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
||||
torch.linspace(
|
||||
0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
|
||||
grid = torch.cat(
|
||||
[grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
||||
|
||||
scale = torch.cat([valid_W.unsqueeze(-1),
|
||||
valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
|
||||
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||||
|
||||
if learnedwh is not None:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
|
||||
else:
|
||||
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
||||
|
||||
# scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
|
||||
# grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
||||
# wh = torch.ones_like(grid) / scale
|
||||
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
||||
proposals.append(proposal)
|
||||
_cur += (H_ * W_)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
output_proposals = torch.cat(proposals, 1)
|
||||
output_proposals_valid = ((output_proposals > 0.01) &
|
||||
(output_proposals < 0.99)).all(
|
||||
-1, keepdim=True)
|
||||
output_proposals = torch.log(output_proposals /
|
||||
(1 - output_proposals)) # unsigmoid
|
||||
output_proposals = output_proposals.masked_fill(
|
||||
memory_padding_mask.unsqueeze(-1), float('inf'))
|
||||
output_proposals = output_proposals.masked_fill(~output_proposals_valid,
|
||||
float('inf'))
|
||||
|
||||
output_memory = memory
|
||||
output_memory = output_memory.masked_fill(
|
||||
memory_padding_mask.unsqueeze(-1), float(0))
|
||||
output_memory = output_memory.masked_fill(~output_proposals_valid,
|
||||
float(0))
|
||||
|
||||
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
|
||||
# output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
|
||||
|
||||
return output_memory, output_proposals
|
||||
|
||||
|
||||
def gen_sineembed_for_position(pos_tensor):
|
||||
# n_query, bs, _ = pos_tensor.size()
|
||||
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
||||
scale = 2 * math.pi
|
||||
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
||||
dim_t = 10000**(2 * (dim_t // 2) / 128)
|
||||
x_embed = pos_tensor[:, :, 0] * scale
|
||||
y_embed = pos_tensor[:, :, 1] * scale
|
||||
pos_x = x_embed[:, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
|
||||
dim=3).flatten(2)
|
||||
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
|
||||
dim=3).flatten(2)
|
||||
if pos_tensor.size(-1) == 2:
|
||||
pos = torch.cat((pos_y, pos_x), dim=2)
|
||||
elif pos_tensor.size(-1) == 4:
|
||||
w_embed = pos_tensor[:, :, 2] * scale
|
||||
pos_w = w_embed[:, :, None] / dim_t
|
||||
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()),
|
||||
dim=3).flatten(2)
|
||||
|
||||
h_embed = pos_tensor[:, :, 3] * scale
|
||||
pos_h = h_embed[:, :, None] / dim_t
|
||||
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()),
|
||||
dim=3).flatten(2)
|
||||
|
||||
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
||||
else:
|
||||
raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
|
||||
pos_tensor.size(-1)))
|
||||
return pos
|
||||
|
||||
|
||||
class SigmoidGeometricMean(Function):
|
||||
"""Forward and backward function of geometric mean of two sigmoid
|
||||
functions.
|
||||
|
@ -211,3 +279,11 @@ class SigmoidGeometricMean(Function):
|
|||
|
||||
|
||||
sigmoid_geometric_mean = SigmoidGeometricMean.apply
|
||||
|
||||
|
||||
def fp16_clamp(x, min=None, max=None):
|
||||
if not x.is_cuda and x.dtype == torch.float16:
|
||||
# clamp for cpu float16, tensor fp16 has no clamp implementation
|
||||
return x.float().clamp(min, max).half()
|
||||
|
||||
return x.clamp(min, max)
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from easycv.models.detection.utils import box_cxcywh_to_xyxy
|
||||
|
||||
|
||||
class DetrPostProcess(nn.Module):
|
||||
""" This module converts the model's output into the format expected by the coco api"""
|
||||
|
||||
def __init__(self, num_select=None) -> None:
|
||||
super().__init__()
|
||||
self.num_select = num_select
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, target_sizes, img_metas):
|
||||
""" Perform the computation
|
||||
Parameters:
|
||||
outputs: raw outputs of the model
|
||||
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
||||
For evaluation, this must be the original image size (before any data augmentation)
|
||||
For visualization, this should be the image size after data augment, but before padding
|
||||
"""
|
||||
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
||||
|
||||
assert len(out_logits) == len(target_sizes)
|
||||
assert target_sizes.shape[1] == 2
|
||||
|
||||
if self.num_select is None:
|
||||
prob = F.softmax(out_logits, -1)
|
||||
scores, labels = prob[..., :-1].max(-1)
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
else:
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(
|
||||
prob.view(out_logits.shape[0], -1), self.num_select, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
boxes = torch.gather(boxes, 1,
|
||||
topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = {
|
||||
'detection_boxes': [boxes[0].cpu().numpy()],
|
||||
'detection_scores': [scores[0].cpu().numpy()],
|
||||
'detection_classes': [labels[0].cpu().numpy().astype(np.int32)],
|
||||
'img_metas': img_metas
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def output_postprocess(outputs, img_metas=None):
|
||||
detection_boxes = []
|
||||
detection_scores = []
|
||||
detection_classes = []
|
||||
img_metas_list = []
|
||||
|
||||
for i in range(len(outputs)):
|
||||
if img_metas:
|
||||
img_metas_list.append(img_metas[i])
|
||||
if outputs[i] is not None:
|
||||
bboxes = outputs[i][:, 0:4] if outputs[i] is not None else None
|
||||
if img_metas:
|
||||
bboxes /= img_metas[i]['scale_factor'][0]
|
||||
detection_boxes.append(bboxes.cpu().numpy())
|
||||
detection_scores.append(
|
||||
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
|
||||
detection_classes.append(outputs[i][:, 6].cpu().numpy().astype(
|
||||
np.int32))
|
||||
else:
|
||||
detection_boxes.append(None)
|
||||
detection_scores.append(None)
|
||||
detection_classes.append(None)
|
||||
|
||||
test_outputs = {
|
||||
'detection_boxes': detection_boxes,
|
||||
'detection_scores': detection_scores,
|
||||
'detection_classes': detection_classes,
|
||||
'img_metas': img_metas_list
|
||||
}
|
||||
|
||||
return test_outputs
|
||||
|
||||
|
||||
# refer to easycv/models/detection/detectors/yolox/postprocess.py and test.py to rebuild a torch-blade-trtplugin NMS, which is checked by zhoulou in test.py
|
||||
# infer docker images is : registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easycv_blade_181_export
|
||||
|
||||
|
||||
def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
|
||||
box_corner = prediction.new(prediction.shape)
|
||||
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||||
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||||
prediction[:, :, :4] = box_corner[:, :, :4]
|
||||
|
||||
output = [None for _ in range(len(prediction))]
|
||||
for i, image_pred in enumerate(prediction):
|
||||
|
||||
# If none are remaining => process next image
|
||||
if not image_pred.numel():
|
||||
continue
|
||||
# Get score and class with highest confidence
|
||||
class_conf, class_pred = torch.max(
|
||||
image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
||||
|
||||
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
|
||||
conf_thre).squeeze()
|
||||
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
|
||||
detections = torch.cat(
|
||||
(image_pred[:, :5], class_conf, class_pred.float()), 1)
|
||||
detections = detections[conf_mask]
|
||||
if not detections.numel():
|
||||
continue
|
||||
|
||||
if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
|
||||
nms_out_index = torchvision.ops.batched_nms(
|
||||
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||||
detections[:, 6], nms_thre)
|
||||
else:
|
||||
nms_out_index = torchvision.ops.nms(
|
||||
detections[:, :4], detections[:, 4] * detections[:, 5],
|
||||
nms_thre)
|
||||
|
||||
detections = detections[nms_out_index]
|
||||
if output[i] is None:
|
||||
output[i] = detections
|
||||
else:
|
||||
output[i] = torch.cat((output[i], detections))
|
||||
|
||||
return output
|
|
@ -4,3 +4,5 @@ from .focal_loss import FocalLoss, VarifocalLoss
|
|||
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
|
||||
from .mse_loss import JointsMSELoss
|
||||
from .pytorch_metric_learning import *
|
||||
from .set_criterion import (CDNCriterion, DNCriterion, HungarianMatcher,
|
||||
SetCriterion)
|
||||
|
|
|
@ -150,9 +150,8 @@ class VarifocalLoss(nn.Module):
|
|||
return loss_cls
|
||||
|
||||
|
||||
# This method is only for debugging
|
||||
def py_sigmoid_focal_loss(pred,
|
||||
target,
|
||||
def py_sigmoid_focal_loss(inputs,
|
||||
targets,
|
||||
weight=None,
|
||||
gamma=2.0,
|
||||
alpha=0.25,
|
||||
|
@ -161,9 +160,9 @@ def py_sigmoid_focal_loss(pred,
|
|||
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the
|
||||
inputs (torch.Tensor): The prediction with shape (N, C), C is the
|
||||
number of classes
|
||||
target (torch.Tensor): The learning label of the prediction.
|
||||
targets (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
gamma (float, optional): The gamma for calculating the modulating
|
||||
factor. Defaults to 2.0.
|
||||
|
@ -174,13 +173,15 @@ def py_sigmoid_focal_loss(pred,
|
|||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
pred_sigmoid = pred.sigmoid()
|
||||
target = target.type_as(pred)
|
||||
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
|
||||
focal_weight = (alpha * target + (1 - alpha) *
|
||||
(1 - target)) * pt.pow(gamma)
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred, target, reduction='none') * focal_weight
|
||||
prob = inputs.sigmoid()
|
||||
ce_loss = F.binary_cross_entropy_with_logits(
|
||||
inputs, targets, reduction='none')
|
||||
p_t = prob * targets + (1 - prob) * (1 - targets)
|
||||
loss = ce_loss * ((1 - p_t)**gamma)
|
||||
|
||||
if alpha >= 0:
|
||||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||||
loss = alpha_t * loss
|
||||
if weight is not None:
|
||||
if weight.shape != loss.shape:
|
||||
if weight.size(0) == loss.size(0):
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .matcher import HungarianMatcher
|
||||
from .set_criterion import CDNCriterion, DNCriterion, SetCriterion
|
|
@ -22,8 +22,7 @@ class SetCriterion(nn.Module):
|
|||
weight_dict,
|
||||
losses,
|
||||
eos_coef=None,
|
||||
loss_class_type='ce',
|
||||
dn_components=None):
|
||||
loss_class_type='ce'):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
num_classes: number of object categories, omitting the special no-object category
|
||||
|
@ -41,8 +40,6 @@ class SetCriterion(nn.Module):
|
|||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = eos_coef
|
||||
self.register_buffer('empty_weight', empty_weight)
|
||||
if dn_components is not None:
|
||||
self.dn_criterion = DNCriterion(self.weight_dict)
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (Binary focal loss)
|
||||
|
@ -63,8 +60,7 @@ class SetCriterion(nn.Module):
|
|||
|
||||
if self.loss_class_type == 'ce':
|
||||
loss_ce = F.cross_entropy(
|
||||
src_logits.transpose(1, 2), target_classes,
|
||||
self.empty_weight) * self.weight_dict['loss_ce']
|
||||
src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||||
elif self.loss_class_type == 'focal_loss':
|
||||
target_classes_onehot = torch.zeros([
|
||||
src_logits.shape[0], src_logits.shape[1],
|
||||
|
@ -78,12 +74,11 @@ class SetCriterion(nn.Module):
|
|||
|
||||
loss_ce = py_sigmoid_focal_loss(
|
||||
src_logits,
|
||||
target_classes_onehot.long(),
|
||||
target_classes_onehot,
|
||||
alpha=0.25,
|
||||
gamma=2,
|
||||
reduction='none').mean(1).sum() / num_boxes
|
||||
loss_ce = loss_ce * src_logits.shape[1] * self.weight_dict[
|
||||
'loss_ce']
|
||||
loss_ce = loss_ce * src_logits.shape[1]
|
||||
losses = {'loss_ce': loss_ce}
|
||||
|
||||
if log:
|
||||
|
@ -122,15 +117,13 @@ class SetCriterion(nn.Module):
|
|||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['loss_bbox'] = loss_bbox.sum(
|
||||
) / num_boxes * self.weight_dict['loss_bbox']
|
||||
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes),
|
||||
box_cxcywh_to_xyxy(target_boxes)))
|
||||
losses['loss_giou'] = loss_giou.sum(
|
||||
) / num_boxes * self.weight_dict['loss_giou']
|
||||
losses['loss_giou'] = loss_giou.sum() / num_boxes
|
||||
|
||||
return losses
|
||||
|
||||
|
@ -157,7 +150,7 @@ class SetCriterion(nn.Module):
|
|||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||||
|
||||
def forward(self, outputs, targets, mask_dict=None, return_indices=False):
|
||||
def forward(self, outputs, targets, num_boxes=None, return_indices=False):
|
||||
""" This performs the loss computation.
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
|
@ -178,20 +171,26 @@ class SetCriterion(nn.Module):
|
|||
indices0_copy = indices
|
||||
indices_list = []
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
if num_boxes is None:
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(
|
||||
outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(
|
||||
self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
||||
l_dict = {
|
||||
k: v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if 'aux_outputs' in outputs:
|
||||
|
@ -209,17 +208,35 @@ class SetCriterion(nn.Module):
|
|||
kwargs = {'log': False}
|
||||
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||||
num_boxes, **kwargs)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
l_dict = {
|
||||
k + f'_{i}': v *
|
||||
(self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
|
||||
if mask_dict is not None:
|
||||
# dn loss computation
|
||||
aux_num = 0
|
||||
if 'aux_outputs' in outputs:
|
||||
aux_num = len(outputs['aux_outputs'])
|
||||
dn_losses = self.dn_criterion(mask_dict, self.training, aux_num,
|
||||
0.25)
|
||||
losses.update(dn_losses)
|
||||
# interm_outputs loss
|
||||
if 'interm_outputs' in outputs:
|
||||
interm_outputs = outputs['interm_outputs']
|
||||
indices = self.matcher(interm_outputs, targets)
|
||||
if return_indices:
|
||||
indices_list.append(indices)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
kwargs = {'log': False}
|
||||
l_dict = self.get_loss(loss, interm_outputs, targets, indices,
|
||||
num_boxes, **kwargs)
|
||||
l_dict = {
|
||||
k + '_interm':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
|
||||
if return_indices:
|
||||
indices_list.append(indices0_copy)
|
||||
|
@ -228,6 +245,120 @@ class SetCriterion(nn.Module):
|
|||
return losses
|
||||
|
||||
|
||||
class CDNCriterion(SetCriterion):
|
||||
""" This class computes the loss for Conditional DETR.
|
||||
The process happens in two steps:
|
||||
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||||
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
matcher,
|
||||
weight_dict,
|
||||
losses,
|
||||
eos_coef=None,
|
||||
loss_class_type='ce'):
|
||||
super().__init__(
|
||||
num_classes=num_classes,
|
||||
matcher=matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=losses,
|
||||
eos_coef=eos_coef,
|
||||
loss_class_type=loss_class_type)
|
||||
|
||||
def prep_for_dn(self, dn_meta):
|
||||
output_known_lbs_bboxes = dn_meta['output_known_lbs_bboxes']
|
||||
num_dn_groups, pad_size = dn_meta['num_dn_group'], dn_meta['pad_size']
|
||||
assert pad_size % num_dn_groups == 0
|
||||
single_pad = pad_size // num_dn_groups
|
||||
|
||||
return output_known_lbs_bboxes, single_pad, num_dn_groups
|
||||
|
||||
def forward(self, outputs, targets, aux_num, num_boxes):
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
|
||||
dn_meta = outputs['dn_meta']
|
||||
losses = {}
|
||||
if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
|
||||
output_known_lbs_bboxes, single_pad, scalar = self.prep_for_dn(
|
||||
dn_meta)
|
||||
|
||||
dn_pos_idx = []
|
||||
dn_neg_idx = []
|
||||
for i in range(len(targets)):
|
||||
if len(targets[i]['labels']) > 0:
|
||||
t = torch.range(0,
|
||||
len(targets[i]['labels']) -
|
||||
1).long().cuda()
|
||||
t = t.unsqueeze(0).repeat(scalar, 1)
|
||||
tgt_idx = t.flatten()
|
||||
output_idx = (torch.tensor(range(scalar)) *
|
||||
single_pad).long().cuda().unsqueeze(1) + t
|
||||
output_idx = output_idx.flatten()
|
||||
else:
|
||||
output_idx = tgt_idx = torch.tensor([]).long().cuda()
|
||||
|
||||
dn_pos_idx.append((output_idx, tgt_idx))
|
||||
dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
|
||||
|
||||
output_known_lbs_bboxes = dn_meta['output_known_lbs_bboxes']
|
||||
l_dict = {}
|
||||
for loss in self.losses:
|
||||
kwargs = {}
|
||||
if 'labels' in loss:
|
||||
kwargs = {'log': False}
|
||||
l_dict.update(
|
||||
self.get_loss(loss, output_known_lbs_bboxes, targets,
|
||||
dn_pos_idx, num_boxes * scalar, **kwargs))
|
||||
|
||||
l_dict = {
|
||||
k + '_dn':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
losses.update(l_dict)
|
||||
|
||||
for i in range(aux_num):
|
||||
if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
|
||||
aux_outputs_known = output_known_lbs_bboxes['aux_outputs'][i]
|
||||
l_dict = {}
|
||||
for loss in self.losses:
|
||||
kwargs = {}
|
||||
if 'labels' in loss:
|
||||
kwargs = {'log': False}
|
||||
|
||||
l_dict.update(
|
||||
self.get_loss(loss, aux_outputs_known, targets,
|
||||
dn_pos_idx, num_boxes * scalar,
|
||||
**kwargs))
|
||||
|
||||
l_dict = {
|
||||
k + f'_dn_{i}':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict = {
|
||||
k + f'_{i}':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
return losses
|
||||
|
||||
|
||||
class DNCriterion(nn.Module):
|
||||
""" This class computes the loss for Conditional DETR.
|
||||
The process happens in two steps:
|
||||
|
@ -281,21 +412,19 @@ class DNCriterion(nn.Module):
|
|||
"""
|
||||
if len(tgt_boxes) == 0:
|
||||
return {
|
||||
'tgt_loss_bbox': torch.as_tensor(0.).to('cuda'),
|
||||
'tgt_loss_giou': torch.as_tensor(0.).to('cuda'),
|
||||
'loss_bbox': torch.as_tensor(0.).to('cuda'),
|
||||
'loss_giou': torch.as_tensor(0.).to('cuda'),
|
||||
}
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['tgt_loss_bbox'] = loss_bbox.sum(
|
||||
) / num_tgt * self.weight_dict['loss_bbox']
|
||||
losses['loss_bbox'] = loss_bbox.sum() / num_tgt
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(tgt_boxes)))
|
||||
losses['tgt_loss_giou'] = loss_giou.sum(
|
||||
) / num_tgt * self.weight_dict['loss_giou']
|
||||
losses['loss_giou'] = loss_giou.sum() / num_tgt
|
||||
return losses
|
||||
|
||||
def tgt_loss_labels(self,
|
||||
|
@ -309,8 +438,8 @@ class DNCriterion(nn.Module):
|
|||
"""
|
||||
if len(tgt_labels_) == 0:
|
||||
return {
|
||||
'tgt_loss_ce': torch.as_tensor(0.).to('cuda'),
|
||||
'tgt_class_error': torch.as_tensor(0.).to('cuda'),
|
||||
'loss_ce': torch.as_tensor(0.).to('cuda'),
|
||||
'class_error': torch.as_tensor(0.).to('cuda'),
|
||||
}
|
||||
|
||||
src_logits, tgt_labels = src_logits_.unsqueeze(
|
||||
|
@ -327,62 +456,78 @@ class DNCriterion(nn.Module):
|
|||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = py_sigmoid_focal_loss(
|
||||
src_logits,
|
||||
target_classes_onehot.long(),
|
||||
target_classes_onehot,
|
||||
alpha=focal_alpha,
|
||||
gamma=2,
|
||||
reduction='none').mean(1).sum(
|
||||
) / num_tgt * src_logits.shape[1] * self.weight_dict['loss_ce']
|
||||
reduction='none').mean(1).sum() / num_tgt * src_logits.shape[1]
|
||||
|
||||
losses = {'tgt_loss_ce': loss_ce}
|
||||
losses = {'loss_ce': loss_ce}
|
||||
if log:
|
||||
losses['tgt_class_error'] = 100 - accuracy(src_logits_,
|
||||
tgt_labels_)[0]
|
||||
losses['class_error'] = 100 - accuracy(src_logits_, tgt_labels_)[0]
|
||||
return losses
|
||||
|
||||
def forward(self, mask_dict, training, aux_num, focal_alpha):
|
||||
def forward(self, mask_dict, aux_num):
|
||||
"""
|
||||
compute dn loss in criterion
|
||||
Args:
|
||||
mask_dict: a dict for dn information
|
||||
training: training or inference flag
|
||||
aux_num: aux loss number
|
||||
focal_alpha: for focal loss
|
||||
"""
|
||||
losses = {}
|
||||
if training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
if self.training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt = self.prepare_for_loss(
|
||||
mask_dict)
|
||||
losses.update(
|
||||
self.tgt_loss_labels(output_known_class[-1], known_labels,
|
||||
num_tgt, focal_alpha))
|
||||
losses.update(
|
||||
self.tgt_loss_boxes(output_known_coord[-1], known_bboxs,
|
||||
num_tgt))
|
||||
l_dict = self.tgt_loss_labels(output_known_class[-1], known_labels,
|
||||
num_tgt, 0.25)
|
||||
l_dict = {
|
||||
k + '_dn':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
l_dict = self.tgt_loss_boxes(output_known_coord[-1], known_bboxs,
|
||||
num_tgt)
|
||||
l_dict = {
|
||||
k + '_dn':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
else:
|
||||
losses['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_class_error'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
|
||||
if aux_num:
|
||||
for i in range(aux_num):
|
||||
# dn aux loss
|
||||
if training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
if self.training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
l_dict = self.tgt_loss_labels(output_known_class[i],
|
||||
known_labels, num_tgt,
|
||||
focal_alpha)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
known_labels, num_tgt, 0.25)
|
||||
l_dict = {
|
||||
k + f'_dn_{i}': v *
|
||||
(self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
l_dict = self.tgt_loss_boxes(output_known_coord[i],
|
||||
known_bboxs, num_tgt)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
l_dict = {
|
||||
k + f'_dn_{i}': v *
|
||||
(self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_class_error'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict = {
|
||||
k + f'_{i}': v *
|
||||
(self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
for k, v in l_dict.items()
|
||||
}
|
||||
losses.update(l_dict)
|
||||
return losses
|
|
@ -1,20 +1,15 @@
|
|||
import copy
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, normal_, uniform_, xavier_uniform_
|
||||
from torch.nn.init import normal_
|
||||
|
||||
from .transformer_decoder import PositionEmbeddingSine, _get_activation_fn
|
||||
|
||||
try:
|
||||
from thirdparty.deformable_transformer.modules import MSDeformAttn
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
@ -115,6 +110,7 @@ class MSDeformAttnTransformerEncoderOnly(nn.Module):
|
|||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
for m in self.modules():
|
||||
from thirdparty.deformable_attention.modules import MSDeformAttn
|
||||
if isinstance(m, MSDeformAttn):
|
||||
m._reset_parameters()
|
||||
normal_(self.level_embed)
|
||||
|
@ -180,6 +176,7 @@ class MSDeformAttnTransformerEncoderLayer(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
# self attention
|
||||
from thirdparty.deformable_attention.modules import MSDeformAttn
|
||||
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
|
|
@ -15,8 +15,9 @@ from .scale import Scale
|
|||
# from .weight_init import (bias_init_with_prob, kaiming_init, normal_init,
|
||||
# uniform_init, xavier_init)
|
||||
from .sobel import Sobel
|
||||
from .transformer import (MLP, TransformerEncoder, TransformerEncoderLayer,
|
||||
_get_activation_fn, _get_clones)
|
||||
from .transformer import (MLP, DropPath, Mlp, TransformerEncoder,
|
||||
TransformerEncoderLayer, _get_activation_fn,
|
||||
_get_clones)
|
||||
|
||||
# __all__ = [
|
||||
# 'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule',
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
@ -22,6 +23,63 @@ class MLP(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" Multilayer perceptron.
|
||||
Parameters:
|
||||
act_layer: Specify the activate function, default use nn.GELU.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (
|
||||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -110,8 +168,11 @@ class TransformerEncoderLayer(nn.Module):
|
|||
return src
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
def _get_clones(module, N, layer_share=False):
|
||||
if layer_share:
|
||||
return nn.ModuleList([module for i in range(N)])
|
||||
else:
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
|
|
|
@ -169,12 +169,9 @@ class EVRunner(EpochBasedRunner):
|
|||
param groups. If the runner has a dict of optimizers, this
|
||||
method will return a dict.
|
||||
"""
|
||||
# add interface to selfdefine current_lr_fn for lr_hook
|
||||
# so that runner can logging correct lrs
|
||||
if hasattr(self, 'current_lr_fn'):
|
||||
lr = self.current_lr_fn(self.optimizer)
|
||||
elif isinstance(self.optimizer, torch.optim.Optimizer):
|
||||
lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||
if isinstance(self.optimizer, torch.optim.Optimizer):
|
||||
lr = sorted([group['lr'] for group in self.optimizer.param_groups],
|
||||
reverse=True) # avoid lr display error
|
||||
elif isinstance(self.optimizer, dict):
|
||||
lr = dict()
|
||||
for name, optim in self.optimizer.items():
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import data_parallel as mm_data_parallel
|
||||
|
@ -134,3 +136,45 @@ def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
|
|||
torch.split(flatten_tensor, tensor_numels), tensor_shapes)
|
||||
]
|
||||
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
|
||||
|
||||
|
||||
def get_device():
|
||||
"""Returns an available device, cpu, cuda."""
|
||||
is_device_available = {'cuda': torch.cuda.is_available()}
|
||||
device_list = [k for k, v in is_device_available.items() if v]
|
||||
return device_list[0] if len(device_list) == 1 else 'cpu'
|
||||
|
||||
|
||||
def sync_random_seed(seed=None, device='cuda'):
|
||||
"""Make sure different ranks share the same seed.
|
||||
All workers must call this function, otherwise it will deadlock.
|
||||
This method is generally used in `DistributedSampler`,
|
||||
because the seed should be identical across all processes
|
||||
in the distributed group.
|
||||
In distributed sampling, different ranks should sample non-overlapped
|
||||
data in the dataset. Therefore, this function is used to make sure that
|
||||
each rank shuffles the data indices in the same order based
|
||||
on the same seed. Then different ranks could use different indices
|
||||
to select non-overlapped data from the same data list.
|
||||
Args:
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
device (str): The device where the seed will be put on.
|
||||
Default to 'cuda'.
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = np.random.randint(2**31)
|
||||
assert isinstance(seed, int)
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
||||
|
|
|
@ -11,7 +11,7 @@ class EfficientFormerTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_vitdet(self):
|
||||
def test_efficientformer(self):
|
||||
model = EfficientFormer(
|
||||
layers=[3, 2, 6, 4],
|
||||
embed_dims=[48, 96, 224, 448],
|
||||
|
|
|
@ -237,6 +237,81 @@ class DETRTest(unittest.TestCase):
|
|||
]]),
|
||||
decimal=1)
|
||||
|
||||
def test_dino(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_r50_36e/epoch_29.pth'
|
||||
config_path = 'configs/detection/dino/dino_4sc_r50_36e_coco.py'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
dino = DetrPredictor(model_path, config_path)
|
||||
output = dino.predict(img)
|
||||
dino.visualize(img, output, out_file=None)
|
||||
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('img_metas', output)
|
||||
self.assertEqual(len(output['detection_boxes'][0]), 300)
|
||||
self.assertEqual(len(output['detection_scores'][0]), 300)
|
||||
self.assertEqual(len(output['detection_classes'][0]), 300)
|
||||
|
||||
self.assertListEqual(
|
||||
output['detection_classes'][0][:10].tolist(),
|
||||
np.array([13, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'][0][:10],
|
||||
np.array([
|
||||
0.8808171153068542, 0.8584598898887634, 0.8214247226715088,
|
||||
0.8156911134719849, 0.7707086801528931, 0.6717984080314636,
|
||||
0.6578451991081238, 0.6269607543945312, 0.6063129901885986,
|
||||
0.5223093628883362
|
||||
],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'][0][:10],
|
||||
np.array([[
|
||||
222.15492248535156, 175.9025421142578, 456.3177490234375,
|
||||
382.48211669921875
|
||||
],
|
||||
[
|
||||
295.12115478515625, 115.97019958496094,
|
||||
378.97119140625, 150.2149658203125
|
||||
],
|
||||
[
|
||||
190.94241333007812, 108.94568634033203,
|
||||
298.280517578125, 155.6221160888672
|
||||
],
|
||||
[
|
||||
167.8346405029297, 109.49150085449219,
|
||||
211.50537109375, 140.08895874023438
|
||||
],
|
||||
[
|
||||
482.0719909667969, 110.47320556640625,
|
||||
523.1851806640625, 130.19410705566406
|
||||
],
|
||||
[
|
||||
609.3395385742188, 113.26068115234375,
|
||||
635.8460083007812, 136.93771362304688
|
||||
],
|
||||
[
|
||||
266.5657958984375, 105.04171752929688,
|
||||
326.9735107421875, 127.39012145996094
|
||||
],
|
||||
[
|
||||
431.43096923828125, 105.18028259277344,
|
||||
484.13787841796875, 131.9821319580078
|
||||
],
|
||||
[
|
||||
60.43342971801758, 94.02497100830078,
|
||||
86.346435546875, 106.31623840332031
|
||||
],
|
||||
[
|
||||
139.32015991210938, 96.0668716430664,
|
||||
167.1505126953125, 105.44377899169922
|
||||
]]),
|
||||
decimal=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -21,8 +21,8 @@ try:
|
|||
except ModuleNotFoundError as e:
|
||||
info_string = (
|
||||
'\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n'
|
||||
'\t`cd mask2former/modeling/pixel_decoder/ops`\n'
|
||||
'\t`sh make.sh`\n')
|
||||
'\t`cd thirdparty/deformable_attention`\n'
|
||||
'\t`python setup.py build install`\n')
|
||||
raise ModuleNotFoundError(info_string)
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ def _is_power_of_2(n):
|
|||
|
||||
class MSDeformAttn(nn.Module):
|
||||
|
||||
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
||||
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, im2col_step=128):
|
||||
"""
|
||||
Multi-Scale Deformable Attention Module
|
||||
:param d_model hidden dimension
|
||||
|
@ -52,7 +52,7 @@ class MSDeformAttn(nn.Module):
|
|||
"You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
||||
'which is more efficient in our CUDA implementation.')
|
||||
|
||||
self.im2col_step = 128
|
||||
self.im2col_step = im2col_step
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_levels = n_levels
|
||||
|
@ -140,11 +140,21 @@ class MSDeformAttn(nn.Module):
|
|||
'Last dim of reference_points must be 2 or 4, but get {} instead.'
|
||||
.format(reference_points.shape[-1]))
|
||||
try:
|
||||
output = MSDeformAttnFunction.apply(value, input_spatial_shapes,
|
||||
input_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step)
|
||||
# for amp
|
||||
if value.dtype == torch.float16:
|
||||
# for mixed precision
|
||||
output = MSDeformAttnFunction.apply(
|
||||
value.to(torch.float32),
|
||||
input_spatial_shapes, input_level_start_index,
|
||||
sampling_locations.to(torch.float32), attention_weights,
|
||||
self.im2col_step)
|
||||
output = output.to(torch.float16)
|
||||
else:
|
||||
output = MSDeformAttnFunction.apply(value, input_spatial_shapes,
|
||||
input_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step)
|
||||
except:
|
||||
# CPU
|
||||
output = ms_deform_attn_core_pytorch(value, input_spatial_shapes,
|
|
@ -110,5 +110,6 @@ if __name__ == '__main__':
|
|||
check_forward_equal_with_pytorch_double()
|
||||
check_forward_equal_with_pytorch_float()
|
||||
|
||||
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
||||
# If out of memory occurs, reduce the number of channels
|
||||
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
||||
check_gradient_numerical(channels, True, True, True)
|
|
@ -23,10 +23,11 @@ if is_torchacc_enabled():
|
|||
import time
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import init_dist
|
||||
|
||||
from easycv import __version__
|
||||
from easycv.apis import set_random_seed, train_model
|
||||
from easycv.apis import init_random_seed, set_random_seed, train_model
|
||||
from easycv.datasets import build_dataloader, build_dataset
|
||||
from easycv.datasets.utils import is_dali_dataset_type
|
||||
from easycv.file import io
|
||||
|
@ -37,6 +38,7 @@ from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
|||
from easycv.utils.config_tools import traverse_replace
|
||||
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
||||
mmcv_config_fromfile, rebuild_config)
|
||||
from easycv.utils.dist_utils import get_device
|
||||
from easycv.utils.setup_env import setup_multi_processes
|
||||
|
||||
|
||||
|
@ -61,6 +63,10 @@ def parse_args():
|
|||
help='number of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--diff-seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument('--fp16', action='store_true', help='use fp16')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
|
@ -207,15 +213,18 @@ def main():
|
|||
logger.info('GPU INFO : {}'.format(torch.cuda.get_device_name(0)))
|
||||
|
||||
# set random seeds
|
||||
# Using different seeds for different ranks may reduce accuracy
|
||||
seed = init_random_seed(args.seed, device=get_device())
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
if is_torchacc_enabled():
|
||||
assert args.seed is not None, 'Must provide `seed` to sync model initializer if use torchacc!'
|
||||
assert seed is not None, 'Must provide `seed` to sync model initializer if use torchacc!'
|
||||
|
||||
if args.seed is not None:
|
||||
if seed is not None:
|
||||
logger.info('Set random seed to {}, deterministic: {}'.format(
|
||||
args.seed, args.deterministic))
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = args.seed
|
||||
meta['seed'] = args.seed
|
||||
seed, args.deterministic))
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
|
||||
if args.pretrained is not None:
|
||||
assert isinstance(args.pretrained, str)
|
||||
|
|
Loading…
Reference in New Issue