[Feature] Add InternImage Classification project (#1569)
* [Feature] add internimage project * [Feature] add internimage project * update license * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * update license * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * update internimage configs * support internimage project * support internimage project * support internimage project * internimagepull/1644/head
parent
8e9e880601
commit
3eaf719a64
|
@ -78,6 +78,12 @@ class ImageClassifier(BaseClassifier):
|
|||
self.neck = neck
|
||||
self.head = head
|
||||
|
||||
# If the model needs to load pretrain weights from a third party,
|
||||
# the key can be modified with this hook
|
||||
if hasattr(self.backbone, '_checkpoint_filter'):
|
||||
self._register_load_state_dict_pre_hook(
|
||||
self.backbone._checkpoint_filter)
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# InternImage Classification
|
||||
|
||||
## Description
|
||||
|
||||
This is the implementation of [InternImage](https://arxiv.org/abs/2211.05778) for image classification.
|
||||
|
||||
## Usage
|
||||
|
||||
### Setup Environment
|
||||
|
||||
Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) documentation of MMPretrain to finish installation.
|
||||
|
||||
Please install DCNv3. Run the command below following the [ InternImage official installation instructions](https://github.com/OpenGVLab/InternImage/blob/master/classification/README.md).
|
||||
|
||||
```shell
|
||||
cd ops_dcnv3
|
||||
sh ./make.sh
|
||||
```
|
||||
|
||||
### Training and Test Commands
|
||||
|
||||
At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/internimage_classification/` root directory, please run command below to add it.
|
||||
|
||||
```shell
|
||||
export PYTHONPATH=`pwd`:$PYTHONPATH
|
||||
```
|
||||
|
||||
#### Training
|
||||
|
||||
##### On Local Single GPU
|
||||
|
||||
```bash
|
||||
# train with mim
|
||||
mim train mmpretrain ${CONFIG} --work-dir ${WORK_DIR}
|
||||
|
||||
# a specific command example
|
||||
mim train mmpretrain configs/internimage-tiny_8xb128_in1k-224.py \
|
||||
--work-dir work_dirs/internimage-tiny_8xb128_in1k-224/
|
||||
```
|
||||
|
||||
##### On Multiple GPUs
|
||||
|
||||
```bash
|
||||
# train with mim
|
||||
mim train mmpretrain ${CONFIG} \
|
||||
--work-dir ${WORK_DIR} \
|
||||
--launcher pytorch --gpus 8
|
||||
```
|
||||
|
||||
##### On Multiple GPUs with Slurm
|
||||
|
||||
```bash
|
||||
# train with mim
|
||||
mim train mmpretrain ${CONFIG} \
|
||||
--work-dir ${WORK_DIR} \
|
||||
--launcher slurm --gpus 16 --gpus-per-node 8 \
|
||||
--partition ${PARTITION}
|
||||
```
|
||||
|
||||
#### Test
|
||||
|
||||
Please download the pretrain weight provided by [OpenGVLab](https://github.com/OpenGVLab/) from [here](https://huggingface.co/OpenGVLab/InternImage/tree/main)
|
||||
|
||||
##### On Local Single GPU
|
||||
|
||||
```bash
|
||||
# test with mim
|
||||
mim test mmpretrain ${CONFIG} -C ${CHECKPOINT}
|
||||
|
||||
# a specific command example
|
||||
mim test mmpretrain configs/internimage-tiny_8xb128_in1k-224.py -C /PATH/TO/internimage_t_1k_224.pth
|
||||
```
|
||||
|
||||
##### On Multiple GPUs
|
||||
|
||||
```bash
|
||||
# test with mim
|
||||
# a specific command examples, 8 GPUs here
|
||||
mim test mmpretrain configs/internimage_t_1k_224.py \
|
||||
-C /PATH/TO/internimage_t_1k_224.pth \
|
||||
--launcher pytorch --gpus 8
|
||||
```
|
||||
|
||||
##### On Multiple GPUs with Slurm
|
||||
|
||||
```bash
|
||||
# test with mim
|
||||
mim test mmpretrain ${CONFIG} \
|
||||
-C ${CHECKPOINT}
|
||||
--work-dir ${WORK_DIR} \
|
||||
--launcher slurm --gpus 8 --gpus-per-node 8 \
|
||||
--partition ${PARTITION} \
|
||||
$PY_ARGS
|
||||
```
|
||||
|
||||
Note: `PY_ARGS` is other optional args.
|
||||
|
||||
## Results on ImageNet1K
|
||||
|
||||
The accuracy of different models on ImageNet1K,
|
||||
|
||||
| name | resolution | acc@1 | acc@5 | config | weight |
|
||||
| :------------: | :--------: | :-----: | :-----: | :-------------------------------------------------------: | :-----------------------------------------------------------------------------------------------: |
|
||||
| InternImage-T | 224 | 83.4700 | 96.5340 | [config](./configs/internimage-tiny_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_t_1k_224.pth) |
|
||||
| InternImage-S | 224 | 84.1640 | 96.9320 | [config](./configs/internimage-small_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_s_1k_224.pth) |
|
||||
| InternImage-B | 224 | 84.8660 | 97.1820 | [config](./configs/internimage-base_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_b_1k_224.pth) |
|
||||
| InternImage-L | 384 | 87.7060 | 98.3820 | [config](./configs/internimage-large_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_l_22kto1k_384.pth) |
|
||||
| InternImage-XL | 384 | 88.0460 | 98.5620 | [config](./configs/internimage-xlagre_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_xl_22kto1k_384.pth) |
|
||||
| InternImage-H | 640 | 89.5500 | 98.8500 | [config](./configs/internimage-huge_8xb128_in1k-640.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_h_22kto1k_640.pth) |
|
||||
| InternImage-G | 512 | 90.0580 | 98.9700 | [config](./configs/internimage-giant_8xb128_in1k-512.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_g_22kto1k_512.pth) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{wang2022internimage,
|
||||
title={InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions},
|
||||
author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others},
|
||||
journal={arXiv preprint arXiv:2211.05778},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,113 @@
|
|||
_base_ = 'mmpretrain::_base_/default_runtime.py'
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=224,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='../../data/imagenet',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='../../data/imagenet',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# model setting
|
||||
custom_imports = dict(imports='models')
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='InternImage',
|
||||
stem_channels=64,
|
||||
drop_path_rate=0.1,
|
||||
stage_blocks=[4, 4, 18, 4],
|
||||
groups=[4, 8, 16, 32]),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5)))
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type='AdamW', lr=1.25e-04, eps=1e-8, betas=(0.9, 0.999)),
|
||||
weight_decay=0.05)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type='LinearLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=20,
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=280,
|
||||
by_epoch=True,
|
||||
begin=20,
|
||||
end=300,
|
||||
eta_min=1.25e-06)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR,
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=128 * 8)
|
|
@ -0,0 +1,13 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=112,
|
||||
drop_path_rate=0.5,
|
||||
stage_blocks=[4, 4, 21, 4],
|
||||
groups=[7, 14, 28, 56],
|
||||
layer_scale=1e-5,
|
||||
post_norm=True),
|
||||
head=dict(in_channels=1344))
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=0.0005))
|
|
@ -0,0 +1,55 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=512,
|
||||
drop_path_rate=0.4,
|
||||
stage_blocks=[2, 2, 48, 4],
|
||||
groups=[16, 32, 64, 128],
|
||||
dw_kernel_size=5,
|
||||
level2_post_norm=True,
|
||||
level2_post_norm_block_ids=[5, 11, 17, 23, 29, 35, 41, 47],
|
||||
center_feature_scale=True,
|
||||
use_clip_projector=True,
|
||||
),
|
||||
neck=None,
|
||||
head=dict(in_channels=768))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=512,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=512,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=512),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=5e-6))
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=2,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
|
||||
]
|
||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
|
@ -0,0 +1,55 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=320,
|
||||
drop_path_rate=0.1,
|
||||
stage_blocks=[6, 6, 32, 6],
|
||||
groups=[10, 20, 40, 80],
|
||||
dw_kernel_size=5,
|
||||
res_post_norm=True,
|
||||
level2_post_norm=True,
|
||||
level2_post_norm_block_ids=[5, 11, 17, 23, 29],
|
||||
center_feature_scale=True,
|
||||
use_clip_projector=True,
|
||||
),
|
||||
neck=None,
|
||||
head=dict(in_channels=768))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=640,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=640,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=640),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=5e-6))
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=2,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
|
||||
]
|
||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
|
@ -0,0 +1,51 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=160,
|
||||
drop_path_rate=0.1,
|
||||
stage_blocks=[5, 5, 22, 5],
|
||||
groups=[10, 20, 40, 80],
|
||||
layer_scale=1e-5,
|
||||
offset_scale=2.0,
|
||||
post_norm=True),
|
||||
head=dict(in_channels=1920))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=384,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=5e-6))
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=2,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
|
||||
]
|
||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=80,
|
||||
drop_path_rate=0.4,
|
||||
stage_blocks=[4, 4, 21, 4],
|
||||
groups=[5, 10, 20, 40],
|
||||
layer_scale=1e-5,
|
||||
post_norm=True),
|
||||
head=dict(in_channels=960))
|
|
@ -0,0 +1,8 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=64,
|
||||
drop_path_rate=0.1,
|
||||
stage_blocks=[4, 4, 18, 4],
|
||||
groups=[4, 8, 16, 32]))
|
|
@ -0,0 +1,50 @@
|
|||
_base_ = './_base_.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
stem_channels=192,
|
||||
drop_path_rate=0.2,
|
||||
stage_blocks=[5, 5, 24, 5],
|
||||
groups=[12, 24, 48, 96],
|
||||
layer_scale=1e-5,
|
||||
offset_scale=2.0,
|
||||
post_norm=True),
|
||||
head=dict(in_channels=2304))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=384,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=384,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
optim_wrapper = dict(optimizer=dict(lr=5e-6))
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=2,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20)
|
||||
]
|
||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .intern_image import InternImage
|
||||
|
||||
__all__ = ['InternImage']
|
|
@ -0,0 +1,636 @@
|
|||
# Copyright (c) 2022 OpenGVLab
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# modified from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/intern_image.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn.bricks import DropPath, build_activation_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
from ops_dcnv3 import modules as opsm
|
||||
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.models.utils import CrossMultiheadAttention
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
class to_channels_first(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class to_channels_last(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
def build_norm_layer(dim,
|
||||
norm_layer,
|
||||
in_format='channels_last',
|
||||
out_format='channels_last',
|
||||
eps=1e-6):
|
||||
layers = []
|
||||
if norm_layer == 'BN':
|
||||
if in_format == 'channels_last':
|
||||
layers.append(to_channels_first())
|
||||
layers.append(nn.BatchNorm2d(dim))
|
||||
if out_format == 'channels_last':
|
||||
layers.append(to_channels_last())
|
||||
elif norm_layer == 'LN':
|
||||
if in_format == 'channels_first':
|
||||
layers.append(to_channels_last())
|
||||
layers.append(nn.LayerNorm(dim, eps=eps))
|
||||
if out_format == 'channels_first':
|
||||
layers.append(to_channels_first())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'build_norm_layer does not support {norm_layer}')
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class AttentiveBlock(nn.Module):
|
||||
"""Attentive Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: False.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop (float, optional): Dropout rate. Default: 0.0.
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0.
|
||||
norm_cfg (dict, optional): Normalization layer.
|
||||
Default: dict(type='LN')
|
||||
out_dim (int, optional): Dimension of output. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
out_dim=None):
|
||||
super().__init__()
|
||||
norm_layer = norm_cfg['type']
|
||||
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
|
||||
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
|
||||
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
|
||||
|
||||
self.cross_dcn = CrossMultiheadAttention(
|
||||
embed_dims=dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
if out_dim and out_dim != dim:
|
||||
self.cross_dcn.proj = nn.Linear(dim, out_dim)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x_q, x_kv, pos_q, pos_k):
|
||||
x_q = self.norm1_q(x_q + pos_q)
|
||||
x_k = self.norm1_k(x_kv + pos_k)
|
||||
x_v = self.norm1_v(x_kv)
|
||||
x = self.cross_dcn(x_q, k=x_k, v=x_v)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionPoolingBlock(AttentiveBlock):
|
||||
|
||||
def forward(self, x):
|
||||
x_q = x.mean(1, keepdim=True)
|
||||
x_kv = x
|
||||
pos_q, pos_k = 0, 0
|
||||
x = super().forward(x_q, x_kv, pos_q, pos_k)
|
||||
x = x.squeeze(1)
|
||||
return x
|
||||
|
||||
|
||||
class DownsampleLayer(nn.Module):
|
||||
"""Downsample layer of InternImage.
|
||||
|
||||
Args:
|
||||
channels (int): number of input channels
|
||||
norm_layer (str): normalization layer
|
||||
"""
|
||||
|
||||
def __init__(self, channels, norm_layer='LN'):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.norm = build_norm_layer(2 * channels, norm_layer,
|
||||
'channels_first', 'channels_last')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x.permute(0, 3, 1, 2))
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternImageLayer(nn.Module):
|
||||
"""Basic layer of InternImage.
|
||||
|
||||
Args:
|
||||
core_op (nn.Module): core operation of InternImage
|
||||
channels (int): number of input channels
|
||||
groups (list): Groups of each block.
|
||||
mlp_ratio (float): ratio of mlp hidden features to input channels
|
||||
drop (float): dropout rate
|
||||
drop_path (float): drop path rate
|
||||
act_cfg (dict): activation layer
|
||||
norm_cfg (dict): normalization layer
|
||||
post_norm (bool): whether to use post normalization
|
||||
layer_scale (float): layer scale
|
||||
offset_scale (float): offset scale
|
||||
with_cp (bool): whether to use checkpoint
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_op,
|
||||
channels,
|
||||
groups,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
post_norm=False,
|
||||
layer_scale=None,
|
||||
offset_scale=1.0,
|
||||
with_cp=False,
|
||||
dw_kernel_size=None,
|
||||
res_post_norm=False,
|
||||
center_feature_scale=False,
|
||||
remove_center=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.groups = groups
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.norm1 = build_norm_layer(channels, 'LN')
|
||||
self.post_norm = post_norm
|
||||
self.dcn = core_op(
|
||||
channels=channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
pad=1,
|
||||
dilation=1,
|
||||
group=groups,
|
||||
offset_scale=offset_scale,
|
||||
act_layer=act_cfg['type'],
|
||||
norm_layer=norm_cfg['type'],
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
center_feature_scale=center_feature_scale,
|
||||
remove_center=remove_center,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.norm2 = build_norm_layer(channels, 'LN')
|
||||
|
||||
self.mlp = FFN(
|
||||
embed_dims=channels,
|
||||
feedforward_channels=int(channels * mlp_ratio),
|
||||
act_cfg=act_cfg,
|
||||
ffn_drop=drop,
|
||||
add_identity=False)
|
||||
|
||||
self.layer_scale = layer_scale is not None
|
||||
if self.layer_scale:
|
||||
self.gamma1 = nn.Parameter(
|
||||
layer_scale * torch.ones(channels), requires_grad=True)
|
||||
self.gamma2 = nn.Parameter(
|
||||
layer_scale * torch.ones(channels), requires_grad=True)
|
||||
self.res_post_norm = res_post_norm
|
||||
if res_post_norm:
|
||||
self.res_post_norm1 = build_norm_layer(channels, 'LN')
|
||||
self.res_post_norm2 = build_norm_layer(channels, 'LN')
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
if not self.layer_scale:
|
||||
if self.post_norm:
|
||||
x = x + self.drop_path(self.norm1(self.dcn(x)))
|
||||
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
||||
elif self.res_post_norm:
|
||||
x = x + self.drop_path(
|
||||
self.res_post_norm1(self.dcn(self.norm1(x))))
|
||||
x = x + self.drop_path(
|
||||
self.res_post_norm2(self.mlp(self.norm2(x))))
|
||||
else:
|
||||
x = x + self.drop_path(self.dcn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
if self.post_norm:
|
||||
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
|
||||
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternImageBlock(nn.Module):
|
||||
"""Block of InternImage.
|
||||
|
||||
Args:
|
||||
core_op (nn.Module): core operation of InternImage
|
||||
channels (int): number of input channels
|
||||
depths (list): Depth of each block.
|
||||
groups (list): Groups of each block.
|
||||
mlp_ratio (float): ratio of mlp hidden features to input channels
|
||||
drop (float): dropout rate
|
||||
drop_path (float): drop path rate
|
||||
act_cfg (dict): activation layer
|
||||
norm_cfg (dict): normalization layer
|
||||
post_norm (bool): whether to use post normalization
|
||||
layer_scale (float): layer scale
|
||||
offset_scale (float): offset scale
|
||||
with_cp (bool): whether to use checkpoint
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_op,
|
||||
channels,
|
||||
depth,
|
||||
groups,
|
||||
downsample=True,
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
post_norm=False,
|
||||
offset_scale=1.0,
|
||||
layer_scale=None,
|
||||
with_cp=False,
|
||||
dw_kernel_size=None,
|
||||
post_norm_block_ids=None,
|
||||
res_post_norm=False,
|
||||
center_feature_scale=False,
|
||||
remove_center=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.depth = depth
|
||||
self.post_norm = post_norm
|
||||
self.center_feature_scale = center_feature_scale
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
InternImageLayer(
|
||||
core_op=core_op,
|
||||
channels=channels,
|
||||
groups=groups,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
post_norm=post_norm,
|
||||
layer_scale=layer_scale,
|
||||
offset_scale=offset_scale,
|
||||
with_cp=with_cp,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
res_post_norm=res_post_norm,
|
||||
center_feature_scale=center_feature_scale,
|
||||
remove_center=remove_center,
|
||||
) for i in range(depth)
|
||||
])
|
||||
if not self.post_norm or center_feature_scale:
|
||||
self.norm = build_norm_layer(channels, 'LN')
|
||||
self.post_norm_block_ids = post_norm_block_ids
|
||||
if post_norm_block_ids is not None:
|
||||
self.post_norms = nn.ModuleList([
|
||||
build_norm_layer(channels, 'LN', eps=1e-6)
|
||||
for _ in post_norm_block_ids
|
||||
])
|
||||
self.downsample = DownsampleLayer(
|
||||
channels=channels,
|
||||
norm_layer=norm_cfg['type']) if downsample else None
|
||||
|
||||
def forward(self, x, return_wo_downsample=False):
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if (self.post_norm_block_ids
|
||||
is not None) and (i in self.post_norm_block_ids):
|
||||
index = self.post_norm_block_ids.index(i)
|
||||
x = self.post_norms[index](x)
|
||||
if not self.post_norm or self.center_feature_scale:
|
||||
x = self.norm(x)
|
||||
if return_wo_downsample:
|
||||
x_ = x
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
if return_wo_downsample:
|
||||
return x, x_
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InternImage(BaseBackbone):
|
||||
""" InternImage
|
||||
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
|
||||
Args:
|
||||
core_op (str): Core operator. Default: 'DCNv3'
|
||||
stem_channels (int): Number of the first stage. Default: 64
|
||||
stage_blocks (list): Depth of each block. Default: [3, 4, 18, 5]
|
||||
groups (list): Groups of each block. Default: [3, 6, 12, 24]
|
||||
num_classes (int): Number of classes. Default: 1000
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
drop_rate (float): Probability of an element to be zeroed. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
||||
act_cfg (dict): Activation layer. Default: dict(type='GELU')
|
||||
norm_cfg (dict): Normalization layer. Default: dict(type='LN')
|
||||
layer_scale (bool): Whether to use layer scale. Default: False
|
||||
cls_scale (bool): Whether to use class scale. Default: False
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
dw_kernel_size (int): Size of the dwconv. Default: None
|
||||
use_clip_projector (bool): Whether to use clip projector. Default: False
|
||||
level2_post_norm (bool): Whether to use level2 post norm. Default: False
|
||||
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
|
||||
res_post_norm (bool): Whether to use res post norm. Default: False
|
||||
center_feature_scale (bool): Whether to use center feature scale. Default: False
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
stem_channels=64,
|
||||
stage_blocks=[3, 4, 18, 5],
|
||||
groups=[3, 6, 12, 24],
|
||||
mlp_ratio=4.,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
drop_path_type='linear',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
layer_scale=None,
|
||||
offset_scale=1.0,
|
||||
post_norm=False,
|
||||
cls_scale=1.5,
|
||||
with_cp=False,
|
||||
dw_kernel_size=None,
|
||||
use_clip_projector=False,
|
||||
level2_post_norm=False,
|
||||
level2_post_norm_block_ids=None,
|
||||
res_post_norm=False,
|
||||
center_feature_scale=False,
|
||||
remove_center=False,
|
||||
init_cfg=None):
|
||||
super(InternImage, self).__init__(init_cfg)
|
||||
|
||||
self.core_op = 'DCNv3'
|
||||
self.num_stages = len(stage_blocks)
|
||||
self.num_features = int(stem_channels * 2**(self.num_stages - 1))
|
||||
self.post_norm = post_norm
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.use_clip_projector = use_clip_projector
|
||||
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
||||
self.remove_center = remove_center
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
|
||||
# stem layer
|
||||
self._make_stem_layer(in_channels=3, stem_channels=stem_channels)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth decay rule
|
||||
total_depth = sum(stage_blocks)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
if drop_path_type == 'uniform':
|
||||
for i in range(len(dpr)):
|
||||
dpr[i] = drop_path_rate
|
||||
|
||||
# InternImage Layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(self.num_stages):
|
||||
if level2_post_norm and i == 2:
|
||||
post_norm_block_ids = level2_post_norm_block_ids
|
||||
else:
|
||||
post_norm_block_ids = None
|
||||
|
||||
layer = InternImageBlock(
|
||||
core_op=getattr(opsm, self.core_op),
|
||||
channels=int(stem_channels * 2**i),
|
||||
depth=stage_blocks[i],
|
||||
groups=groups[i],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[sum(stage_blocks[:i]):sum(stage_blocks[:i + 1])],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
post_norm=post_norm,
|
||||
downsample=(i < self.num_stages - 1),
|
||||
layer_scale=layer_scale,
|
||||
offset_scale=offset_scale,
|
||||
with_cp=with_cp,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
post_norm_block_ids=post_norm_block_ids,
|
||||
res_post_norm=res_post_norm,
|
||||
center_feature_scale=center_feature_scale,
|
||||
remove_center=remove_center,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
# Conv Head
|
||||
if not use_clip_projector:
|
||||
self.conv_head = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.num_features,
|
||||
int(self.num_features * cls_scale),
|
||||
kernel_size=1,
|
||||
bias=False),
|
||||
build_norm_layer(
|
||||
int(self.num_features * cls_scale), 'BN', 'channels_first',
|
||||
'channels_first'), build_activation_layer(act_cfg))
|
||||
|
||||
else:
|
||||
pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim \
|
||||
= 1024, 2, 16, 768
|
||||
self.dcnv3_head_x4 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=self.num_features,
|
||||
out_channels=pretrain_embed_dim * (_stride**2),
|
||||
kernel_size=1), nn.PixelShuffle(_stride))
|
||||
self.dcnv3_head_x3 = nn.Conv2d(
|
||||
in_channels=self.num_features // 2,
|
||||
out_channels=pretrain_embed_dim,
|
||||
kernel_size=1)
|
||||
self.clip_projector = AttentionPoolingBlock(
|
||||
dim=pretrain_embed_dim,
|
||||
num_heads=attnpool_num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
norm_cfg=norm_cfg,
|
||||
out_dim=clip_embed_dim)
|
||||
norm_layer = norm_cfg['type']
|
||||
self.fc_norm = build_norm_layer(
|
||||
clip_embed_dim, norm_layer, eps=1e-6)
|
||||
|
||||
def init_weights(self):
|
||||
super(InternImage, self).init_weights()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
elif isinstance(m, getattr(opsm, self.core_op)):
|
||||
m._reset_parameters()
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
norm_layer = self.norm_cfg['type']
|
||||
self.patch_embed = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
build_norm_layer(stem_channels // 2, norm_layer, 'channels_first',
|
||||
'channels_first'),
|
||||
build_activation_layer(self.act_cfg),
|
||||
nn.Conv2d(
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
build_norm_layer(stem_channels, norm_layer, 'channels_first',
|
||||
'channels_last'),
|
||||
)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.conv_head(x.permute(0, 3, 1, 2))
|
||||
return (x, )
|
||||
|
||||
def forward_features_seq_out(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
seq_out = []
|
||||
for layer in self.layers:
|
||||
x, x_ = layer(x, return_wo_downsample=True)
|
||||
seq_out.append(x_)
|
||||
return seq_out
|
||||
|
||||
def forward_clip_projector(self, x): # for InternImage-H/G
|
||||
xs = self.forward_features_seq_out(x)
|
||||
x1, x2, x3, x4 = xs
|
||||
|
||||
x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
|
||||
x4 = self.dcnv3_head_x4(x4)
|
||||
x = x4
|
||||
x3 = self.dcnv3_head_x3(x3)
|
||||
x = x + x3
|
||||
|
||||
x = x.flatten(-2).transpose(1, 2).contiguous()
|
||||
x = self.clip_projector(x)
|
||||
x = self.fc_norm(x)
|
||||
|
||||
return (x, )
|
||||
|
||||
def forward(self, x):
|
||||
if not self.use_clip_projector:
|
||||
# for InternImage-T/S/B/L/XL
|
||||
return self.forward_features(x)
|
||||
else:
|
||||
# for InternImage-H/G
|
||||
return self.forward_clip_projector(x)
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_filter(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
def internimage_to_mmpretrain():
|
||||
for k, v in state_dict['model'].items():
|
||||
if 'head.' in k and 'conv_head' not in k:
|
||||
if 'weight' in k:
|
||||
new_k = 'head.fc.weight'
|
||||
else:
|
||||
new_k = 'head.fc.bias'
|
||||
elif 'patch_embed' in k:
|
||||
map_fun = {
|
||||
'conv1': '0',
|
||||
'norm1': '1',
|
||||
'conv2': '3',
|
||||
'norm2': '4'
|
||||
}
|
||||
new_k = k
|
||||
for old, new in map_fun.items():
|
||||
new_k = new_k.replace(old, new)
|
||||
new_k = 'backbone.' + new_k
|
||||
|
||||
elif 'levels' in k:
|
||||
new_k = k.replace('levels', 'layers')
|
||||
if 'mlp' in new_k:
|
||||
new_k = new_k.replace('fc1', 'layers.0.0')
|
||||
new_k = new_k.replace('fc2', 'layers.1')
|
||||
new_k = 'backbone.' + new_k
|
||||
elif 'clip_projector.cross_dcn.k_bias' in k:
|
||||
continue
|
||||
else:
|
||||
new_k = 'backbone.' + k
|
||||
|
||||
state_dict[new_k] = state_dict['model'][k]
|
||||
del state_dict['model']
|
||||
|
||||
# The original weights need to be converted to mmpretrain format.
|
||||
# Some modules in the original weights starts with 'levels',
|
||||
# and in this implement they are replaced with 'layers'.
|
||||
if 'model' in state_dict and 'levels.0.blocks.0.norm1.0.weight'\
|
||||
in state_dict['model']:
|
||||
internimage_to_mmpretrain()
|
|
@ -0,0 +1,10 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch # noqa
|
|
@ -0,0 +1,248 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import pkg_resources
|
||||
|
||||
import DCNv3
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
|
||||
|
||||
|
||||
class DCNv3Function(Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input, offset, mask, kernel_h, kernel_w, stride_h,
|
||||
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
group_channels, offset_scale, im2col_step, remove_center):
|
||||
ctx.kernel_h = kernel_h
|
||||
ctx.kernel_w = kernel_w
|
||||
ctx.stride_h = stride_h
|
||||
ctx.stride_w = stride_w
|
||||
ctx.pad_h = pad_h
|
||||
ctx.pad_w = pad_w
|
||||
ctx.dilation_h = dilation_h
|
||||
ctx.dilation_w = dilation_w
|
||||
ctx.group = group
|
||||
ctx.group_channels = group_channels
|
||||
ctx.offset_scale = offset_scale
|
||||
ctx.im2col_step = im2col_step
|
||||
ctx.remove_center = remove_center
|
||||
|
||||
args = [
|
||||
input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h,
|
||||
pad_w, dilation_h, dilation_w, group, group_channels, offset_scale,
|
||||
ctx.im2col_step
|
||||
]
|
||||
if remove_center or dcn_version > 1.0:
|
||||
args.append(remove_center)
|
||||
|
||||
output = DCNv3.dcnv3_forward(*args)
|
||||
ctx.save_for_backward(input, offset, mask)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
input, offset, mask = ctx.saved_tensors
|
||||
|
||||
args = [
|
||||
input, offset, mask, ctx.kernel_h, ctx.kernel_w, ctx.stride_h,
|
||||
ctx.stride_w, ctx.pad_h, ctx.pad_w, ctx.dilation_h, ctx.dilation_w,
|
||||
ctx.group, ctx.group_channels, ctx.offset_scale,
|
||||
grad_output.contiguous(), ctx.im2col_step
|
||||
]
|
||||
if ctx.remove_center or dcn_version > 1.0:
|
||||
args.append(ctx.remove_center)
|
||||
|
||||
grad_input, grad_offset, grad_mask = \
|
||||
DCNv3.dcnv3_backward(*args)
|
||||
|
||||
return grad_input, grad_offset, grad_mask, \
|
||||
None, None, None, None, None, None, None,\
|
||||
None, None, None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h,
|
||||
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
group_channels, offset_scale, im2col_step, remove_center):
|
||||
"""Symbolic function for mmdeploy::DCNv3.
|
||||
|
||||
Returns:
|
||||
DCNv3 op for onnx.
|
||||
"""
|
||||
return g.op(
|
||||
'mmdeploy::TRTDCNv3',
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
kernel_h_i=int(kernel_h),
|
||||
kernel_w_i=int(kernel_w),
|
||||
stride_h_i=int(stride_h),
|
||||
stride_w_i=int(stride_w),
|
||||
pad_h_i=int(pad_h),
|
||||
pad_w_i=int(pad_w),
|
||||
dilation_h_i=int(dilation_h),
|
||||
dilation_w_i=int(dilation_w),
|
||||
group_i=int(group),
|
||||
group_channels_i=int(group_channels),
|
||||
offset_scale_f=float(offset_scale),
|
||||
im2col_step_i=int(im2col_step),
|
||||
remove_center=int(remove_center),
|
||||
)
|
||||
|
||||
|
||||
def _get_reference_points(spatial_shapes,
|
||||
device,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
dilation_h,
|
||||
dilation_w,
|
||||
pad_h=0,
|
||||
pad_w=0,
|
||||
stride_h=1,
|
||||
stride_w=1):
|
||||
_, H_, W_, _ = spatial_shapes
|
||||
H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
|
||||
W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
|
||||
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(
|
||||
# pad_h + 0.5,
|
||||
# H_ - pad_h - 0.5,
|
||||
(dilation_h * (kernel_h - 1)) // 2 + 0.5,
|
||||
(dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
|
||||
H_out,
|
||||
dtype=torch.float32,
|
||||
device=device),
|
||||
torch.linspace(
|
||||
# pad_w + 0.5,
|
||||
# W_ - pad_w - 0.5,
|
||||
(dilation_w * (kernel_w - 1)) // 2 + 0.5,
|
||||
(dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
|
||||
W_out,
|
||||
dtype=torch.float32,
|
||||
device=device))
|
||||
ref_y = ref_y.reshape(-1)[None] / H_
|
||||
ref_x = ref_x.reshape(-1)[None] / W_
|
||||
|
||||
ref = torch.stack((ref_x, ref_y), -1).reshape(1, H_out, W_out, 1, 2)
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h,
|
||||
dilation_w, group, device):
|
||||
_, H_, W_, _ = spatial_shapes
|
||||
points_list = []
|
||||
x, y = torch.meshgrid(
|
||||
torch.linspace(
|
||||
-((dilation_w * (kernel_w - 1)) // 2),
|
||||
-((dilation_w * (kernel_w - 1)) // 2) +
|
||||
(kernel_w - 1) * dilation_w,
|
||||
kernel_w,
|
||||
dtype=torch.float32,
|
||||
device=device),
|
||||
torch.linspace(
|
||||
-((dilation_h * (kernel_h - 1)) // 2),
|
||||
-((dilation_h * (kernel_h - 1)) // 2) +
|
||||
(kernel_h - 1) * dilation_h,
|
||||
kernel_h,
|
||||
dtype=torch.float32,
|
||||
device=device))
|
||||
|
||||
points_list.extend([x / W_, y / H_])
|
||||
grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
|
||||
repeat(1, group, 1).permute(1, 0, 2)
|
||||
grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h):
|
||||
idx = list(range(sampling_locations.shape[-2]))
|
||||
C = (kernel_w * kernel_h - 1) // 2
|
||||
idx = [i for i in idx if i != C and (i - C) % (C * 2 + 1) != 0]
|
||||
sampling_locations = sampling_locations[:, :, :, idx, :]
|
||||
return sampling_locations
|
||||
|
||||
|
||||
def dcnv3_core_pytorch(input, offset, mask, kernel_h, kernel_w, stride_h,
|
||||
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
group_channels, offset_scale, remove_center):
|
||||
# for debug and test only,
|
||||
# need to use cuda version instead
|
||||
|
||||
if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0
|
||||
or kernel_w != kernel_h):
|
||||
raise ValueError(
|
||||
'remove_center is only compatible with square odd kernel size.')
|
||||
|
||||
input = F.pad(input, [0, 0, pad_h, pad_h, pad_w, pad_w])
|
||||
N_, H_in, W_in, _ = input.shape
|
||||
_, H_out, W_out, _ = offset.shape
|
||||
|
||||
ref = _get_reference_points(input.shape, input.device, kernel_h, kernel_w,
|
||||
dilation_h, dilation_w, pad_h, pad_w, stride_h,
|
||||
stride_w)
|
||||
grid = _generate_dilation_grids(input.shape, kernel_h, kernel_w,
|
||||
dilation_h, dilation_w, group,
|
||||
input.device)
|
||||
spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
|
||||
repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).\
|
||||
to(input.device)
|
||||
|
||||
sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1)
|
||||
if remove_center:
|
||||
sampling_locations = remove_center_sampling_locations(
|
||||
sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h)
|
||||
sampling_locations = sampling_locations.flatten(3, 4)
|
||||
sampling_locations = sampling_locations + \
|
||||
offset * offset_scale / spatial_norm
|
||||
|
||||
P_ = kernel_h * kernel_w - remove_center
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
# N_, H_in, W_in, group*group_channels ->
|
||||
# N_, H_in*W_in, group*group_channels ->
|
||||
# N_, group*group_channels, H_in*W_in ->
|
||||
# N_*group, group_channels, H_in, W_in
|
||||
input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
|
||||
reshape(N_*group, group_channels, H_in, W_in)
|
||||
# N_, H_out, W_out, group*P_*2 ->
|
||||
# N_, H_out*W_out, group, P_, 2 ->
|
||||
# N_, group, H_out*W_out, P_, 2 ->
|
||||
# N_*group, H_out*W_out, P_, 2
|
||||
sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).\
|
||||
transpose(1, 2).flatten(0, 1)
|
||||
# N_*group, group_channels, H_out*W_out, P_
|
||||
sampling_input_ = F.grid_sample(
|
||||
input_,
|
||||
sampling_grid_,
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
align_corners=False)
|
||||
|
||||
# (N_, H_out, W_out, group*P_) ->
|
||||
# N_, H_out*W_out, group, P_ ->
|
||||
# (N_, group, H_out*W_out, P_) ->
|
||||
# (N_*group, 1, H_out*W_out, P_)
|
||||
mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
|
||||
reshape(N_*group, 1, H_out*W_out, P_)
|
||||
output = (sampling_input_ * mask).sum(-1).view(N_, group * group_channels,
|
||||
H_out * W_out)
|
||||
|
||||
return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env bash
|
||||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
python setup.py build install
|
|
@ -0,0 +1,10 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
from .dcnv3 import DCNv3, DCNv3_pytorch # noqa
|
|
@ -0,0 +1,360 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
from ..functions import DCNv3Function, dcnv3_core_pytorch
|
||||
|
||||
|
||||
class to_channels_first(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class to_channels_last(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
def build_norm_layer(dim,
|
||||
norm_layer,
|
||||
in_format='channels_last',
|
||||
out_format='channels_last',
|
||||
eps=1e-6):
|
||||
layers = []
|
||||
if norm_layer == 'BN':
|
||||
if in_format == 'channels_last':
|
||||
layers.append(to_channels_first())
|
||||
layers.append(nn.BatchNorm2d(dim))
|
||||
if out_format == 'channels_last':
|
||||
layers.append(to_channels_last())
|
||||
elif norm_layer == 'LN':
|
||||
if in_format == 'channels_first':
|
||||
layers.append(to_channels_last())
|
||||
layers.append(nn.LayerNorm(dim, eps=eps))
|
||||
if out_format == 'channels_first':
|
||||
layers.append(to_channels_first())
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'build_norm_layer does not support {norm_layer}')
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def build_act_layer(act_layer):
|
||||
if act_layer == 'ReLU':
|
||||
return nn.ReLU(inplace=True)
|
||||
elif act_layer == 'SiLU':
|
||||
return nn.SiLU(inplace=True)
|
||||
elif act_layer == 'GELU':
|
||||
return nn.GELU()
|
||||
|
||||
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
||||
|
||||
|
||||
def _is_power_of_2(n):
|
||||
if (not isinstance(n, int)) or (n < 0):
|
||||
raise ValueError(
|
||||
'invalid input for _is_power_of_2: {} (type: {})'.format(
|
||||
n, type(n)))
|
||||
|
||||
return (n & (n - 1) == 0) and n != 0
|
||||
|
||||
|
||||
class CenterFeatureScaleModule(nn.Module):
|
||||
|
||||
def forward(self, query, center_feature_scale_proj_weight,
|
||||
center_feature_scale_proj_bias):
|
||||
center_feature_scale = F.linear(
|
||||
query,
|
||||
weight=center_feature_scale_proj_weight,
|
||||
bias=center_feature_scale_proj_bias).sigmoid()
|
||||
return center_feature_scale
|
||||
|
||||
|
||||
class DCNv3_pytorch(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels=64,
|
||||
kernel_size=3,
|
||||
dw_kernel_size=None,
|
||||
stride=1,
|
||||
pad=1,
|
||||
dilation=1,
|
||||
group=4,
|
||||
offset_scale=1.0,
|
||||
act_layer='GELU',
|
||||
norm_layer='LN',
|
||||
center_feature_scale=False,
|
||||
remove_center=False,
|
||||
):
|
||||
"""DCNv3 Module.
|
||||
|
||||
:param channels
|
||||
:param kernel_size
|
||||
:param stride
|
||||
:param pad
|
||||
:param dilation
|
||||
:param group
|
||||
:param offset_scale
|
||||
:param act_layer
|
||||
:param norm_layer
|
||||
"""
|
||||
super().__init__()
|
||||
if channels % group != 0:
|
||||
raise ValueError(f'channels must be divisible by group, '
|
||||
f'but got {channels} and {group}')
|
||||
_d_per_group = channels // group
|
||||
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\
|
||||
else kernel_size
|
||||
# you'd better set _d_per_group to a power of 2
|
||||
# which is more efficient in our CUDA implementation
|
||||
if not _is_power_of_2(_d_per_group):
|
||||
warnings.warn(
|
||||
"You'd better set channels in DCNv3 "
|
||||
'to make the dimension of each attention head a power of 2 '
|
||||
'which is more efficient in our CUDA implementation.')
|
||||
|
||||
self.offset_scale = offset_scale
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dw_kernel_size = dw_kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.pad = pad
|
||||
self.group = group
|
||||
self.group_channels = channels // group
|
||||
self.offset_scale = offset_scale
|
||||
self.center_feature_scale = center_feature_scale
|
||||
self.remove_center = int(remove_center)
|
||||
|
||||
self.dw_conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=dw_kernel_size,
|
||||
stride=1,
|
||||
padding=(dw_kernel_size - 1) // 2,
|
||||
groups=channels),
|
||||
build_norm_layer(channels, norm_layer, 'channels_first',
|
||||
'channels_last'), build_act_layer(act_layer))
|
||||
self.offset = nn.Linear(
|
||||
channels,
|
||||
group * (kernel_size * kernel_size - remove_center) * 2)
|
||||
self.mask = nn.Linear(
|
||||
channels, group * (kernel_size * kernel_size - remove_center))
|
||||
self.input_proj = nn.Linear(channels, channels)
|
||||
self.output_proj = nn.Linear(channels, channels)
|
||||
self._reset_parameters()
|
||||
|
||||
if center_feature_scale:
|
||||
self.center_feature_scale_proj_weight = nn.Parameter(
|
||||
torch.zeros((group, channels), dtype=torch.float))
|
||||
self.center_feature_scale_proj_bias = nn.Parameter(
|
||||
torch.tensor(0.0, dtype=torch.float).view(
|
||||
(1, )).repeat(group, ))
|
||||
self.center_feature_scale_module = CenterFeatureScaleModule()
|
||||
|
||||
def _reset_parameters(self):
|
||||
constant_(self.offset.weight.data, 0.)
|
||||
constant_(self.offset.bias.data, 0.)
|
||||
constant_(self.mask.weight.data, 0.)
|
||||
constant_(self.mask.bias.data, 0.)
|
||||
xavier_uniform_(self.input_proj.weight.data)
|
||||
constant_(self.input_proj.bias.data, 0.)
|
||||
xavier_uniform_(self.output_proj.weight.data)
|
||||
constant_(self.output_proj.bias.data, 0.)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
:param query (N, H, W, C)
|
||||
:return output (N, H, W, C)
|
||||
"""
|
||||
N, H, W, _ = input.shape
|
||||
|
||||
x = self.input_proj(input)
|
||||
x_proj = x
|
||||
|
||||
x1 = input.permute(0, 3, 1, 2)
|
||||
x1 = self.dw_conv(x1)
|
||||
offset = self.offset(x1)
|
||||
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
||||
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
|
||||
|
||||
x = dcnv3_core_pytorch(x, offset, mask, self.kernel_size,
|
||||
self.kernel_size, self.stride, self.stride,
|
||||
self.pad, self.pad, self.dilation,
|
||||
self.dilation, self.group, self.group_channels,
|
||||
self.offset_scale, self.remove_center)
|
||||
if self.center_feature_scale:
|
||||
center_feature_scale = self.center_feature_scale_module(
|
||||
x1, self.center_feature_scale_proj_weight,
|
||||
self.center_feature_scale_proj_bias)
|
||||
# N, H, W, groups ->
|
||||
# N, H, W, groups, 1 ->
|
||||
# N, H, W, groups, _d_per_group ->
|
||||
# N, H, W, channels
|
||||
center_feature_scale = center_feature_scale[..., None].repeat(
|
||||
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
||||
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
||||
x = self.output_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DCNv3(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels=64,
|
||||
kernel_size=3,
|
||||
dw_kernel_size=None,
|
||||
stride=1,
|
||||
pad=1,
|
||||
dilation=1,
|
||||
group=4,
|
||||
offset_scale=1.0,
|
||||
act_layer='GELU',
|
||||
norm_layer='LN',
|
||||
center_feature_scale=False,
|
||||
remove_center=False,
|
||||
):
|
||||
"""DCNv3 Module.
|
||||
|
||||
:param channels
|
||||
:param kernel_size
|
||||
:param stride
|
||||
:param pad
|
||||
:param dilation
|
||||
:param group
|
||||
:param offset_scale
|
||||
:param act_layer
|
||||
:param norm_layer
|
||||
"""
|
||||
super().__init__()
|
||||
if channels % group != 0:
|
||||
raise ValueError(f'channels must be divisible by group, '
|
||||
f'but got {channels} and {group}')
|
||||
_d_per_group = channels // group
|
||||
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\
|
||||
else kernel_size
|
||||
# you'd better set _d_per_group to a power of 2
|
||||
# which is more efficient in our CUDA implementation
|
||||
if not _is_power_of_2(_d_per_group):
|
||||
warnings.warn(
|
||||
"You'd better set channels in DCNv3 "
|
||||
'to make the dimension of each attention head a power of 2 '
|
||||
'which is more efficient in our CUDA implementation.')
|
||||
|
||||
self.offset_scale = offset_scale
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dw_kernel_size = dw_kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.pad = pad
|
||||
self.group = group
|
||||
self.group_channels = channels // group
|
||||
self.offset_scale = offset_scale
|
||||
self.center_feature_scale = center_feature_scale
|
||||
self.remove_center = int(remove_center)
|
||||
|
||||
if self.remove_center and self.kernel_size % 2 == 0:
|
||||
raise ValueError(
|
||||
'remove_center is only compatible with odd kernel size.')
|
||||
|
||||
self.dw_conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=dw_kernel_size,
|
||||
stride=1,
|
||||
padding=(dw_kernel_size - 1) // 2,
|
||||
groups=channels),
|
||||
build_norm_layer(channels, norm_layer, 'channels_first',
|
||||
'channels_last'), build_act_layer(act_layer))
|
||||
self.offset = nn.Linear(
|
||||
channels,
|
||||
group * (kernel_size * kernel_size - remove_center) * 2)
|
||||
self.mask = nn.Linear(
|
||||
channels, group * (kernel_size * kernel_size - remove_center))
|
||||
self.input_proj = nn.Linear(channels, channels)
|
||||
self.output_proj = nn.Linear(channels, channels)
|
||||
self._reset_parameters()
|
||||
|
||||
if center_feature_scale:
|
||||
self.center_feature_scale_proj_weight = nn.Parameter(
|
||||
torch.zeros((group, channels), dtype=torch.float))
|
||||
self.center_feature_scale_proj_bias = nn.Parameter(
|
||||
torch.tensor(0.0, dtype=torch.float).view(
|
||||
(1, )).repeat(group, ))
|
||||
self.center_feature_scale_module = CenterFeatureScaleModule()
|
||||
|
||||
def _reset_parameters(self):
|
||||
constant_(self.offset.weight.data, 0.)
|
||||
constant_(self.offset.bias.data, 0.)
|
||||
constant_(self.mask.weight.data, 0.)
|
||||
constant_(self.mask.bias.data, 0.)
|
||||
xavier_uniform_(self.input_proj.weight.data)
|
||||
constant_(self.input_proj.bias.data, 0.)
|
||||
xavier_uniform_(self.output_proj.weight.data)
|
||||
constant_(self.output_proj.bias.data, 0.)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
:param query (N, H, W, C)
|
||||
:return output (N, H, W, C)
|
||||
"""
|
||||
N, H, W, _ = input.shape
|
||||
|
||||
x = self.input_proj(input)
|
||||
x_proj = x
|
||||
dtype = x.dtype
|
||||
|
||||
x1 = input.permute(0, 3, 1, 2)
|
||||
x1 = self.dw_conv(x1)
|
||||
offset = self.offset(x1)
|
||||
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
||||
mask = F.softmax(mask, -1)
|
||||
mask = mask.reshape(N, H, W, -1).type(dtype)
|
||||
|
||||
x = DCNv3Function.apply(x, offset, mask, self.kernel_size,
|
||||
self.kernel_size, self.stride, self.stride,
|
||||
self.pad, self.pad, self.dilation,
|
||||
self.dilation, self.group, self.group_channels,
|
||||
self.offset_scale, 256, self.remove_center)
|
||||
|
||||
if self.center_feature_scale:
|
||||
center_feature_scale = self.center_feature_scale_module(
|
||||
x1, self.center_feature_scale_proj_weight,
|
||||
self.center_feature_scale_proj_bias)
|
||||
# N, H, W, groups ->
|
||||
# N, H, W, groups, 1 ->
|
||||
# N, H, W, groups, _d_per_group ->
|
||||
# N, H, W, channels
|
||||
center_feature_scale = center_feature_scale[..., None].repeat(
|
||||
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
||||
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
||||
x = self.output_proj(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,72 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
import glob
|
||||
import os
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
|
||||
|
||||
requirements = ['torch', 'torchvision']
|
||||
|
||||
|
||||
def get_extensions():
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
extensions_dir = os.path.join(this_dir, 'src')
|
||||
|
||||
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
|
||||
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
|
||||
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
|
||||
|
||||
sources = main_file + source_cpu
|
||||
extension = CppExtension
|
||||
extra_compile_args = {'cxx': []}
|
||||
define_macros = []
|
||||
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension
|
||||
sources += source_cuda
|
||||
define_macros += [('WITH_CUDA', None)]
|
||||
extra_compile_args['nvcc'] = [
|
||||
# "-DCUDA_HAS_FP16=1",
|
||||
# "-D__CUDA_NO_HALF_OPERATORS__",
|
||||
# "-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
# "-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError('Cuda is not availabel')
|
||||
|
||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
||||
include_dirs = [extensions_dir]
|
||||
ext_modules = [
|
||||
extension(
|
||||
'DCNv3',
|
||||
sources,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
]
|
||||
return ext_modules
|
||||
|
||||
|
||||
setup(
|
||||
name='DCNv3',
|
||||
version='1.1',
|
||||
author='InternImage',
|
||||
url='https://github.com/OpenGVLab/InternImage',
|
||||
description='PyTorch Wrapper for CUDA Functions of DCNv3',
|
||||
packages=find_packages(exclude=(
|
||||
'configs',
|
||||
'tests',
|
||||
)),
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
|
||||
)
|
|
@ -0,0 +1,37 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h,
|
||||
const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const int im2col_step) {
|
||||
AT_ERROR("Not implement on cpu");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const at::Tensor &grad_output, const int im2col_step) {
|
||||
AT_ERROR("Not implement on cpu");
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h,
|
||||
const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const int im2col_step);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const at::Tensor &grad_output, const int im2col_step);
|
|
@ -0,0 +1,174 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include "cuda/dcnv3_im2col_cuda.cuh"
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h,
|
||||
const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels,
|
||||
const float offset_scale, const int im2col_step, const int remove_center) {
|
||||
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
|
||||
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
|
||||
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
|
||||
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int height_in = input.size(1);
|
||||
const int width_in = input.size(2);
|
||||
const int channels = input.size(3);
|
||||
const int height_out =
|
||||
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
|
||||
1;
|
||||
const int width_out =
|
||||
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
|
||||
1;
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0,
|
||||
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||
AT_ASSERTM(
|
||||
channels == (group * group_channels),
|
||||
"Input channels and group times group channels won't match: (%d vs %d).",
|
||||
channels, group * group_channels);
|
||||
|
||||
auto output =
|
||||
at::zeros({batch, height_out, width_out, group * group_channels},
|
||||
input.options());
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto output_n = output.view({batch / batch_n, batch_n, height_out,
|
||||
width_out, group * group_channels});
|
||||
auto per_input_size = height_in * width_in * group * group_channels;
|
||||
auto per_offset_size =
|
||||
height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
|
||||
auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto columns = output_n.select(0, n);
|
||||
// AT_DISPATCH_FLOATING_TYPES(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.type(), "ms_deform_attn_forward_cuda", ([&] {
|
||||
dcnv3_im2col_cuda(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
|
||||
offset.data<scalar_t>() +
|
||||
n * im2col_step_ * per_offset_size,
|
||||
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
|
||||
columns.data<scalar_t>(), kernel_h, kernel_w, stride_h,
|
||||
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
group_channels, batch_n, height_in, width_in, height_out,
|
||||
width_out, offset_scale, remove_center);
|
||||
}));
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const at::Tensor &grad_output, const int im2col_step, const int remove_center) {
|
||||
|
||||
AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous");
|
||||
AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous");
|
||||
AT_ASSERTM(grad_output.is_contiguous(),
|
||||
"grad_output tensor has to be contiguous");
|
||||
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
|
||||
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");
|
||||
AT_ASSERTM(grad_output.type().is_cuda(),
|
||||
"grad_output must be a CUDA tensor");
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int height_in = input.size(1);
|
||||
const int width_in = input.size(2);
|
||||
const int channels = input.size(3);
|
||||
const int height_out =
|
||||
(height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h +
|
||||
1;
|
||||
const int width_out =
|
||||
(width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
|
||||
1;
|
||||
const int im2col_step_ = std::min(batch, im2col_step);
|
||||
|
||||
AT_ASSERTM(batch % im2col_step_ == 0,
|
||||
"batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
||||
AT_ASSERTM(
|
||||
channels == (group * group_channels),
|
||||
"Input channels and group times group channels won't match: (%d vs %d).",
|
||||
channels, group * group_channels);
|
||||
|
||||
auto dtype = input.dtype();
|
||||
if (dtype == at::kHalf) {
|
||||
dtype = at::kFloat;
|
||||
}
|
||||
|
||||
auto grad_input = at::zeros_like(input, dtype);
|
||||
auto grad_offset = at::zeros_like(offset, dtype);
|
||||
auto grad_mask = at::zeros_like(mask, dtype);
|
||||
|
||||
const int batch_n = im2col_step_;
|
||||
auto per_input_size = height_in * width_in * group * group_channels;
|
||||
auto per_offset_size =
|
||||
height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2;
|
||||
auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center);
|
||||
auto grad_output_n =
|
||||
grad_output.view({batch / im2col_step_, batch_n, height_out * width_out,
|
||||
group, group_channels});
|
||||
|
||||
for (int n = 0; n < batch / im2col_step_; ++n) {
|
||||
auto grad_output_g = grad_output_n.select(0, n);
|
||||
// AT_DISPATCH_FLOATING_TYPES(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.type(), "ms_deform_attn_backward_cuda", ([&] {
|
||||
dcnv3_col2im_cuda(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
grad_output_g.data<scalar_t>(),
|
||||
input.data<scalar_t>() + n * im2col_step_ * per_input_size,
|
||||
offset.data<scalar_t>() +
|
||||
n * im2col_step_ * per_offset_size,
|
||||
mask.data<scalar_t>() + n * im2col_step_ * per_mask_size,
|
||||
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
|
||||
dilation_h, dilation_w, group, group_channels, batch_n,
|
||||
height_in, width_in, height_out, width_out, offset_scale, remove_center,
|
||||
grad_input.data<opmath_t>() +
|
||||
n * im2col_step_ * per_input_size,
|
||||
grad_offset.data<opmath_t>() +
|
||||
n * im2col_step_ * per_offset_size,
|
||||
grad_mask.data<opmath_t>() +
|
||||
n * im2col_step_ * per_mask_size);
|
||||
}));
|
||||
}
|
||||
|
||||
if (input.dtype() == torch::kHalf) {
|
||||
return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf),
|
||||
grad_mask.to(torch::kHalf)};
|
||||
} else {
|
||||
return {grad_input, grad_offset, grad_mask};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h,
|
||||
const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels,
|
||||
const float offset_scale, const int im2col_step, const int remove_center);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group,
|
||||
const int group_channels, const float offset_scale,
|
||||
const at::Tensor &grad_output, const int im2col_step, const int remove_center);
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,59 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cpu/dcnv3_cpu.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include "cuda/dcnv3_cuda.h"
|
||||
#endif
|
||||
|
||||
at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int group, const int group_channels,
|
||||
const float offset_scale, const int im2col_step, const int remove_center) {
|
||||
if (input.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w,
|
||||
stride_h, stride_w, pad_h, pad_w, dilation_h,
|
||||
dilation_w, group, group_channels,
|
||||
offset_scale, im2col_step, remove_center);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
dcnv3_backward(const at::Tensor &input, const at::Tensor &offset,
|
||||
const at::Tensor &mask, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h,
|
||||
const int pad_w, const int dilation_h, const int dilation_w,
|
||||
const int group, const int group_channels,
|
||||
const float offset_scale, const at::Tensor &grad_output,
|
||||
const int im2col_step, const int remove_center) {
|
||||
if (input.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w,
|
||||
stride_h, stride_w, pad_h, pad_w, dilation_h,
|
||||
dilation_w, group, group_channels,
|
||||
offset_scale, grad_output, im2col_step, remove_center);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
/*!
|
||||
**************************************************************************************************
|
||||
* InternImage
|
||||
* Copyright (c) 2022 OpenGVLab
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
**************************************************************************************************
|
||||
* Modified from
|
||||
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
**************************************************************************************************
|
||||
*/
|
||||
|
||||
#include "dcnv3.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward");
|
||||
m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward");
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
# --------------------------------------------------------
|
||||
# InternImage
|
||||
# Copyright (c) 2022 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
# Copied from
|
||||
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import math # noqa
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn # noqa
|
||||
from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch
|
||||
from torch.autograd import gradcheck # noqa
|
||||
|
||||
H_in, W_in = 8, 8
|
||||
N, M, D = 2, 4, 16
|
||||
Kh, Kw = 3, 3
|
||||
remove_center = False
|
||||
P = Kh * Kw - remove_center
|
||||
offset_scale = 2.0
|
||||
pad = 1
|
||||
dilation = 1
|
||||
stride = 1
|
||||
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
||||
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
||||
|
||||
torch.manual_seed(3)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def check_forward_equal_with_pytorch_double():
|
||||
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
||||
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
||||
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
||||
mask /= mask.sum(-1, keepdim=True)
|
||||
mask = mask.reshape(N, H_out, W_out, M * P)
|
||||
|
||||
output_pytorch = dcnv3_core_pytorch(input.double(), offset.double(),
|
||||
mask.double(), Kh, Kw, stride, stride,
|
||||
Kh // 2, Kw // 2, dilation, dilation,
|
||||
M, D, offset_scale,
|
||||
remove_center).detach().cpu()
|
||||
|
||||
im2col_step = 2
|
||||
output_cuda = DCNv3Function.apply(input.double(), offset.double(),
|
||||
mask.double(), Kh, Kw, stride, stride,
|
||||
Kh // 2, Kw // 2, dilation, dilation, M,
|
||||
D, offset_scale, im2col_step,
|
||||
remove_center).detach().cpu()
|
||||
|
||||
fwdok = torch.allclose(output_cuda, output_pytorch)
|
||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
print('>>> forward double')
|
||||
print(f'* {fwdok} check_forward_equal_with_pytorch_double:'
|
||||
f' max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def check_forward_equal_with_pytorch_float():
|
||||
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
||||
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
||||
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
||||
mask /= mask.sum(-1, keepdim=True)
|
||||
mask = mask.reshape(N, H_out, W_out, M * P)
|
||||
|
||||
output_pytorch = dcnv3_core_pytorch(input, offset, mask, Kh, Kw, stride,
|
||||
stride, Kh // 2, Kw // 2, dilation,
|
||||
dilation, M, D, offset_scale,
|
||||
remove_center).detach().cpu()
|
||||
|
||||
im2col_step = 2
|
||||
output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
|
||||
stride, Kh // 2, Kw // 2, dilation,
|
||||
dilation, M, D, offset_scale,
|
||||
im2col_step,
|
||||
remove_center).detach().cpu()
|
||||
|
||||
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
||||
max_rel_err = ((output_cuda - output_pytorch).abs() /
|
||||
output_pytorch.abs()).max()
|
||||
print('>>> forward float')
|
||||
print(f'* {fwdok} check_forward_equal_with_pytorch_float:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
|
||||
def check_backward_equal_with_pytorch_double(channels=4,
|
||||
grad_input=True,
|
||||
grad_offset=True,
|
||||
grad_mask=True):
|
||||
# H_in, W_in = 4, 4
|
||||
N = 2
|
||||
M = 2
|
||||
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
||||
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
||||
|
||||
D = channels
|
||||
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
||||
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
||||
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
||||
mask0 /= mask0.sum(-1, keepdim=True)
|
||||
mask0 = mask0.reshape(N, H_out, W_out, M * P)
|
||||
input0.requires_grad = grad_input
|
||||
offset0.requires_grad = grad_offset
|
||||
mask0.requires_grad = grad_mask
|
||||
|
||||
output_pytorch = dcnv3_core_pytorch(input0.double(), offset0.double(),
|
||||
mask0.double(), Kh, Kw, stride, stride,
|
||||
Kh // 2, Kw // 2, dilation, dilation,
|
||||
M, D, offset_scale, remove_center)
|
||||
output_pytorch.sum().backward()
|
||||
|
||||
input1 = input0.detach()
|
||||
offset1 = offset0.detach()
|
||||
mask1 = mask0.detach()
|
||||
input1.requires_grad = grad_input
|
||||
offset1.requires_grad = grad_offset
|
||||
mask1.requires_grad = grad_mask
|
||||
|
||||
im2col_step = 2
|
||||
output_cuda = DCNv3Function.apply(input1.double(), offset1.double(),
|
||||
mask1.double(), Kh, Kw, stride, stride,
|
||||
Kh // 2, Kw // 2, dilation, dilation, M,
|
||||
D, offset_scale, im2col_step,
|
||||
remove_center)
|
||||
output_cuda.sum().backward()
|
||||
|
||||
print(f'>>> backward double: channels {D}')
|
||||
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (input0.grad - input1.grad).abs().max()
|
||||
max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
|
||||
print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_double:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (offset0.grad - offset1.grad).abs().max()
|
||||
max_rel_err = ((offset0.grad - offset1.grad).abs() /
|
||||
offset0.grad.abs()).max()
|
||||
print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (mask0.grad - mask1.grad).abs().max()
|
||||
max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
|
||||
print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
|
||||
def check_backward_equal_with_pytorch_float(channels=4,
|
||||
grad_input=True,
|
||||
grad_offset=True,
|
||||
grad_mask=True):
|
||||
# H_in, W_in = 4, 4
|
||||
N = 2
|
||||
M = 2
|
||||
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
||||
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
||||
|
||||
D = channels
|
||||
input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
||||
offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
||||
mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
||||
mask0 /= mask0.sum(-1, keepdim=True)
|
||||
mask0 = mask0.reshape(N, H_out, W_out, M * P)
|
||||
input0.requires_grad = grad_input
|
||||
offset0.requires_grad = grad_offset
|
||||
mask0.requires_grad = grad_mask
|
||||
|
||||
output_pytorch = dcnv3_core_pytorch(input0, offset0, mask0, Kh, Kw, stride,
|
||||
stride, Kh // 2, Kw // 2, dilation,
|
||||
dilation, M, D, offset_scale,
|
||||
remove_center)
|
||||
output_pytorch.sum().backward()
|
||||
|
||||
input1 = input0.detach()
|
||||
offset1 = offset0.detach()
|
||||
mask1 = mask0.detach()
|
||||
input1.requires_grad = grad_input
|
||||
offset1.requires_grad = grad_offset
|
||||
mask1.requires_grad = grad_mask
|
||||
|
||||
im2col_step = 2
|
||||
output_cuda = DCNv3Function.apply(input1, offset1, mask1, Kh, Kw, stride,
|
||||
stride, Kh // 2, Kw // 2, dilation,
|
||||
dilation, M, D, offset_scale,
|
||||
im2col_step, remove_center)
|
||||
output_cuda.sum().backward()
|
||||
|
||||
print(f'>>> backward float: channels {D}')
|
||||
bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (input0.grad - input1.grad).abs().max()
|
||||
max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max()
|
||||
print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_float:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (offset0.grad - offset1.grad).abs().max()
|
||||
max_rel_err = ((offset0.grad - offset1.grad).abs() /
|
||||
offset0.grad.abs()).max()
|
||||
print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3)
|
||||
max_abs_err = (mask0.grad - mask1.grad).abs().max()
|
||||
max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max()
|
||||
print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float:'
|
||||
f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def check_time_cost(im2col_step=128):
|
||||
N = 512
|
||||
H_in, W_in = 64, 64
|
||||
H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1
|
||||
W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1
|
||||
|
||||
input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01
|
||||
offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10
|
||||
mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5
|
||||
mask /= mask.sum(-1, keepdim=True)
|
||||
mask = mask.reshape(N, H_out, W_out, M * P)
|
||||
print(f'>>> time cost: im2col_step {im2col_step};'
|
||||
f'input {input.shape}; points {P} ')
|
||||
repeat = 100
|
||||
for i in range(repeat):
|
||||
output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride,
|
||||
stride, Kh // 2, Kw // 2, dilation,
|
||||
dilation, M, D, 1.0, im2col_step,
|
||||
remove_center)
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for i in range(repeat):
|
||||
output_cuda = DCNv3Function.apply( # noqa
|
||||
input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2,
|
||||
dilation, dilation, M, D, 1.0, im2col_step, remove_center)
|
||||
torch.cuda.synchronize()
|
||||
print(f'foward time cost: {(time.time() - start) / repeat}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_forward_equal_with_pytorch_double()
|
||||
check_forward_equal_with_pytorch_float()
|
||||
for channels in [1, 16, 30, 32, 64, 71, 1025]:
|
||||
check_backward_equal_with_pytorch_double(channels, True, True, True)
|
||||
for channels in [1, 16, 30, 32, 64, 71, 1025]:
|
||||
check_backward_equal_with_pytorch_float(channels, True, True, True)
|
||||
for i in range(3):
|
||||
im2col_step = 128 * (2**i)
|
||||
check_time_cost(im2col_step)
|
Loading…
Reference in New Issue