diff --git a/mmpretrain/models/classifiers/image.py b/mmpretrain/models/classifiers/image.py index 37fef8bb..a4e9c9bc 100644 --- a/mmpretrain/models/classifiers/image.py +++ b/mmpretrain/models/classifiers/image.py @@ -78,6 +78,12 @@ class ImageClassifier(BaseClassifier): self.neck = neck self.head = head + # If the model needs to load pretrain weights from a third party, + # the key can be modified with this hook + if hasattr(self.backbone, '_checkpoint_filter'): + self._register_load_state_dict_pre_hook( + self.backbone._checkpoint_filter) + def forward(self, inputs: torch.Tensor, data_samples: Optional[List[DataSample]] = None, diff --git a/projects/internimage_classification/README.md b/projects/internimage_classification/README.md new file mode 100644 index 00000000..53b02133 --- /dev/null +++ b/projects/internimage_classification/README.md @@ -0,0 +1,121 @@ +# InternImage Classification + +## Description + +This is the implementation of [InternImage](https://arxiv.org/abs/2211.05778) for image classification. + +## Usage + +### Setup Environment + +Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) documentation of MMPretrain to finish installation. + +Please install DCNv3. Run the command below following the [ InternImage official installation instructions](https://github.com/OpenGVLab/InternImage/blob/master/classification/README.md). + +```shell +cd ops_dcnv3 +sh ./make.sh +``` + +### Training and Test Commands + +At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/internimage_classification/` root directory, please run command below to add it. + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +#### Training + +##### On Local Single GPU + +```bash +# train with mim +mim train mmpretrain ${CONFIG} --work-dir ${WORK_DIR} + +# a specific command example +mim train mmpretrain configs/internimage-tiny_8xb128_in1k-224.py \ + --work-dir work_dirs/internimage-tiny_8xb128_in1k-224/ +``` + +##### On Multiple GPUs + +```bash +# train with mim +mim train mmpretrain ${CONFIG} \ + --work-dir ${WORK_DIR} \ + --launcher pytorch --gpus 8 +``` + +##### On Multiple GPUs with Slurm + +```bash +# train with mim +mim train mmpretrain ${CONFIG} \ + --work-dir ${WORK_DIR} \ + --launcher slurm --gpus 16 --gpus-per-node 8 \ + --partition ${PARTITION} +``` + +#### Test + +Please download the pretrain weight provided by [OpenGVLab](https://github.com/OpenGVLab/) from [here](https://huggingface.co/OpenGVLab/InternImage/tree/main) + +##### On Local Single GPU + +```bash +# test with mim +mim test mmpretrain ${CONFIG} -C ${CHECKPOINT} + +# a specific command example +mim test mmpretrain configs/internimage-tiny_8xb128_in1k-224.py -C /PATH/TO/internimage_t_1k_224.pth +``` + +##### On Multiple GPUs + +```bash +# test with mim +# a specific command examples, 8 GPUs here +mim test mmpretrain configs/internimage_t_1k_224.py \ + -C /PATH/TO/internimage_t_1k_224.pth \ + --launcher pytorch --gpus 8 +``` + +##### On Multiple GPUs with Slurm + +```bash +# test with mim +mim test mmpretrain ${CONFIG} \ + -C ${CHECKPOINT} + --work-dir ${WORK_DIR} \ + --launcher slurm --gpus 8 --gpus-per-node 8 \ + --partition ${PARTITION} \ + $PY_ARGS +``` + +Note: `PY_ARGS` is other optional args. + +## Results on ImageNet1K + +The accuracy of different models on ImageNet1K, + +| name | resolution | acc@1 | acc@5 | config | weight | +| :------------: | :--------: | :-----: | :-----: | :-------------------------------------------------------: | :-----------------------------------------------------------------------------------------------: | +| InternImage-T | 224 | 83.4700 | 96.5340 | [config](./configs/internimage-tiny_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_t_1k_224.pth) | +| InternImage-S | 224 | 84.1640 | 96.9320 | [config](./configs/internimage-small_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_s_1k_224.pth) | +| InternImage-B | 224 | 84.8660 | 97.1820 | [config](./configs/internimage-base_8xb128_in1k-224.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_b_1k_224.pth) | +| InternImage-L | 384 | 87.7060 | 98.3820 | [config](./configs/internimage-large_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_l_22kto1k_384.pth) | +| InternImage-XL | 384 | 88.0460 | 98.5620 | [config](./configs/internimage-xlagre_8xb128_in1k-384.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_xl_22kto1k_384.pth) | +| InternImage-H | 640 | 89.5500 | 98.8500 | [config](./configs/internimage-huge_8xb128_in1k-640.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_h_22kto1k_640.pth) | +| InternImage-G | 512 | 90.0580 | 98.9700 | [config](./configs/internimage-giant_8xb128_in1k-512.py) | [model](https://huggingface.co/OpenGVLab/InternImage/resolve/main/internimage_g_22kto1k_512.pth) | + +## Citation + +```bibtex +@article{wang2022internimage, + title={InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions}, + author={Wang, Wenhai and Dai, Jifeng and Chen, Zhe and Huang, Zhenhang and Li, Zhiqi and Zhu, Xizhou and Hu, Xiaowei and Lu, Tong and Lu, Lewei and Li, Hongsheng and others}, + journal={arXiv preprint arXiv:2211.05778}, + year={2022} +} +``` diff --git a/projects/internimage_classification/configs/_base_.py b/projects/internimage_classification/configs/_base_.py new file mode 100644 index 00000000..4e9b2ada --- /dev/null +++ b/projects/internimage_classification/configs/_base_.py @@ -0,0 +1,113 @@ +_base_ = 'mmpretrain::_base_/default_runtime.py' + +# dataset settings +dataset_type = 'ImageNet' +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=224, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs'), +] + +train_dataloader = dict( + batch_size=128, + num_workers=8, + dataset=dict( + type=dataset_type, + data_root='../../data/imagenet', + data_prefix='train', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), +) + +val_dataloader = dict( + batch_size=128, + num_workers=8, + dataset=dict( + type=dataset_type, + data_root='../../data/imagenet', + data_prefix='val', + pipeline=test_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) +val_evaluator = dict(type='Accuracy', topk=(1, 5)) + +test_dataloader = val_dataloader +test_evaluator = val_evaluator + +# model setting +custom_imports = dict(imports='models') + +model = dict( + type='ImageClassifier', + backbone=dict( + type='InternImage', + stem_channels=64, + drop_path_rate=0.1, + stage_blocks=[4, 4, 18, 4], + groups=[4, 8, 16, 32]), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5))) + +# optimizer +optim_wrapper = dict( + optimizer=dict(type='AdamW', lr=1.25e-04, eps=1e-8, betas=(0.9, 0.999)), + weight_decay=0.05) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type='CosineAnnealingLR', + T_max=280, + by_epoch=True, + begin=20, + end=300, + eta_min=1.25e-06) +] + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128 * 8) diff --git a/projects/internimage_classification/configs/internimage-base_8xb128_in1k-224.py b/projects/internimage_classification/configs/internimage-base_8xb128_in1k-224.py new file mode 100644 index 00000000..735d17e5 --- /dev/null +++ b/projects/internimage_classification/configs/internimage-base_8xb128_in1k-224.py @@ -0,0 +1,13 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=112, + drop_path_rate=0.5, + stage_blocks=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + layer_scale=1e-5, + post_norm=True), + head=dict(in_channels=1344)) + +optim_wrapper = dict(optimizer=dict(lr=0.0005)) diff --git a/projects/internimage_classification/configs/internimage-giant_8xb128_in1k-512.py b/projects/internimage_classification/configs/internimage-giant_8xb128_in1k-512.py new file mode 100644 index 00000000..4ccd34ee --- /dev/null +++ b/projects/internimage_classification/configs/internimage-giant_8xb128_in1k-512.py @@ -0,0 +1,55 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=512, + drop_path_rate=0.4, + stage_blocks=[2, 2, 48, 4], + groups=[16, 32, 64, 128], + dw_kernel_size=5, + level2_post_norm=True, + level2_post_norm_block_ids=[5, 11, 17, 23, 29, 35, 41, 47], + center_feature_scale=True, + use_clip_projector=True, + ), + neck=None, + head=dict(in_channels=768)) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=512, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs'), +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=512, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=512), + dict(type='PackInputs'), +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict(optimizer=dict(lr=5e-6)) +param_scheduler = [ + dict( + type='LinearLR', + by_epoch=True, + begin=0, + end=2, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20) +] +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) diff --git a/projects/internimage_classification/configs/internimage-huge_8xb128_in1k-640.py b/projects/internimage_classification/configs/internimage-huge_8xb128_in1k-640.py new file mode 100644 index 00000000..0e7c8e77 --- /dev/null +++ b/projects/internimage_classification/configs/internimage-huge_8xb128_in1k-640.py @@ -0,0 +1,55 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=320, + drop_path_rate=0.1, + stage_blocks=[6, 6, 32, 6], + groups=[10, 20, 40, 80], + dw_kernel_size=5, + res_post_norm=True, + level2_post_norm=True, + level2_post_norm_block_ids=[5, 11, 17, 23, 29], + center_feature_scale=True, + use_clip_projector=True, + ), + neck=None, + head=dict(in_channels=768)) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=640, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=640, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=640), + dict(type='PackInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict(optimizer=dict(lr=5e-6)) +param_scheduler = [ + dict( + type='LinearLR', + by_epoch=True, + begin=0, + end=2, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20) +] +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) diff --git a/projects/internimage_classification/configs/internimage-large_8xb128_in1k-384.py b/projects/internimage_classification/configs/internimage-large_8xb128_in1k-384.py new file mode 100644 index 00000000..838ec95b --- /dev/null +++ b/projects/internimage_classification/configs/internimage-large_8xb128_in1k-384.py @@ -0,0 +1,51 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=160, + drop_path_rate=0.1, + stage_blocks=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + layer_scale=1e-5, + offset_scale=2.0, + post_norm=True), + head=dict(in_channels=1920)) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=384, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=384), + dict(type='PackInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict(optimizer=dict(lr=5e-6)) +param_scheduler = [ + dict( + type='LinearLR', + by_epoch=True, + begin=0, + end=2, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20) +] +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) diff --git a/projects/internimage_classification/configs/internimage-small_8xb128_in1k-224.py b/projects/internimage_classification/configs/internimage-small_8xb128_in1k-224.py new file mode 100644 index 00000000..ba2075e4 --- /dev/null +++ b/projects/internimage_classification/configs/internimage-small_8xb128_in1k-224.py @@ -0,0 +1,11 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=80, + drop_path_rate=0.4, + stage_blocks=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + layer_scale=1e-5, + post_norm=True), + head=dict(in_channels=960)) diff --git a/projects/internimage_classification/configs/internimage-tiny_8xb128_in1k-224.py b/projects/internimage_classification/configs/internimage-tiny_8xb128_in1k-224.py new file mode 100644 index 00000000..abee278a --- /dev/null +++ b/projects/internimage_classification/configs/internimage-tiny_8xb128_in1k-224.py @@ -0,0 +1,8 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=64, + drop_path_rate=0.1, + stage_blocks=[4, 4, 18, 4], + groups=[4, 8, 16, 32])) diff --git a/projects/internimage_classification/configs/internimage-xlagre_8xb128_in1k-384.py b/projects/internimage_classification/configs/internimage-xlagre_8xb128_in1k-384.py new file mode 100644 index 00000000..dfd494bb --- /dev/null +++ b/projects/internimage_classification/configs/internimage-xlagre_8xb128_in1k-384.py @@ -0,0 +1,50 @@ +_base_ = './_base_.py' + +model = dict( + backbone=dict( + stem_channels=192, + drop_path_rate=0.2, + stage_blocks=[5, 5, 24, 5], + groups=[12, 24, 48, 96], + layer_scale=1e-5, + offset_scale=2.0, + post_norm=True), + head=dict(in_channels=2304)) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=384, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=384, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=384), + dict(type='PackInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict(optimizer=dict(lr=5e-6)) +param_scheduler = [ + dict( + type='LinearLR', + by_epoch=True, + begin=0, + end=2, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=18, by_epoch=True, begin=2, end=20) +] +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) diff --git a/projects/internimage_classification/models/__init__.py b/projects/internimage_classification/models/__init__.py new file mode 100644 index 00000000..99cbcf62 --- /dev/null +++ b/projects/internimage_classification/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .intern_image import InternImage + +__all__ = ['InternImage'] diff --git a/projects/internimage_classification/models/intern_image.py b/projects/internimage_classification/models/intern_image.py new file mode 100644 index 00000000..41c42cc5 --- /dev/null +++ b/projects/internimage_classification/models/intern_image.py @@ -0,0 +1,636 @@ +# Copyright (c) 2022 OpenGVLab +# Copyright (c) OpenMMLab. All rights reserved. +# modified from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/intern_image.py +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn.bricks import DropPath, build_activation_layer +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model.weight_init import trunc_normal_ +from ops_dcnv3 import modules as opsm + +from mmpretrain.models.backbones.base_backbone import BaseBackbone +from mmpretrain.models.utils import CrossMultiheadAttention +from mmpretrain.registry import MODELS + + +class to_channels_first(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 3, 1, 2) + + +class to_channels_last(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 2, 3, 1) + + +def build_norm_layer(dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6): + layers = [] + if norm_layer == 'BN': + if in_format == 'channels_last': + layers.append(to_channels_first()) + layers.append(nn.BatchNorm2d(dim)) + if out_format == 'channels_last': + layers.append(to_channels_last()) + elif norm_layer == 'LN': + if in_format == 'channels_first': + layers.append(to_channels_last()) + layers.append(nn.LayerNorm(dim, eps=eps)) + if out_format == 'channels_first': + layers.append(to_channels_first()) + else: + raise NotImplementedError( + f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) + + +class AttentiveBlock(nn.Module): + """Attentive Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop (float, optional): Dropout rate. Default: 0.0. + attn_drop (float, optional): Attention dropout rate. Default: 0.0. + drop_path (float, optional): Stochastic depth rate. Default: 0.0. + norm_cfg (dict, optional): Normalization layer. + Default: dict(type='LN') + out_dim (int, optional): Dimension of output. Default: None. + """ + + def __init__(self, + dim, + num_heads, + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_cfg=dict(type='LN'), + out_dim=None): + super().__init__() + norm_layer = norm_cfg['type'] + self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6) + + self.cross_dcn = CrossMultiheadAttention( + embed_dims=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + if out_dim and out_dim != dim: + self.cross_dcn.proj = nn.Linear(dim, out_dim) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x_q, x_kv, pos_q, pos_k): + x_q = self.norm1_q(x_q + pos_q) + x_k = self.norm1_k(x_kv + pos_k) + x_v = self.norm1_v(x_kv) + x = self.cross_dcn(x_q, k=x_k, v=x_v) + return x + + +class AttentionPoolingBlock(AttentiveBlock): + + def forward(self, x): + x_q = x.mean(1, keepdim=True) + x_kv = x + pos_q, pos_k = 0, 0 + x = super().forward(x_q, x_kv, pos_q, pos_k) + x = x.squeeze(1) + return x + + +class DownsampleLayer(nn.Module): + """Downsample layer of InternImage. + + Args: + channels (int): number of input channels + norm_layer (str): normalization layer + """ + + def __init__(self, channels, norm_layer='LN'): + super().__init__() + self.conv = nn.Conv2d( + channels, + 2 * channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm = build_norm_layer(2 * channels, norm_layer, + 'channels_first', 'channels_last') + + def forward(self, x): + x = self.conv(x.permute(0, 3, 1, 2)) + x = self.norm(x) + return x + + +class InternImageLayer(nn.Module): + """Basic layer of InternImage. + + Args: + core_op (nn.Module): core operation of InternImage + channels (int): number of input channels + groups (list): Groups of each block. + mlp_ratio (float): ratio of mlp hidden features to input channels + drop (float): dropout rate + drop_path (float): drop path rate + act_cfg (dict): activation layer + norm_cfg (dict): normalization layer + post_norm (bool): whether to use post normalization + layer_scale (float): layer scale + offset_scale (float): offset scale + with_cp (bool): whether to use checkpoint + """ + + def __init__( + self, + core_op, + channels, + groups, + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + post_norm=False, + layer_scale=None, + offset_scale=1.0, + with_cp=False, + dw_kernel_size=None, + res_post_norm=False, + center_feature_scale=False, + remove_center=False, + ): + super().__init__() + self.channels = channels + self.groups = groups + self.mlp_ratio = mlp_ratio + self.with_cp = with_cp + + self.norm1 = build_norm_layer(channels, 'LN') + self.post_norm = post_norm + self.dcn = core_op( + channels=channels, + kernel_size=3, + stride=1, + pad=1, + dilation=1, + group=groups, + offset_scale=offset_scale, + act_layer=act_cfg['type'], + norm_layer=norm_cfg['type'], + dw_kernel_size=dw_kernel_size, + center_feature_scale=center_feature_scale, + remove_center=remove_center, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.norm2 = build_norm_layer(channels, 'LN') + + self.mlp = FFN( + embed_dims=channels, + feedforward_channels=int(channels * mlp_ratio), + act_cfg=act_cfg, + ffn_drop=drop, + add_identity=False) + + self.layer_scale = layer_scale is not None + if self.layer_scale: + self.gamma1 = nn.Parameter( + layer_scale * torch.ones(channels), requires_grad=True) + self.gamma2 = nn.Parameter( + layer_scale * torch.ones(channels), requires_grad=True) + self.res_post_norm = res_post_norm + if res_post_norm: + self.res_post_norm1 = build_norm_layer(channels, 'LN') + self.res_post_norm2 = build_norm_layer(channels, 'LN') + + def forward(self, x): + + def _inner_forward(x): + if not self.layer_scale: + if self.post_norm: + x = x + self.drop_path(self.norm1(self.dcn(x))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + elif self.res_post_norm: + x = x + self.drop_path( + self.res_post_norm1(self.dcn(self.norm1(x)))) + x = x + self.drop_path( + self.res_post_norm2(self.mlp(self.norm2(x)))) + else: + x = x + self.drop_path(self.dcn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + if self.post_norm: + x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x))) + x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x))) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class InternImageBlock(nn.Module): + """Block of InternImage. + + Args: + core_op (nn.Module): core operation of InternImage + channels (int): number of input channels + depths (list): Depth of each block. + groups (list): Groups of each block. + mlp_ratio (float): ratio of mlp hidden features to input channels + drop (float): dropout rate + drop_path (float): drop path rate + act_cfg (dict): activation layer + norm_cfg (dict): normalization layer + post_norm (bool): whether to use post normalization + layer_scale (float): layer scale + offset_scale (float): offset scale + with_cp (bool): whether to use checkpoint + """ + + def __init__( + self, + core_op, + channels, + depth, + groups, + downsample=True, + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + post_norm=False, + offset_scale=1.0, + layer_scale=None, + with_cp=False, + dw_kernel_size=None, + post_norm_block_ids=None, + res_post_norm=False, + center_feature_scale=False, + remove_center=False, + ): + super().__init__() + self.channels = channels + self.depth = depth + self.post_norm = post_norm + self.center_feature_scale = center_feature_scale + + self.blocks = nn.ModuleList([ + InternImageLayer( + core_op=core_op, + channels=channels, + groups=groups, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + post_norm=post_norm, + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + dw_kernel_size=dw_kernel_size, + res_post_norm=res_post_norm, + center_feature_scale=center_feature_scale, + remove_center=remove_center, + ) for i in range(depth) + ]) + if not self.post_norm or center_feature_scale: + self.norm = build_norm_layer(channels, 'LN') + self.post_norm_block_ids = post_norm_block_ids + if post_norm_block_ids is not None: + self.post_norms = nn.ModuleList([ + build_norm_layer(channels, 'LN', eps=1e-6) + for _ in post_norm_block_ids + ]) + self.downsample = DownsampleLayer( + channels=channels, + norm_layer=norm_cfg['type']) if downsample else None + + def forward(self, x, return_wo_downsample=False): + for i, blk in enumerate(self.blocks): + x = blk(x) + if (self.post_norm_block_ids + is not None) and (i in self.post_norm_block_ids): + index = self.post_norm_block_ids.index(i) + x = self.post_norms[index](x) + if not self.post_norm or self.center_feature_scale: + x = self.norm(x) + if return_wo_downsample: + x_ = x + if self.downsample is not None: + x = self.downsample(x) + + if return_wo_downsample: + return x, x_ + return x + + +@MODELS.register_module() +class InternImage(BaseBackbone): + """ InternImage + A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - + https://arxiv.org/pdf/2103.14030 + + Args: + core_op (str): Core operator. Default: 'DCNv3' + stem_channels (int): Number of the first stage. Default: 64 + stage_blocks (list): Depth of each block. Default: [3, 4, 18, 5] + groups (list): Groups of each block. Default: [3, 6, 12, 24] + num_classes (int): Number of classes. Default: 1000 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop_rate (float): Probability of an element to be zeroed. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0. + act_cfg (dict): Activation layer. Default: dict(type='GELU') + norm_cfg (dict): Normalization layer. Default: dict(type='LN') + layer_scale (bool): Whether to use layer scale. Default: False + cls_scale (bool): Whether to use class scale. Default: False + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + dw_kernel_size (int): Size of the dwconv. Default: None + use_clip_projector (bool): Whether to use clip projector. Default: False + level2_post_norm (bool): Whether to use level2 post norm. Default: False + level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None + res_post_norm (bool): Whether to use res post norm. Default: False + center_feature_scale (bool): Whether to use center feature scale. Default: False + """ # noqa: E501 + + def __init__(self, + stem_channels=64, + stage_blocks=[3, 4, 18, 5], + groups=[3, 6, 12, 24], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + drop_path_type='linear', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + layer_scale=None, + offset_scale=1.0, + post_norm=False, + cls_scale=1.5, + with_cp=False, + dw_kernel_size=None, + use_clip_projector=False, + level2_post_norm=False, + level2_post_norm_block_ids=None, + res_post_norm=False, + center_feature_scale=False, + remove_center=False, + init_cfg=None): + super(InternImage, self).__init__(init_cfg) + + self.core_op = 'DCNv3' + self.num_stages = len(stage_blocks) + self.num_features = int(stem_channels * 2**(self.num_stages - 1)) + self.post_norm = post_norm + self.mlp_ratio = mlp_ratio + self.use_clip_projector = use_clip_projector + self.level2_post_norm_block_ids = level2_post_norm_block_ids + self.remove_center = remove_center + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + + # stem layer + self._make_stem_layer(in_channels=3, stem_channels=stem_channels) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + total_depth = sum(stage_blocks) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + if drop_path_type == 'uniform': + for i in range(len(dpr)): + dpr[i] = drop_path_rate + + # InternImage Layers + self.layers = nn.ModuleList() + for i in range(self.num_stages): + if level2_post_norm and i == 2: + post_norm_block_ids = level2_post_norm_block_ids + else: + post_norm_block_ids = None + + layer = InternImageBlock( + core_op=getattr(opsm, self.core_op), + channels=int(stem_channels * 2**i), + depth=stage_blocks[i], + groups=groups[i], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(stage_blocks[:i]):sum(stage_blocks[:i + 1])], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + post_norm=post_norm, + downsample=(i < self.num_stages - 1), + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + dw_kernel_size=dw_kernel_size, + post_norm_block_ids=post_norm_block_ids, + res_post_norm=res_post_norm, + center_feature_scale=center_feature_scale, + remove_center=remove_center, + ) + self.layers.append(layer) + + # Conv Head + if not use_clip_projector: + self.conv_head = nn.Sequential( + nn.Conv2d( + self.num_features, + int(self.num_features * cls_scale), + kernel_size=1, + bias=False), + build_norm_layer( + int(self.num_features * cls_scale), 'BN', 'channels_first', + 'channels_first'), build_activation_layer(act_cfg)) + + else: + pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim \ + = 1024, 2, 16, 768 + self.dcnv3_head_x4 = nn.Sequential( + nn.Conv2d( + in_channels=self.num_features, + out_channels=pretrain_embed_dim * (_stride**2), + kernel_size=1), nn.PixelShuffle(_stride)) + self.dcnv3_head_x3 = nn.Conv2d( + in_channels=self.num_features // 2, + out_channels=pretrain_embed_dim, + kernel_size=1) + self.clip_projector = AttentionPoolingBlock( + dim=pretrain_embed_dim, + num_heads=attnpool_num_heads, + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + norm_cfg=norm_cfg, + out_dim=clip_embed_dim) + norm_layer = norm_cfg['type'] + self.fc_norm = build_norm_layer( + clip_embed_dim, norm_layer, eps=1e-6) + + def init_weights(self): + super(InternImage, self).init_weights() + + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + elif isinstance(m, getattr(opsm, self.core_op)): + m._reset_parameters() + + def _make_stem_layer(self, in_channels, stem_channels): + norm_layer = self.norm_cfg['type'] + self.patch_embed = nn.Sequential( + nn.Conv2d( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1), + build_norm_layer(stem_channels // 2, norm_layer, 'channels_first', + 'channels_first'), + build_activation_layer(self.act_cfg), + nn.Conv2d( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=2, + padding=1), + build_norm_layer(stem_channels, norm_layer, 'channels_first', + 'channels_last'), + ) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.conv_head(x.permute(0, 3, 1, 2)) + return (x, ) + + def forward_features_seq_out(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x) + + seq_out = [] + for layer in self.layers: + x, x_ = layer(x, return_wo_downsample=True) + seq_out.append(x_) + return seq_out + + def forward_clip_projector(self, x): # for InternImage-H/G + xs = self.forward_features_seq_out(x) + x1, x2, x3, x4 = xs + + x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW + x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW + x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW + x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW + + x4 = self.dcnv3_head_x4(x4) + x = x4 + x3 = self.dcnv3_head_x3(x3) + x = x + x3 + + x = x.flatten(-2).transpose(1, 2).contiguous() + x = self.clip_projector(x) + x = self.fc_norm(x) + + return (x, ) + + def forward(self, x): + if not self.use_clip_projector: + # for InternImage-T/S/B/L/XL + return self.forward_features(x) + else: + # for InternImage-H/G + return self.forward_clip_projector(x) + + @staticmethod + def _checkpoint_filter(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + + def internimage_to_mmpretrain(): + for k, v in state_dict['model'].items(): + if 'head.' in k and 'conv_head' not in k: + if 'weight' in k: + new_k = 'head.fc.weight' + else: + new_k = 'head.fc.bias' + elif 'patch_embed' in k: + map_fun = { + 'conv1': '0', + 'norm1': '1', + 'conv2': '3', + 'norm2': '4' + } + new_k = k + for old, new in map_fun.items(): + new_k = new_k.replace(old, new) + new_k = 'backbone.' + new_k + + elif 'levels' in k: + new_k = k.replace('levels', 'layers') + if 'mlp' in new_k: + new_k = new_k.replace('fc1', 'layers.0.0') + new_k = new_k.replace('fc2', 'layers.1') + new_k = 'backbone.' + new_k + elif 'clip_projector.cross_dcn.k_bias' in k: + continue + else: + new_k = 'backbone.' + k + + state_dict[new_k] = state_dict['model'][k] + del state_dict['model'] + + # The original weights need to be converted to mmpretrain format. + # Some modules in the original weights starts with 'levels', + # and in this implement they are replaced with 'layers'. + if 'model' in state_dict and 'levels.0.blocks.0.norm1.0.weight'\ + in state_dict['model']: + internimage_to_mmpretrain() diff --git a/projects/internimage_classification/ops_dcnv3/functions/__init__.py b/projects/internimage_classification/ops_dcnv3/functions/__init__.py new file mode 100644 index 00000000..bc1f5b6e --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/functions/__init__.py @@ -0,0 +1,10 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch # noqa diff --git a/projects/internimage_classification/ops_dcnv3/functions/dcnv3_func.py b/projects/internimage_classification/ops_dcnv3/functions/dcnv3_func.py new file mode 100644 index 00000000..da1b6afe --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/functions/dcnv3_func.py @@ -0,0 +1,248 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +from __future__ import absolute_import, division, print_function +import pkg_resources + +import DCNv3 +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +dcn_version = float(pkg_resources.get_distribution('DCNv3').version) + + +class DCNv3Function(Function): + + @staticmethod + @custom_fwd + def forward(ctx, input, offset, mask, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale, im2col_step, remove_center): + ctx.kernel_h = kernel_h + ctx.kernel_w = kernel_w + ctx.stride_h = stride_h + ctx.stride_w = stride_w + ctx.pad_h = pad_h + ctx.pad_w = pad_w + ctx.dilation_h = dilation_h + ctx.dilation_w = dilation_w + ctx.group = group + ctx.group_channels = group_channels + ctx.offset_scale = offset_scale + ctx.im2col_step = im2col_step + ctx.remove_center = remove_center + + args = [ + input, offset, mask, kernel_h, kernel_w, stride_h, stride_w, pad_h, + pad_w, dilation_h, dilation_w, group, group_channels, offset_scale, + ctx.im2col_step + ] + if remove_center or dcn_version > 1.0: + args.append(remove_center) + + output = DCNv3.dcnv3_forward(*args) + ctx.save_for_backward(input, offset, mask) + + return output + + @staticmethod + @once_differentiable + @custom_bwd + def backward(ctx, grad_output): + input, offset, mask = ctx.saved_tensors + + args = [ + input, offset, mask, ctx.kernel_h, ctx.kernel_w, ctx.stride_h, + ctx.stride_w, ctx.pad_h, ctx.pad_w, ctx.dilation_h, ctx.dilation_w, + ctx.group, ctx.group_channels, ctx.offset_scale, + grad_output.contiguous(), ctx.im2col_step + ] + if ctx.remove_center or dcn_version > 1.0: + args.append(ctx.remove_center) + + grad_input, grad_offset, grad_mask = \ + DCNv3.dcnv3_backward(*args) + + return grad_input, grad_offset, grad_mask, \ + None, None, None, None, None, None, None,\ + None, None, None, None, None, None + + @staticmethod + def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale, im2col_step, remove_center): + """Symbolic function for mmdeploy::DCNv3. + + Returns: + DCNv3 op for onnx. + """ + return g.op( + 'mmdeploy::TRTDCNv3', + input, + offset, + mask, + kernel_h_i=int(kernel_h), + kernel_w_i=int(kernel_w), + stride_h_i=int(stride_h), + stride_w_i=int(stride_w), + pad_h_i=int(pad_h), + pad_w_i=int(pad_w), + dilation_h_i=int(dilation_h), + dilation_w_i=int(dilation_w), + group_i=int(group), + group_channels_i=int(group_channels), + offset_scale_f=float(offset_scale), + im2col_step_i=int(im2col_step), + remove_center=int(remove_center), + ) + + +def _get_reference_points(spatial_shapes, + device, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_h=0, + pad_w=0, + stride_h=1, + stride_w=1): + _, H_, W_, _ = spatial_shapes + H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + + ref_y, ref_x = torch.meshgrid( + torch.linspace( + # pad_h + 0.5, + # H_ - pad_h - 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, + H_out, + dtype=torch.float32, + device=device), + torch.linspace( + # pad_w + 0.5, + # W_ - pad_w - 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, + W_out, + dtype=torch.float32, + device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + + ref = torch.stack((ref_x, ref_y), -1).reshape(1, H_out, W_out, 1, 2) + + return ref + + +def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, + dilation_w, group, device): + _, H_, W_, _ = spatial_shapes + points_list = [] + x, y = torch.meshgrid( + torch.linspace( + -((dilation_w * (kernel_w - 1)) // 2), + -((dilation_w * (kernel_w - 1)) // 2) + + (kernel_w - 1) * dilation_w, + kernel_w, + dtype=torch.float32, + device=device), + torch.linspace( + -((dilation_h * (kernel_h - 1)) // 2), + -((dilation_h * (kernel_h - 1)) // 2) + + (kernel_h - 1) * dilation_h, + kernel_h, + dtype=torch.float32, + device=device)) + + points_list.extend([x / W_, y / H_]) + grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ + repeat(1, group, 1).permute(1, 0, 2) + grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) + + return grid + + +def remove_center_sampling_locations(sampling_locations, kernel_w, kernel_h): + idx = list(range(sampling_locations.shape[-2])) + C = (kernel_w * kernel_h - 1) // 2 + idx = [i for i in idx if i != C and (i - C) % (C * 2 + 1) != 0] + sampling_locations = sampling_locations[:, :, :, idx, :] + return sampling_locations + + +def dcnv3_core_pytorch(input, offset, mask, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale, remove_center): + # for debug and test only, + # need to use cuda version instead + + if remove_center and (kernel_h % 2 == 0 or kernel_w % 2 == 0 + or kernel_w != kernel_h): + raise ValueError( + 'remove_center is only compatible with square odd kernel size.') + + input = F.pad(input, [0, 0, pad_h, pad_h, pad_w, pad_w]) + N_, H_in, W_in, _ = input.shape + _, H_out, W_out, _ = offset.shape + + ref = _get_reference_points(input.shape, input.device, kernel_h, kernel_w, + dilation_h, dilation_w, pad_h, pad_w, stride_h, + stride_w) + grid = _generate_dilation_grids(input.shape, kernel_h, kernel_w, + dilation_h, dilation_w, group, + input.device) + spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ + repeat(1, 1, 1, group*(kernel_h*kernel_w-remove_center)).\ + to(input.device) + + sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1) + if remove_center: + sampling_locations = remove_center_sampling_locations( + sampling_locations, kernel_w=kernel_w, kernel_h=kernel_h) + sampling_locations = sampling_locations.flatten(3, 4) + sampling_locations = sampling_locations + \ + offset * offset_scale / spatial_norm + + P_ = kernel_h * kernel_w - remove_center + sampling_grids = 2 * sampling_locations - 1 + # N_, H_in, W_in, group*group_channels -> + # N_, H_in*W_in, group*group_channels -> + # N_, group*group_channels, H_in*W_in -> + # N_*group, group_channels, H_in, W_in + input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ + reshape(N_*group, group_channels, H_in, W_in) + # N_, H_out, W_out, group*P_*2 -> + # N_, H_out*W_out, group, P_, 2 -> + # N_, group, H_out*W_out, P_, 2 -> + # N_*group, H_out*W_out, P_, 2 + sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).\ + transpose(1, 2).flatten(0, 1) + # N_*group, group_channels, H_out*W_out, P_ + sampling_input_ = F.grid_sample( + input_, + sampling_grid_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + + # (N_, H_out, W_out, group*P_) -> + # N_, H_out*W_out, group, P_ -> + # (N_, group, H_out*W_out, P_) -> + # (N_*group, 1, H_out*W_out, P_) + mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ + reshape(N_*group, 1, H_out*W_out, P_) + output = (sampling_input_ * mask).sum(-1).view(N_, group * group_channels, + H_out * W_out) + + return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() diff --git a/projects/internimage_classification/ops_dcnv3/make.sh b/projects/internimage_classification/ops_dcnv3/make.sh new file mode 100755 index 00000000..31ba0f9b --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/make.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +python setup.py build install diff --git a/projects/internimage_classification/ops_dcnv3/modules/__init__.py b/projects/internimage_classification/ops_dcnv3/modules/__init__.py new file mode 100644 index 00000000..930cd3ff --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/modules/__init__.py @@ -0,0 +1,10 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +from .dcnv3 import DCNv3, DCNv3_pytorch # noqa diff --git a/projects/internimage_classification/ops_dcnv3/modules/dcnv3.py b/projects/internimage_classification/ops_dcnv3/modules/dcnv3.py new file mode 100644 index 00000000..47a369aa --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/modules/dcnv3.py @@ -0,0 +1,360 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +from __future__ import absolute_import, division, print_function +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import constant_, xavier_uniform_ + +from ..functions import DCNv3Function, dcnv3_core_pytorch + + +class to_channels_first(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 3, 1, 2) + + +class to_channels_last(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 2, 3, 1) + + +def build_norm_layer(dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6): + layers = [] + if norm_layer == 'BN': + if in_format == 'channels_last': + layers.append(to_channels_first()) + layers.append(nn.BatchNorm2d(dim)) + if out_format == 'channels_last': + layers.append(to_channels_last()) + elif norm_layer == 'LN': + if in_format == 'channels_first': + layers.append(to_channels_last()) + layers.append(nn.LayerNorm(dim, eps=eps)) + if out_format == 'channels_first': + layers.append(to_channels_first()) + else: + raise NotImplementedError( + f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) + + +def build_act_layer(act_layer): + if act_layer == 'ReLU': + return nn.ReLU(inplace=True) + elif act_layer == 'SiLU': + return nn.SiLU(inplace=True) + elif act_layer == 'GELU': + return nn.GELU() + + raise NotImplementedError(f'build_act_layer does not support {act_layer}') + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + + return (n & (n - 1) == 0) and n != 0 + + +class CenterFeatureScaleModule(nn.Module): + + def forward(self, query, center_feature_scale_proj_weight, + center_feature_scale_proj_bias): + center_feature_scale = F.linear( + query, + weight=center_feature_scale_proj_weight, + bias=center_feature_scale_proj_bias).sigmoid() + return center_feature_scale + + +class DCNv3_pytorch(nn.Module): + + def __init__( + self, + channels=64, + kernel_size=3, + dw_kernel_size=None, + stride=1, + pad=1, + dilation=1, + group=4, + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', + center_feature_scale=False, + remove_center=False, + ): + """DCNv3 Module. + + :param channels + :param kernel_size + :param stride + :param pad + :param dilation + :param group + :param offset_scale + :param act_layer + :param norm_layer + """ + super().__init__() + if channels % group != 0: + raise ValueError(f'channels must be divisible by group, ' + f'but got {channels} and {group}') + _d_per_group = channels // group + dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\ + else kernel_size + # you'd better set _d_per_group to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_group): + warnings.warn( + "You'd better set channels in DCNv3 " + 'to make the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.offset_scale = offset_scale + self.channels = channels + self.kernel_size = kernel_size + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad = pad + self.group = group + self.group_channels = channels // group + self.offset_scale = offset_scale + self.center_feature_scale = center_feature_scale + self.remove_center = int(remove_center) + + self.dw_conv = nn.Sequential( + nn.Conv2d( + channels, + channels, + kernel_size=dw_kernel_size, + stride=1, + padding=(dw_kernel_size - 1) // 2, + groups=channels), + build_norm_layer(channels, norm_layer, 'channels_first', + 'channels_last'), build_act_layer(act_layer)) + self.offset = nn.Linear( + channels, + group * (kernel_size * kernel_size - remove_center) * 2) + self.mask = nn.Linear( + channels, group * (kernel_size * kernel_size - remove_center)) + self.input_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels) + self._reset_parameters() + + if center_feature_scale: + self.center_feature_scale_proj_weight = nn.Parameter( + torch.zeros((group, channels), dtype=torch.float)) + self.center_feature_scale_proj_bias = nn.Parameter( + torch.tensor(0.0, dtype=torch.float).view( + (1, )).repeat(group, )) + self.center_feature_scale_module = CenterFeatureScaleModule() + + def _reset_parameters(self): + constant_(self.offset.weight.data, 0.) + constant_(self.offset.bias.data, 0.) + constant_(self.mask.weight.data, 0.) + constant_(self.mask.bias.data, 0.) + xavier_uniform_(self.input_proj.weight.data) + constant_(self.input_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input): + """ + :param query (N, H, W, C) + :return output (N, H, W, C) + """ + N, H, W, _ = input.shape + + x = self.input_proj(input) + x_proj = x + + x1 = input.permute(0, 3, 1, 2) + x1 = self.dw_conv(x1) + offset = self.offset(x1) + mask = self.mask(x1).reshape(N, H, W, self.group, -1) + mask = F.softmax(mask, -1).reshape(N, H, W, -1) + + x = dcnv3_core_pytorch(x, offset, mask, self.kernel_size, + self.kernel_size, self.stride, self.stride, + self.pad, self.pad, self.dilation, + self.dilation, self.group, self.group_channels, + self.offset_scale, self.remove_center) + if self.center_feature_scale: + center_feature_scale = self.center_feature_scale_module( + x1, self.center_feature_scale_proj_weight, + self.center_feature_scale_proj_bias) + # N, H, W, groups -> + # N, H, W, groups, 1 -> + # N, H, W, groups, _d_per_group -> + # N, H, W, channels + center_feature_scale = center_feature_scale[..., None].repeat( + 1, 1, 1, 1, self.channels // self.group).flatten(-2) + x = x * (1 - center_feature_scale) + x_proj * center_feature_scale + x = self.output_proj(x) + + return x + + +class DCNv3(nn.Module): + + def __init__( + self, + channels=64, + kernel_size=3, + dw_kernel_size=None, + stride=1, + pad=1, + dilation=1, + group=4, + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', + center_feature_scale=False, + remove_center=False, + ): + """DCNv3 Module. + + :param channels + :param kernel_size + :param stride + :param pad + :param dilation + :param group + :param offset_scale + :param act_layer + :param norm_layer + """ + super().__init__() + if channels % group != 0: + raise ValueError(f'channels must be divisible by group, ' + f'but got {channels} and {group}') + _d_per_group = channels // group + dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\ + else kernel_size + # you'd better set _d_per_group to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_group): + warnings.warn( + "You'd better set channels in DCNv3 " + 'to make the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.offset_scale = offset_scale + self.channels = channels + self.kernel_size = kernel_size + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad = pad + self.group = group + self.group_channels = channels // group + self.offset_scale = offset_scale + self.center_feature_scale = center_feature_scale + self.remove_center = int(remove_center) + + if self.remove_center and self.kernel_size % 2 == 0: + raise ValueError( + 'remove_center is only compatible with odd kernel size.') + + self.dw_conv = nn.Sequential( + nn.Conv2d( + channels, + channels, + kernel_size=dw_kernel_size, + stride=1, + padding=(dw_kernel_size - 1) // 2, + groups=channels), + build_norm_layer(channels, norm_layer, 'channels_first', + 'channels_last'), build_act_layer(act_layer)) + self.offset = nn.Linear( + channels, + group * (kernel_size * kernel_size - remove_center) * 2) + self.mask = nn.Linear( + channels, group * (kernel_size * kernel_size - remove_center)) + self.input_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels) + self._reset_parameters() + + if center_feature_scale: + self.center_feature_scale_proj_weight = nn.Parameter( + torch.zeros((group, channels), dtype=torch.float)) + self.center_feature_scale_proj_bias = nn.Parameter( + torch.tensor(0.0, dtype=torch.float).view( + (1, )).repeat(group, )) + self.center_feature_scale_module = CenterFeatureScaleModule() + + def _reset_parameters(self): + constant_(self.offset.weight.data, 0.) + constant_(self.offset.bias.data, 0.) + constant_(self.mask.weight.data, 0.) + constant_(self.mask.bias.data, 0.) + xavier_uniform_(self.input_proj.weight.data) + constant_(self.input_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input): + """ + :param query (N, H, W, C) + :return output (N, H, W, C) + """ + N, H, W, _ = input.shape + + x = self.input_proj(input) + x_proj = x + dtype = x.dtype + + x1 = input.permute(0, 3, 1, 2) + x1 = self.dw_conv(x1) + offset = self.offset(x1) + mask = self.mask(x1).reshape(N, H, W, self.group, -1) + mask = F.softmax(mask, -1) + mask = mask.reshape(N, H, W, -1).type(dtype) + + x = DCNv3Function.apply(x, offset, mask, self.kernel_size, + self.kernel_size, self.stride, self.stride, + self.pad, self.pad, self.dilation, + self.dilation, self.group, self.group_channels, + self.offset_scale, 256, self.remove_center) + + if self.center_feature_scale: + center_feature_scale = self.center_feature_scale_module( + x1, self.center_feature_scale_proj_weight, + self.center_feature_scale_proj_bias) + # N, H, W, groups -> + # N, H, W, groups, 1 -> + # N, H, W, groups, _d_per_group -> + # N, H, W, channels + center_feature_scale = center_feature_scale[..., None].repeat( + 1, 1, 1, 1, self.channels // self.group).flatten(-2) + x = x * (1 - center_feature_scale) + x_proj * center_feature_scale + x = self.output_proj(x) + + return x diff --git a/projects/internimage_classification/ops_dcnv3/setup.py b/projects/internimage_classification/ops_dcnv3/setup.py new file mode 100644 index 00000000..34f8e7a3 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/setup.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +import glob +import os +from setuptools import find_packages, setup + +import torch +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +requirements = ['torch', 'torchvision'] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, 'src') + + main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp')) + source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu')) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {'cxx': []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [('WITH_CUDA', None)] + extra_compile_args['nvcc'] = [ + # "-DCUDA_HAS_FP16=1", + # "-D__CUDA_NO_HALF_OPERATORS__", + # "-D__CUDA_NO_HALF_CONVERSIONS__", + # "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + 'DCNv3', + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + + +setup( + name='DCNv3', + version='1.1', + author='InternImage', + url='https://github.com/OpenGVLab/InternImage', + description='PyTorch Wrapper for CUDA Functions of DCNv3', + packages=find_packages(exclude=( + 'configs', + 'tests', + )), + ext_modules=get_extensions(), + cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension}, +) diff --git a/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp b/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp new file mode 100644 index 00000000..a3bddc18 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.cpp @@ -0,0 +1,37 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + +at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const int im2col_step) { + AT_ERROR("Not implement on cpu"); +} + +std::vector +dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step) { + AT_ERROR("Not implement on cpu"); +} diff --git a/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.h b/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.h new file mode 100644 index 00000000..d457bcbd --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/cpu/dcnv3_cpu.h @@ -0,0 +1,31 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const int im2col_step); + +std::vector +dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step); diff --git a/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu new file mode 100644 index 00000000..f7932480 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.cu @@ -0,0 +1,174 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "cuda/dcnv3_im2col_cuda.cuh" +#include + +#include +#include +#include +#include +#include + +at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, + const float offset_scale, const int im2col_step, const int remove_center) { + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels won't match: (%d vs %d).", + channels, group * group_channels); + + auto output = + at::zeros({batch, height_out, width_out, group * group_channels}, + input.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch / batch_n, batch_n, height_out, + width_out, group * group_channels}); + auto per_input_size = height_in * width_in * group * group_channels; + auto per_offset_size = + height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2; + auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center); + for (int n = 0; n < batch / im2col_step_; ++n) { + auto columns = output_n.select(0, n); + // AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "ms_deform_attn_forward_cuda", ([&] { + dcnv3_im2col_cuda( + at::cuda::getCurrentCUDAStream(), + input.data() + n * im2col_step_ * per_input_size, + offset.data() + + n * im2col_step_ * per_offset_size, + mask.data() + n * im2col_step_ * per_mask_size, + columns.data(), kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, batch_n, height_in, width_in, height_out, + width_out, offset_scale, remove_center); + })); + } + + return output; +} + +std::vector +dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step, const int remove_center) { + + AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); + AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); + AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), + "grad_output tensor has to be contiguous"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), + "grad_output must be a CUDA tensor"); + + const int batch = input.size(0); + const int height_in = input.size(1); + const int width_in = input.size(2); + const int channels = input.size(3); + const int height_out = + (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + + 1; + const int width_out = + (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + + 1; + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, + "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + AT_ASSERTM( + channels == (group * group_channels), + "Input channels and group times group channels won't match: (%d vs %d).", + channels, group * group_channels); + + auto dtype = input.dtype(); + if (dtype == at::kHalf) { + dtype = at::kFloat; + } + + auto grad_input = at::zeros_like(input, dtype); + auto grad_offset = at::zeros_like(offset, dtype); + auto grad_mask = at::zeros_like(mask, dtype); + + const int batch_n = im2col_step_; + auto per_input_size = height_in * width_in * group * group_channels; + auto per_offset_size = + height_out * width_out * group * (kernel_h * kernel_w - remove_center) * 2; + auto per_mask_size = height_out * width_out * group * (kernel_h * kernel_w - remove_center); + auto grad_output_n = + grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, + group, group_channels}); + + for (int n = 0; n < batch / im2col_step_; ++n) { + auto grad_output_g = grad_output_n.select(0, n); + // AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "ms_deform_attn_backward_cuda", ([&] { + dcnv3_col2im_cuda( + at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + input.data() + n * im2col_step_ * per_input_size, + offset.data() + + n * im2col_step_ * per_offset_size, + mask.data() + n * im2col_step_ * per_mask_size, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, batch_n, + height_in, width_in, height_out, width_out, offset_scale, remove_center, + grad_input.data() + + n * im2col_step_ * per_input_size, + grad_offset.data() + + n * im2col_step_ * per_offset_size, + grad_mask.data() + + n * im2col_step_ * per_mask_size); + })); + } + + if (input.dtype() == torch::kHalf) { + return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf), + grad_mask.to(torch::kHalf)}; + } else { + return {grad_input, grad_offset, grad_mask}; + } +} diff --git a/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.h b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.h new file mode 100644 index 00000000..d7ac0244 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_cuda.h @@ -0,0 +1,31 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, + const float offset_scale, const int im2col_step, const int remove_center); + +std::vector +dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, + const int group_channels, const float offset_scale, + const at::Tensor &grad_output, const int im2col_step, const int remove_center); diff --git a/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh new file mode 100644 index 00000000..ab6da3c5 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/cuda/dcnv3_im2col_cuda.cuh @@ -0,0 +1,1094 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include +#include + +#include +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 256; +inline int GET_BLOCKS(const int N, const int num_threads) { + return (N + num_threads - 1) / num_threads; +} + +#define opmath_t at::opmath_type + +template +__device__ opmath_t dcnv3_im2col_bilinear(const scalar_t *&bottom_data, + const int &height, const int &width, + const int &group, + const int &group_channels, + const opmath_t &h, const opmath_t &w, + const int &g, const int &c) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = group * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = g * group_channels + c; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ void dcnv3_col2im_bilinear( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &group_channels, const opmath_t &h, + const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale, + const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im, + opmath_t *grad_offset, opmath_t *grad_mask) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * group_channels + c; + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const opmath_t top_grad_im = top_grad * mask; + opmath_t grad_h_weight = 0, grad_w_weight = 0; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_im + ptr1, w1 * top_grad_im); + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_im + ptr2, w2 * top_grad_im); + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_im + ptr3, w3 * top_grad_im); + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_im + ptr4, w4 * top_grad_im); + } + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_mask = top_grad * val; + *grad_offset = offset_scale * grad_w_weight * top_grad_im; + *(grad_offset + 1) = offset_scale * grad_h_weight * top_grad_im; +} + +template +__device__ void dcnv3_col2im_bilinear_gm( + const scalar_t *&bottom_data, const int &height, const int &width, + const int &nheads, const int &group_channels, const opmath_t &h, + const opmath_t &w, const int &m, const int &c, const opmath_t offset_scale, + const opmath_t &top_grad, const opmath_t &mask, opmath_t *&grad_im, + opmath_t *grad_offset, opmath_t *grad_mask) { + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const opmath_t lh = h - h_low; + const opmath_t lw = w - w_low; + const opmath_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * group_channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * group_channels + c; + + const opmath_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const opmath_t top_grad_im = top_grad * mask; + opmath_t grad_h_weight = 0, grad_w_weight = 0; + + opmath_t v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_im + ptr1, w1 * top_grad_im); + } + opmath_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_im + ptr2, w2 * top_grad_im); + } + opmath_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_im + ptr3, w3 * top_grad_im); + } + opmath_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_im + ptr4, w4 * top_grad_im); + } + + const opmath_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_mask, top_grad * val); + atomicAdd(grad_offset, offset_scale * grad_w_weight * top_grad_im); + atomicAdd(grad_offset + 1, offset_scale * grad_h_weight * top_grad_im); +} + +template +__global__ void dcnv3_im2col_gpu_kernel( + const int num_kernels, const scalar_t *data_im, const scalar_t *data_offset, + const scalar_t *data_mask, scalar_t *data_col, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center) { + CUDA_KERNEL_LOOP(index, num_kernels) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const int input_size = height_in * width_in; + scalar_t *data_col_ptr = data_col + index; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = group * group_channels; + opmath_t col = 0; + const scalar_t *data_im_ptr = data_im + b_col * input_size * qid_stride; + // top-left + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + col += dcnv3_im2col_bilinear( + data_im_ptr, height_in, width_in, group, + group_channels, loc_h, loc_w, g_col, c_col) * + weight; + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + } + *data_col_ptr = col; + } +} + +// debug +template +__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + __shared__ opmath_t cache_grad_offset[blockSize * 2]; + __shared__ opmath_t cache_grad_mask[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + opmath_t _grad_w = cache_grad_offset[0], + _grad_h = cache_grad_offset[1], + _grad_a = cache_grad_mask[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockSize; ++tid) { + _grad_w += cache_grad_offset[sid]; + _grad_h += cache_grad_offset[sid + 1]; + _grad_a += cache_grad_mask[tid]; + sid += 2; + } + + *grad_offset = _grad_w; + *(grad_offset + 1) = _grad_h; + *grad_mask = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + __shared__ opmath_t cache_grad_offset[blockSize * 2]; + __shared__ opmath_t cache_grad_mask[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockSize / 2; s > 0; s >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) { + *grad_offset = cache_grad_offset[0]; + *(grad_offset + 1) = cache_grad_offset[1]; + *grad_mask = cache_grad_mask[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v1( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + if (tid == 0) { + opmath_t _grad_w = cache_grad_offset[0], + _grad_h = cache_grad_offset[1], + _grad_a = cache_grad_mask[0]; + int sid = 2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) { + _grad_w += cache_grad_offset[sid]; + _grad_h += cache_grad_offset[sid + 1]; + _grad_a += cache_grad_mask[tid]; + sid += 2; + } + + *grad_offset = _grad_w; + *(grad_offset + 1) = _grad_h; + *grad_mask = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_mask[tid] += + cache_grad_mask[tid + (s << 1)]; + cache_grad_offset[xid1] += + cache_grad_offset[xid2 + (s << 1)]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + *grad_offset = cache_grad_offset[0]; + *(grad_offset + 1) = cache_grad_offset[1]; + *grad_mask = cache_grad_mask[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + extern __shared__ int _s[]; + opmath_t *cache_grad_offset = (opmath_t *)_s; + opmath_t *cache_grad_mask = cache_grad_offset + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + *(cache_grad_offset + (threadIdx.x << 1)) = 0; + *(cache_grad_offset + ((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_mask + threadIdx.x) = 0; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, + cache_grad_offset + (threadIdx.x << 1), + cache_grad_mask + threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0; + s >>= 1, spre >>= 1) { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_mask[tid] += cache_grad_mask[tid + s]; + cache_grad_offset[xid1] += cache_grad_offset[xid2]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1]; + if (tid + (s << 1) < spre) { + cache_grad_mask[tid] += + cache_grad_mask[tid + (s << 1)]; + cache_grad_offset[xid1] += + cache_grad_offset[xid2 + (s << 1)]; + cache_grad_offset[xid1 + 1] += + cache_grad_offset[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(grad_offset, cache_grad_offset[0]); + atomicAdd(grad_offset + 1, cache_grad_offset[1]); + atomicAdd(grad_mask, cache_grad_mask[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +__global__ void dcnv3_col2im_gpu_kernel_gm( + const int num_kernels, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int height_in, + const int width_in, const int height_out, const int width_out, + const opmath_t offset_scale, const int remove_center, opmath_t *grad_im, opmath_t *grad_offset, + opmath_t *grad_mask) { + CUDA_KERNEL_LOOP(index, num_kernels) { + int _temp = index; + const int c_col = _temp % group_channels; + _temp /= group_channels; + const int sampling_index = _temp; + const int g_col = _temp % group; + _temp /= group; + const int p0_w = ((dilation_w * (kernel_w - 1)) >> 1) - pad_w + + (_temp % width_out) * stride_w; + _temp /= width_out; + const int p0_h = ((dilation_h * (kernel_h - 1)) >> 1) - pad_h + + (_temp % height_out) * stride_h; + _temp /= height_out; + const int b_col = _temp; + + const opmath_t top_grad = grad_col[index]; + const int input_size = height_in * width_in; + const int kernel_size = kernel_h * kernel_w - remove_center; + int data_weight_ptr = sampling_index * kernel_size; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_offset += grad_sampling_ptr << 1; + grad_mask += grad_sampling_ptr; + const int qid_stride = group * group_channels; + const int im_ptr_offset = b_col * input_size * qid_stride; + const scalar_t *data_im_ptr = data_im + im_ptr_offset; + opmath_t *grad_im_ptr = grad_im + im_ptr_offset; + const opmath_t p0_w_ = + p0_w - ((dilation_w * (kernel_w - 1)) >> 1) * offset_scale; + const opmath_t p0_h_ = + p0_h - ((dilation_h * (kernel_h - 1)) >> 1) * offset_scale; + + const int center_h = kernel_h / 2; + const int center_w = kernel_w / 2; + + for (int i = 0; i < kernel_w; ++i) { + for (int j = 0; j < kernel_h; ++j) { + // if not remove center, or remove center and not the center + if (i!=center_w || j!=center_h || !remove_center) { + const opmath_t offset_w = data_offset[data_loc_w_ptr]; + const opmath_t offset_h = data_offset[data_loc_w_ptr + 1]; + const opmath_t loc_w = + p0_w_ + (i * dilation_w + offset_w) * offset_scale; + const opmath_t loc_h = + p0_h_ + (j * dilation_h + offset_h) * offset_scale; + const opmath_t weight = data_mask[data_weight_ptr]; + if (loc_h > -1 && loc_w > -1 && loc_h < height_in && + loc_w < width_in) { + dcnv3_col2im_bilinear_gm( + data_im_ptr, height_in, width_in, group, group_channels, + loc_h, loc_w, g_col, c_col, offset_scale, top_grad, + weight, grad_im_ptr, grad_offset, grad_mask); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_mask += 1; + grad_offset += 2; + } + } + } + } +} + +template +void dcnv3_im2col_cuda(cudaStream_t stream, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + scalar_t *data_col, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const int batch_n, const int height_in, + const int width_in, const int height_out, + const int width_out, const opmath_t offset_scale, const int remove_center) { + const int num_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_actual_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_threads = CUDA_NUM_THREADS; + dcnv3_im2col_gpu_kernel + <<>>(num_kernels, data_im, data_offset, data_mask, data_col, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, height_in, + width_in, height_out, width_out, offset_scale, remove_center); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in dcnv3_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +template +void dcnv3_col2im_cuda( + cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, const int batch_n, + const int height_in, const int width_in, const int height_out, + const int width_out, const opmath_t offset_scale, const int remove_center, + opmath_t *grad_im, opmath_t *grad_offset, opmath_t *grad_mask) { + const int num_threads = + (group_channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : group_channels; + const int num_kernels = + batch_n * height_out * width_out * group * group_channels; + const int num_actual_kernels = + batch_n * height_out * width_out * group * group_channels; + if (group_channels > 1024) { + if ((group_channels & 1023) == 0) { + dcnv3_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, height_in, + width_in, height_out, width_out, offset_scale, remove_center, grad_im, + grad_offset, grad_mask); + } else { + dcnv3_col2im_gpu_kernel_gm + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + } + } else { + switch (group_channels) { + case 1: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 2: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 4: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 8: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 16: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 32: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 64: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 128: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 256: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 512: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + case 1024: + dcnv3_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>(num_kernels, grad_col, data_im, data_offset, + data_mask, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, + group_channels, height_in, width_in, height_out, + width_out, offset_scale, remove_center, grad_im, grad_offset, + grad_mask); + break; + default: + if (group_channels < 64) { + dcnv3_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + height_in, width_in, height_out, width_out, + offset_scale, remove_center, grad_im, grad_offset, grad_mask); + } else { + dcnv3_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, grad_col, data_im, data_offset, data_mask, + kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, + dilation_h, dilation_w, group, group_channels, + height_in, width_in, height_out, width_out, + offset_scale, remove_center, grad_im, grad_offset, grad_mask); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in dcnv3_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/projects/internimage_classification/ops_dcnv3/src/dcnv3.h b/projects/internimage_classification/ops_dcnv3/src/dcnv3.h new file mode 100644 index 00000000..ce4500fa --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/dcnv3.h @@ -0,0 +1,59 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/dcnv3_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/dcnv3_cuda.h" +#endif + +at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const float offset_scale, const int im2col_step, const int remove_center) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, im2col_step, remove_center); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, + const at::Tensor &mask, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, + const int pad_w, const int dilation_h, const int dilation_w, + const int group, const int group_channels, + const float offset_scale, const at::Tensor &grad_output, + const int im2col_step, const int remove_center) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, group, group_channels, + offset_scale, grad_output, im2col_step, remove_center); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/projects/internimage_classification/ops_dcnv3/src/vision.cpp b/projects/internimage_classification/ops_dcnv3/src/vision.cpp new file mode 100644 index 00000000..1f7a9087 --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/src/vision.cpp @@ -0,0 +1,17 @@ +/*! +************************************************************************************************** +* InternImage +* Copyright (c) 2022 OpenGVLab +* Licensed under The MIT License [see LICENSE for details] +************************************************************************************************** +* Modified from +*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "dcnv3.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward"); + m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward"); +} diff --git a/projects/internimage_classification/ops_dcnv3/test.py b/projects/internimage_classification/ops_dcnv3/test.py new file mode 100644 index 00000000..b9c2204d --- /dev/null +++ b/projects/internimage_classification/ops_dcnv3/test.py @@ -0,0 +1,255 @@ +# -------------------------------------------------------- +# InternImage +# Copyright (c) 2022 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +# Copied from +# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/ + +from __future__ import absolute_import, division, print_function +import math # noqa +import time + +import torch +import torch.nn as nn # noqa +from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch +from torch.autograd import gradcheck # noqa + +H_in, W_in = 8, 8 +N, M, D = 2, 4, 16 +Kh, Kw = 3, 3 +remove_center = False +P = Kh * Kw - remove_center +offset_scale = 2.0 +pad = 1 +dilation = 1 +stride = 1 +H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 +W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M * P) + + output_pytorch = dcnv3_core_pytorch(input.double(), offset.double(), + mask.double(), Kh, Kw, stride, stride, + Kh // 2, Kw // 2, dilation, dilation, + M, D, offset_scale, + remove_center).detach().cpu() + + im2col_step = 2 + output_cuda = DCNv3Function.apply(input.double(), offset.double(), + mask.double(), Kh, Kw, stride, stride, + Kh // 2, Kw // 2, dilation, dilation, M, + D, offset_scale, im2col_step, + remove_center).detach().cpu() + + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + print('>>> forward double') + print(f'* {fwdok} check_forward_equal_with_pytorch_double:' + f' max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M * P) + + output_pytorch = dcnv3_core_pytorch(input, offset, mask, Kh, Kw, stride, + stride, Kh // 2, Kw // 2, dilation, + dilation, M, D, offset_scale, + remove_center).detach().cpu() + + im2col_step = 2 + output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride, + stride, Kh // 2, Kw // 2, dilation, + dilation, M, D, offset_scale, + im2col_step, + remove_center).detach().cpu() + + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / + output_pytorch.abs()).max() + print('>>> forward float') + print(f'* {fwdok} check_forward_equal_with_pytorch_float:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_backward_equal_with_pytorch_double(channels=4, + grad_input=True, + grad_offset=True, + grad_mask=True): + # H_in, W_in = 4, 4 + N = 2 + M = 2 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + D = channels + input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01 + offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10 + mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask0 /= mask0.sum(-1, keepdim=True) + mask0 = mask0.reshape(N, H_out, W_out, M * P) + input0.requires_grad = grad_input + offset0.requires_grad = grad_offset + mask0.requires_grad = grad_mask + + output_pytorch = dcnv3_core_pytorch(input0.double(), offset0.double(), + mask0.double(), Kh, Kw, stride, stride, + Kh // 2, Kw // 2, dilation, dilation, + M, D, offset_scale, remove_center) + output_pytorch.sum().backward() + + input1 = input0.detach() + offset1 = offset0.detach() + mask1 = mask0.detach() + input1.requires_grad = grad_input + offset1.requires_grad = grad_offset + mask1.requires_grad = grad_mask + + im2col_step = 2 + output_cuda = DCNv3Function.apply(input1.double(), offset1.double(), + mask1.double(), Kh, Kw, stride, stride, + Kh // 2, Kw // 2, dilation, dilation, M, + D, offset_scale, im2col_step, + remove_center) + output_cuda.sum().backward() + + print(f'>>> backward double: channels {D}') + bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (input0.grad - input1.grad).abs().max() + max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max() + print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_double:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (offset0.grad - offset1.grad).abs().max() + max_rel_err = ((offset0.grad - offset1.grad).abs() / + offset0.grad.abs()).max() + print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (mask0.grad - mask1.grad).abs().max() + max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max() + print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_backward_equal_with_pytorch_float(channels=4, + grad_input=True, + grad_offset=True, + grad_mask=True): + # H_in, W_in = 4, 4 + N = 2 + M = 2 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + D = channels + input0 = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01 + offset0 = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10 + mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask0 /= mask0.sum(-1, keepdim=True) + mask0 = mask0.reshape(N, H_out, W_out, M * P) + input0.requires_grad = grad_input + offset0.requires_grad = grad_offset + mask0.requires_grad = grad_mask + + output_pytorch = dcnv3_core_pytorch(input0, offset0, mask0, Kh, Kw, stride, + stride, Kh // 2, Kw // 2, dilation, + dilation, M, D, offset_scale, + remove_center) + output_pytorch.sum().backward() + + input1 = input0.detach() + offset1 = offset0.detach() + mask1 = mask0.detach() + input1.requires_grad = grad_input + offset1.requires_grad = grad_offset + mask1.requires_grad = grad_mask + + im2col_step = 2 + output_cuda = DCNv3Function.apply(input1, offset1, mask1, Kh, Kw, stride, + stride, Kh // 2, Kw // 2, dilation, + dilation, M, D, offset_scale, + im2col_step, remove_center) + output_cuda.sum().backward() + + print(f'>>> backward float: channels {D}') + bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (input0.grad - input1.grad).abs().max() + max_rel_err = ((input0.grad - input1.grad).abs() / input0.grad.abs()).max() + print(f'* {bwdok} input_grad check_backward_equal_with_pytorch_float:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (offset0.grad - offset1.grad).abs().max() + max_rel_err = ((offset0.grad - offset1.grad).abs() / + offset0.grad.abs()).max() + print(f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) + max_abs_err = (mask0.grad - mask1.grad).abs().max() + max_rel_err = ((mask0.grad - mask1.grad).abs() / mask0.grad.abs()).max() + print(f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float:' + f'max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_time_cost(im2col_step=128): + N = 512 + H_in, W_in = 64, 64 + H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 + W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 + + input = torch.rand(N, H_in, W_in, M * D).cuda() * 0.01 + offset = torch.rand(N, H_out, W_out, M * P * 2).cuda() * 10 + mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 + mask /= mask.sum(-1, keepdim=True) + mask = mask.reshape(N, H_out, W_out, M * P) + print(f'>>> time cost: im2col_step {im2col_step};' + f'input {input.shape}; points {P} ') + repeat = 100 + for i in range(repeat): + output_cuda = DCNv3Function.apply(input, offset, mask, Kh, Kw, stride, + stride, Kh // 2, Kw // 2, dilation, + dilation, M, D, 1.0, im2col_step, + remove_center) + torch.cuda.synchronize() + start = time.time() + for i in range(repeat): + output_cuda = DCNv3Function.apply( # noqa + input, offset, mask, Kh, Kw, stride, stride, Kh // 2, Kw // 2, + dilation, dilation, M, D, 1.0, im2col_step, remove_center) + torch.cuda.synchronize() + print(f'foward time cost: {(time.time() - start) / repeat}') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + for channels in [1, 16, 30, 32, 64, 71, 1025]: + check_backward_equal_with_pytorch_double(channels, True, True, True) + for channels in [1, 16, 30, 32, 64, 71, 1025]: + check_backward_equal_with_pytorch_float(channels, True, True, True) + for i in range(3): + im2col_step = 128 * (2**i) + check_time_cost(im2col_step)