mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* [Feature] add internimage project * [Feature] add internimage project * update license * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * update license * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * [Feature] add internimage project * update internimage configs * support internimage project * support internimage project * support internimage project * internimage
361 lines
12 KiB
Python
361 lines
12 KiB
Python
# --------------------------------------------------------
|
|
# InternImage
|
|
# Copyright (c) 2022 OpenGVLab
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# --------------------------------------------------------
|
|
|
|
# Copied from
|
|
# https://github.com/OpenGVLab/InternImage/blob/master/classification/models/
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.nn.init import constant_, xavier_uniform_
|
|
|
|
from ..functions import DCNv3Function, dcnv3_core_pytorch
|
|
|
|
|
|
class to_channels_first(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.permute(0, 3, 1, 2)
|
|
|
|
|
|
class to_channels_last(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.permute(0, 2, 3, 1)
|
|
|
|
|
|
def build_norm_layer(dim,
|
|
norm_layer,
|
|
in_format='channels_last',
|
|
out_format='channels_last',
|
|
eps=1e-6):
|
|
layers = []
|
|
if norm_layer == 'BN':
|
|
if in_format == 'channels_last':
|
|
layers.append(to_channels_first())
|
|
layers.append(nn.BatchNorm2d(dim))
|
|
if out_format == 'channels_last':
|
|
layers.append(to_channels_last())
|
|
elif norm_layer == 'LN':
|
|
if in_format == 'channels_first':
|
|
layers.append(to_channels_last())
|
|
layers.append(nn.LayerNorm(dim, eps=eps))
|
|
if out_format == 'channels_first':
|
|
layers.append(to_channels_first())
|
|
else:
|
|
raise NotImplementedError(
|
|
f'build_norm_layer does not support {norm_layer}')
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def build_act_layer(act_layer):
|
|
if act_layer == 'ReLU':
|
|
return nn.ReLU(inplace=True)
|
|
elif act_layer == 'SiLU':
|
|
return nn.SiLU(inplace=True)
|
|
elif act_layer == 'GELU':
|
|
return nn.GELU()
|
|
|
|
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
|
|
|
|
|
def _is_power_of_2(n):
|
|
if (not isinstance(n, int)) or (n < 0):
|
|
raise ValueError(
|
|
'invalid input for _is_power_of_2: {} (type: {})'.format(
|
|
n, type(n)))
|
|
|
|
return (n & (n - 1) == 0) and n != 0
|
|
|
|
|
|
class CenterFeatureScaleModule(nn.Module):
|
|
|
|
def forward(self, query, center_feature_scale_proj_weight,
|
|
center_feature_scale_proj_bias):
|
|
center_feature_scale = F.linear(
|
|
query,
|
|
weight=center_feature_scale_proj_weight,
|
|
bias=center_feature_scale_proj_bias).sigmoid()
|
|
return center_feature_scale
|
|
|
|
|
|
class DCNv3_pytorch(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
channels=64,
|
|
kernel_size=3,
|
|
dw_kernel_size=None,
|
|
stride=1,
|
|
pad=1,
|
|
dilation=1,
|
|
group=4,
|
|
offset_scale=1.0,
|
|
act_layer='GELU',
|
|
norm_layer='LN',
|
|
center_feature_scale=False,
|
|
remove_center=False,
|
|
):
|
|
"""DCNv3 Module.
|
|
|
|
:param channels
|
|
:param kernel_size
|
|
:param stride
|
|
:param pad
|
|
:param dilation
|
|
:param group
|
|
:param offset_scale
|
|
:param act_layer
|
|
:param norm_layer
|
|
"""
|
|
super().__init__()
|
|
if channels % group != 0:
|
|
raise ValueError(f'channels must be divisible by group, '
|
|
f'but got {channels} and {group}')
|
|
_d_per_group = channels // group
|
|
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\
|
|
else kernel_size
|
|
# you'd better set _d_per_group to a power of 2
|
|
# which is more efficient in our CUDA implementation
|
|
if not _is_power_of_2(_d_per_group):
|
|
warnings.warn(
|
|
"You'd better set channels in DCNv3 "
|
|
'to make the dimension of each attention head a power of 2 '
|
|
'which is more efficient in our CUDA implementation.')
|
|
|
|
self.offset_scale = offset_scale
|
|
self.channels = channels
|
|
self.kernel_size = kernel_size
|
|
self.dw_kernel_size = dw_kernel_size
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.pad = pad
|
|
self.group = group
|
|
self.group_channels = channels // group
|
|
self.offset_scale = offset_scale
|
|
self.center_feature_scale = center_feature_scale
|
|
self.remove_center = int(remove_center)
|
|
|
|
self.dw_conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
channels,
|
|
channels,
|
|
kernel_size=dw_kernel_size,
|
|
stride=1,
|
|
padding=(dw_kernel_size - 1) // 2,
|
|
groups=channels),
|
|
build_norm_layer(channels, norm_layer, 'channels_first',
|
|
'channels_last'), build_act_layer(act_layer))
|
|
self.offset = nn.Linear(
|
|
channels,
|
|
group * (kernel_size * kernel_size - remove_center) * 2)
|
|
self.mask = nn.Linear(
|
|
channels, group * (kernel_size * kernel_size - remove_center))
|
|
self.input_proj = nn.Linear(channels, channels)
|
|
self.output_proj = nn.Linear(channels, channels)
|
|
self._reset_parameters()
|
|
|
|
if center_feature_scale:
|
|
self.center_feature_scale_proj_weight = nn.Parameter(
|
|
torch.zeros((group, channels), dtype=torch.float))
|
|
self.center_feature_scale_proj_bias = nn.Parameter(
|
|
torch.tensor(0.0, dtype=torch.float).view(
|
|
(1, )).repeat(group, ))
|
|
self.center_feature_scale_module = CenterFeatureScaleModule()
|
|
|
|
def _reset_parameters(self):
|
|
constant_(self.offset.weight.data, 0.)
|
|
constant_(self.offset.bias.data, 0.)
|
|
constant_(self.mask.weight.data, 0.)
|
|
constant_(self.mask.bias.data, 0.)
|
|
xavier_uniform_(self.input_proj.weight.data)
|
|
constant_(self.input_proj.bias.data, 0.)
|
|
xavier_uniform_(self.output_proj.weight.data)
|
|
constant_(self.output_proj.bias.data, 0.)
|
|
|
|
def forward(self, input):
|
|
"""
|
|
:param query (N, H, W, C)
|
|
:return output (N, H, W, C)
|
|
"""
|
|
N, H, W, _ = input.shape
|
|
|
|
x = self.input_proj(input)
|
|
x_proj = x
|
|
|
|
x1 = input.permute(0, 3, 1, 2)
|
|
x1 = self.dw_conv(x1)
|
|
offset = self.offset(x1)
|
|
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
|
mask = F.softmax(mask, -1).reshape(N, H, W, -1)
|
|
|
|
x = dcnv3_core_pytorch(x, offset, mask, self.kernel_size,
|
|
self.kernel_size, self.stride, self.stride,
|
|
self.pad, self.pad, self.dilation,
|
|
self.dilation, self.group, self.group_channels,
|
|
self.offset_scale, self.remove_center)
|
|
if self.center_feature_scale:
|
|
center_feature_scale = self.center_feature_scale_module(
|
|
x1, self.center_feature_scale_proj_weight,
|
|
self.center_feature_scale_proj_bias)
|
|
# N, H, W, groups ->
|
|
# N, H, W, groups, 1 ->
|
|
# N, H, W, groups, _d_per_group ->
|
|
# N, H, W, channels
|
|
center_feature_scale = center_feature_scale[..., None].repeat(
|
|
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
|
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
|
x = self.output_proj(x)
|
|
|
|
return x
|
|
|
|
|
|
class DCNv3(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
channels=64,
|
|
kernel_size=3,
|
|
dw_kernel_size=None,
|
|
stride=1,
|
|
pad=1,
|
|
dilation=1,
|
|
group=4,
|
|
offset_scale=1.0,
|
|
act_layer='GELU',
|
|
norm_layer='LN',
|
|
center_feature_scale=False,
|
|
remove_center=False,
|
|
):
|
|
"""DCNv3 Module.
|
|
|
|
:param channels
|
|
:param kernel_size
|
|
:param stride
|
|
:param pad
|
|
:param dilation
|
|
:param group
|
|
:param offset_scale
|
|
:param act_layer
|
|
:param norm_layer
|
|
"""
|
|
super().__init__()
|
|
if channels % group != 0:
|
|
raise ValueError(f'channels must be divisible by group, '
|
|
f'but got {channels} and {group}')
|
|
_d_per_group = channels // group
|
|
dw_kernel_size = dw_kernel_size if dw_kernel_size is not None\
|
|
else kernel_size
|
|
# you'd better set _d_per_group to a power of 2
|
|
# which is more efficient in our CUDA implementation
|
|
if not _is_power_of_2(_d_per_group):
|
|
warnings.warn(
|
|
"You'd better set channels in DCNv3 "
|
|
'to make the dimension of each attention head a power of 2 '
|
|
'which is more efficient in our CUDA implementation.')
|
|
|
|
self.offset_scale = offset_scale
|
|
self.channels = channels
|
|
self.kernel_size = kernel_size
|
|
self.dw_kernel_size = dw_kernel_size
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.pad = pad
|
|
self.group = group
|
|
self.group_channels = channels // group
|
|
self.offset_scale = offset_scale
|
|
self.center_feature_scale = center_feature_scale
|
|
self.remove_center = int(remove_center)
|
|
|
|
if self.remove_center and self.kernel_size % 2 == 0:
|
|
raise ValueError(
|
|
'remove_center is only compatible with odd kernel size.')
|
|
|
|
self.dw_conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
channels,
|
|
channels,
|
|
kernel_size=dw_kernel_size,
|
|
stride=1,
|
|
padding=(dw_kernel_size - 1) // 2,
|
|
groups=channels),
|
|
build_norm_layer(channels, norm_layer, 'channels_first',
|
|
'channels_last'), build_act_layer(act_layer))
|
|
self.offset = nn.Linear(
|
|
channels,
|
|
group * (kernel_size * kernel_size - remove_center) * 2)
|
|
self.mask = nn.Linear(
|
|
channels, group * (kernel_size * kernel_size - remove_center))
|
|
self.input_proj = nn.Linear(channels, channels)
|
|
self.output_proj = nn.Linear(channels, channels)
|
|
self._reset_parameters()
|
|
|
|
if center_feature_scale:
|
|
self.center_feature_scale_proj_weight = nn.Parameter(
|
|
torch.zeros((group, channels), dtype=torch.float))
|
|
self.center_feature_scale_proj_bias = nn.Parameter(
|
|
torch.tensor(0.0, dtype=torch.float).view(
|
|
(1, )).repeat(group, ))
|
|
self.center_feature_scale_module = CenterFeatureScaleModule()
|
|
|
|
def _reset_parameters(self):
|
|
constant_(self.offset.weight.data, 0.)
|
|
constant_(self.offset.bias.data, 0.)
|
|
constant_(self.mask.weight.data, 0.)
|
|
constant_(self.mask.bias.data, 0.)
|
|
xavier_uniform_(self.input_proj.weight.data)
|
|
constant_(self.input_proj.bias.data, 0.)
|
|
xavier_uniform_(self.output_proj.weight.data)
|
|
constant_(self.output_proj.bias.data, 0.)
|
|
|
|
def forward(self, input):
|
|
"""
|
|
:param query (N, H, W, C)
|
|
:return output (N, H, W, C)
|
|
"""
|
|
N, H, W, _ = input.shape
|
|
|
|
x = self.input_proj(input)
|
|
x_proj = x
|
|
dtype = x.dtype
|
|
|
|
x1 = input.permute(0, 3, 1, 2)
|
|
x1 = self.dw_conv(x1)
|
|
offset = self.offset(x1)
|
|
mask = self.mask(x1).reshape(N, H, W, self.group, -1)
|
|
mask = F.softmax(mask, -1)
|
|
mask = mask.reshape(N, H, W, -1).type(dtype)
|
|
|
|
x = DCNv3Function.apply(x, offset, mask, self.kernel_size,
|
|
self.kernel_size, self.stride, self.stride,
|
|
self.pad, self.pad, self.dilation,
|
|
self.dilation, self.group, self.group_channels,
|
|
self.offset_scale, 256, self.remove_center)
|
|
|
|
if self.center_feature_scale:
|
|
center_feature_scale = self.center_feature_scale_module(
|
|
x1, self.center_feature_scale_proj_weight,
|
|
self.center_feature_scale_proj_bias)
|
|
# N, H, W, groups ->
|
|
# N, H, W, groups, 1 ->
|
|
# N, H, W, groups, _d_per_group ->
|
|
# N, H, W, channels
|
|
center_feature_scale = center_feature_scale[..., None].repeat(
|
|
1, 1, 1, 1, self.channels // self.group).flatten(-2)
|
|
x = x * (1 - center_feature_scale) + x_proj * center_feature_scale
|
|
x = self.output_proj(x)
|
|
|
|
return x
|