853 lines
31 KiB
Python
853 lines
31 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|||
|
# modified from
|
|||
|
# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
|
|||
|
from typing import Sequence
|
|||
|
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
from mmcv.cnn import build_activation_layer
|
|||
|
from mmcv.cnn.bricks import DropPath
|
|||
|
from mmengine.model import ModuleList, Sequential
|
|||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|||
|
|
|||
|
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
|||
|
from mmpretrain.registry import MODELS
|
|||
|
from ..utils import build_norm_layer
|
|||
|
|
|||
|
|
|||
|
def get_2d_relative_pos_embed(embed_dim, grid_size):
|
|||
|
"""
|
|||
|
grid_size: int of the grid height and width
|
|||
|
return:
|
|||
|
pos_embed: [grid_size*grid_size, grid_size*grid_size]
|
|||
|
"""
|
|||
|
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size)
|
|||
|
relative_pos = 2 * np.matmul(pos_embed,
|
|||
|
pos_embed.transpose()) / pos_embed.shape[1]
|
|||
|
return relative_pos
|
|||
|
|
|||
|
|
|||
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
|||
|
"""
|
|||
|
grid_size: int of the grid height and width
|
|||
|
return:
|
|||
|
pos_embed: [grid_size*grid_size, embed_dim] or
|
|||
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|||
|
"""
|
|||
|
grid_h = np.arange(grid_size, dtype=np.float32)
|
|||
|
grid_w = np.arange(grid_size, dtype=np.float32)
|
|||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|||
|
grid = np.stack(grid, axis=0)
|
|||
|
|
|||
|
grid = grid.reshape([2, 1, grid_size, grid_size])
|
|||
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|||
|
if cls_token:
|
|||
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
|
|||
|
axis=0)
|
|||
|
return pos_embed
|
|||
|
|
|||
|
|
|||
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|||
|
assert embed_dim % 2 == 0
|
|||
|
|
|||
|
# use half of dimensions to encode grid_h
|
|||
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
|
|||
|
grid[0]) # (H*W, D/2)
|
|||
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
|
|||
|
grid[1]) # (H*W, D/2)
|
|||
|
|
|||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|||
|
return emb
|
|||
|
|
|||
|
|
|||
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|||
|
"""
|
|||
|
embed_dim: output dimension for each position
|
|||
|
pos: a list of positions to be encoded: size (M,)
|
|||
|
out: (M, D)
|
|||
|
"""
|
|||
|
assert embed_dim % 2 == 0
|
|||
|
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
|||
|
omega /= embed_dim / 2.
|
|||
|
omega = 1. / 10000**omega # (D/2,)
|
|||
|
|
|||
|
pos = pos.reshape(-1) # (M,)
|
|||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
|||
|
|
|||
|
emb_sin = np.sin(out) # (M, D/2)
|
|||
|
emb_cos = np.cos(out) # (M, D/2)
|
|||
|
|
|||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|||
|
return emb
|
|||
|
|
|||
|
|
|||
|
def xy_pairwise_distance(x, y):
|
|||
|
"""Compute pairwise distance of a point cloud.
|
|||
|
|
|||
|
Args:
|
|||
|
x: tensor (batch_size, num_points, num_dims)
|
|||
|
y: tensor (batch_size, num_points, num_dims)
|
|||
|
Returns:
|
|||
|
pairwise distance: (batch_size, num_points, num_points)
|
|||
|
"""
|
|||
|
with torch.no_grad():
|
|||
|
xy_inner = -2 * torch.matmul(x, y.transpose(2, 1))
|
|||
|
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
|
|||
|
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
|
|||
|
return x_square + xy_inner + y_square.transpose(2, 1)
|
|||
|
|
|||
|
|
|||
|
def xy_dense_knn_matrix(x, y, k=16, relative_pos=None):
|
|||
|
"""Get KNN based on the pairwise distance.
|
|||
|
|
|||
|
Args:
|
|||
|
x: (batch_size, num_dims, num_points, 1)
|
|||
|
y: (batch_size, num_dims, num_points, 1)
|
|||
|
k: int
|
|||
|
relative_pos:Whether to use relative_pos
|
|||
|
Returns:
|
|||
|
nearest neighbors:
|
|||
|
(batch_size, num_points, k) (batch_size, num_points, k)
|
|||
|
"""
|
|||
|
with torch.no_grad():
|
|||
|
x = x.transpose(2, 1).squeeze(-1)
|
|||
|
y = y.transpose(2, 1).squeeze(-1)
|
|||
|
batch_size, n_points, n_dims = x.shape
|
|||
|
dist = xy_pairwise_distance(x.detach(), y.detach())
|
|||
|
if relative_pos is not None:
|
|||
|
dist += relative_pos
|
|||
|
_, nn_idx = torch.topk(-dist, k=k)
|
|||
|
center_idx = torch.arange(
|
|||
|
0, n_points, device=x.device).repeat(batch_size, k,
|
|||
|
1).transpose(2, 1)
|
|||
|
return torch.stack((nn_idx, center_idx), dim=0)
|
|||
|
|
|||
|
|
|||
|
class DenseDilated(nn.Module):
|
|||
|
"""Find dilated neighbor from neighbor list.
|
|||
|
|
|||
|
edge_index: (2, batch_size, num_points, k)
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0):
|
|||
|
super(DenseDilated, self).__init__()
|
|||
|
self.dilation = dilation
|
|||
|
self.use_stochastic = use_stochastic
|
|||
|
self.epsilon = epsilon
|
|||
|
self.k = k
|
|||
|
|
|||
|
def forward(self, edge_index):
|
|||
|
if self.use_stochastic:
|
|||
|
if torch.rand(1) < self.epsilon and self.training:
|
|||
|
num = self.k * self.dilation
|
|||
|
randnum = torch.randperm(num)[:self.k]
|
|||
|
edge_index = edge_index[:, :, :, randnum]
|
|||
|
else:
|
|||
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
|||
|
else:
|
|||
|
edge_index = edge_index[:, :, :, ::self.dilation]
|
|||
|
return edge_index
|
|||
|
|
|||
|
|
|||
|
class DenseDilatedKnnGraph(nn.Module):
|
|||
|
"""Find the neighbors' indices based on dilated knn."""
|
|||
|
|
|||
|
def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0):
|
|||
|
super(DenseDilatedKnnGraph, self).__init__()
|
|||
|
self.dilation = dilation
|
|||
|
self.use_stochastic = use_stochastic
|
|||
|
self.epsilon = epsilon
|
|||
|
self.k = k
|
|||
|
self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon)
|
|||
|
|
|||
|
def forward(self, x, y=None, relative_pos=None):
|
|||
|
if y is not None:
|
|||
|
x = F.normalize(x, p=2.0, dim=1)
|
|||
|
y = F.normalize(y, p=2.0, dim=1)
|
|||
|
|
|||
|
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation,
|
|||
|
relative_pos)
|
|||
|
else:
|
|||
|
x = F.normalize(x, p=2.0, dim=1)
|
|||
|
y = x.clone()
|
|||
|
|
|||
|
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation,
|
|||
|
relative_pos)
|
|||
|
return self._dilated(edge_index)
|
|||
|
|
|||
|
|
|||
|
class BasicConv(Sequential):
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
channels,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True,
|
|||
|
drop=0.):
|
|||
|
m = []
|
|||
|
for i in range(1, len(channels)):
|
|||
|
m.append(
|
|||
|
nn.Conv2d(
|
|||
|
channels[i - 1],
|
|||
|
channels[i],
|
|||
|
1,
|
|||
|
bias=graph_conv_bias,
|
|||
|
groups=4))
|
|||
|
if norm_cfg is not None:
|
|||
|
m.append(build_norm_layer(norm_cfg, channels[-1]))
|
|||
|
if act_cfg is not None:
|
|||
|
m.append(build_activation_layer(act_cfg))
|
|||
|
if drop > 0:
|
|||
|
m.append(nn.Dropout2d(drop))
|
|||
|
|
|||
|
super(BasicConv, self).__init__(*m)
|
|||
|
|
|||
|
|
|||
|
def batched_index_select(x, idx):
|
|||
|
r"""fetches neighbors features from a given neighbor idx
|
|||
|
|
|||
|
Args:
|
|||
|
x (Tensor): input feature Tensor
|
|||
|
:math:
|
|||
|
`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
|
|||
|
idx (Tensor): edge_idx
|
|||
|
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
|
|||
|
Returns:
|
|||
|
Tensor: output neighbors features
|
|||
|
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
|
|||
|
"""
|
|||
|
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
|
|||
|
_, num_vertices, k = idx.shape
|
|||
|
idx_base = torch.arange(
|
|||
|
0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
|
|||
|
idx = idx + idx_base
|
|||
|
idx = idx.contiguous().view(-1)
|
|||
|
|
|||
|
x = x.transpose(2, 1)
|
|||
|
feature = x.contiguous().view(batch_size * num_vertices_reduced,
|
|||
|
-1)[idx, :]
|
|||
|
feature = feature.view(batch_size, num_vertices, k,
|
|||
|
num_dims).permute(0, 3, 1, 2).contiguous()
|
|||
|
return feature
|
|||
|
|
|||
|
|
|||
|
class MRConv2d(nn.Module):
|
|||
|
"""Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
|
|||
|
for dense data type."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True):
|
|||
|
super(MRConv2d, self).__init__()
|
|||
|
self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg,
|
|||
|
graph_conv_bias)
|
|||
|
|
|||
|
def forward(self, x, edge_index, y=None):
|
|||
|
x_i = batched_index_select(x, edge_index[1])
|
|||
|
if y is not None:
|
|||
|
x_j = batched_index_select(y, edge_index[0])
|
|||
|
else:
|
|||
|
x_j = batched_index_select(x, edge_index[0])
|
|||
|
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
|
|||
|
b, c, n, _ = x.shape
|
|||
|
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)],
|
|||
|
dim=2).reshape(b, 2 * c, n, _)
|
|||
|
return self.nn(x)
|
|||
|
|
|||
|
|
|||
|
class EdgeConv2d(nn.Module):
|
|||
|
"""Edge convolution layer (with activation, batch normalization) for dense
|
|||
|
data type."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True):
|
|||
|
super(EdgeConv2d, self).__init__()
|
|||
|
self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg,
|
|||
|
graph_conv_bias)
|
|||
|
|
|||
|
def forward(self, x, edge_index, y=None):
|
|||
|
x_i = batched_index_select(x, edge_index[1])
|
|||
|
if y is not None:
|
|||
|
x_j = batched_index_select(y, edge_index[0])
|
|||
|
else:
|
|||
|
x_j = batched_index_select(x, edge_index[0])
|
|||
|
max_value, _ = torch.max(
|
|||
|
self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
|
|||
|
return max_value
|
|||
|
|
|||
|
|
|||
|
class GraphSAGE(nn.Module):
|
|||
|
"""GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216)
|
|||
|
for dense data type."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True):
|
|||
|
super(GraphSAGE, self).__init__()
|
|||
|
self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg,
|
|||
|
graph_conv_bias)
|
|||
|
self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg,
|
|||
|
norm_cfg, graph_conv_bias)
|
|||
|
|
|||
|
def forward(self, x, edge_index, y=None):
|
|||
|
if y is not None:
|
|||
|
x_j = batched_index_select(y, edge_index[0])
|
|||
|
else:
|
|||
|
x_j = batched_index_select(x, edge_index[0])
|
|||
|
x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True)
|
|||
|
return self.nn2(torch.cat([x, x_j], dim=1))
|
|||
|
|
|||
|
|
|||
|
class GINConv2d(nn.Module):
|
|||
|
"""GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for
|
|||
|
dense data type."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True):
|
|||
|
super(GINConv2d, self).__init__()
|
|||
|
self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg,
|
|||
|
graph_conv_bias)
|
|||
|
eps_init = 0.0
|
|||
|
self.eps = nn.Parameter(torch.Tensor([eps_init]))
|
|||
|
|
|||
|
def forward(self, x, edge_index, y=None):
|
|||
|
if y is not None:
|
|||
|
x_j = batched_index_select(y, edge_index[0])
|
|||
|
else:
|
|||
|
x_j = batched_index_select(x, edge_index[0])
|
|||
|
x_j = torch.sum(x_j, -1, keepdim=True)
|
|||
|
return self.nn((1 + self.eps) * x + x_j)
|
|||
|
|
|||
|
|
|||
|
class GraphConv2d(nn.Module):
|
|||
|
"""Static graph convolution layer."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
graph_conv_type,
|
|||
|
act_cfg,
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True):
|
|||
|
super(GraphConv2d, self).__init__()
|
|||
|
if graph_conv_type == 'edge':
|
|||
|
self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg,
|
|||
|
norm_cfg, graph_conv_bias)
|
|||
|
elif graph_conv_type == 'mr':
|
|||
|
self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg,
|
|||
|
graph_conv_bias)
|
|||
|
elif graph_conv_type == 'sage':
|
|||
|
self.gconv = GraphSAGE(in_channels, out_channels, act_cfg,
|
|||
|
norm_cfg, graph_conv_bias)
|
|||
|
elif graph_conv_type == 'gin':
|
|||
|
self.gconv = GINConv2d(in_channels, out_channels, act_cfg,
|
|||
|
norm_cfg, graph_conv_bias)
|
|||
|
else:
|
|||
|
raise NotImplementedError(
|
|||
|
'graph_conv_type:{} is not supported'.format(graph_conv_type))
|
|||
|
|
|||
|
def forward(self, x, edge_index, y=None):
|
|||
|
return self.gconv(x, edge_index, y)
|
|||
|
|
|||
|
|
|||
|
class DyGraphConv2d(GraphConv2d):
|
|||
|
"""Dynamic graph convolution layer."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
out_channels,
|
|||
|
k=9,
|
|||
|
dilation=1,
|
|||
|
graph_conv_type='mr',
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True,
|
|||
|
use_stochastic=False,
|
|||
|
epsilon=0.2,
|
|||
|
r=1):
|
|||
|
super(DyGraphConv2d,
|
|||
|
self).__init__(in_channels, out_channels, graph_conv_type,
|
|||
|
act_cfg, norm_cfg, graph_conv_bias)
|
|||
|
self.k = k
|
|||
|
self.d = dilation
|
|||
|
self.r = r
|
|||
|
self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation,
|
|||
|
use_stochastic, epsilon)
|
|||
|
|
|||
|
def forward(self, x, relative_pos=None):
|
|||
|
B, C, H, W = x.shape
|
|||
|
y = None
|
|||
|
if self.r > 1:
|
|||
|
y = F.avg_pool2d(x, self.r, self.r)
|
|||
|
y = y.reshape(B, C, -1, 1).contiguous()
|
|||
|
x = x.reshape(B, C, -1, 1).contiguous()
|
|||
|
edge_index = self.dilated_knn_graph(x, y, relative_pos)
|
|||
|
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
|
|||
|
return x.reshape(B, -1, H, W).contiguous()
|
|||
|
|
|||
|
|
|||
|
class Grapher(nn.Module):
|
|||
|
"""Grapher module with graph convolution and fc layers."""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_channels,
|
|||
|
k=9,
|
|||
|
dilation=1,
|
|||
|
graph_conv_type='mr',
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
norm_cfg=None,
|
|||
|
graph_conv_bias=True,
|
|||
|
use_stochastic=False,
|
|||
|
epsilon=0.2,
|
|||
|
r=1,
|
|||
|
n=196,
|
|||
|
drop_path=0.0,
|
|||
|
relative_pos=False):
|
|||
|
super(Grapher, self).__init__()
|
|||
|
self.channels = in_channels
|
|||
|
self.n = n
|
|||
|
self.r = r
|
|||
|
self.fc1 = Sequential(
|
|||
|
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
|
|||
|
build_norm_layer(dict(type='BN'), in_channels),
|
|||
|
)
|
|||
|
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k,
|
|||
|
dilation, graph_conv_type, act_cfg,
|
|||
|
norm_cfg, graph_conv_bias,
|
|||
|
use_stochastic, epsilon, r)
|
|||
|
self.fc2 = Sequential(
|
|||
|
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
|
|||
|
build_norm_layer(dict(type='BN'), in_channels),
|
|||
|
)
|
|||
|
self.drop_path = DropPath(
|
|||
|
drop_path) if drop_path > 0. else nn.Identity()
|
|||
|
|
|||
|
self.relative_pos = None
|
|||
|
if relative_pos:
|
|||
|
relative_pos_tensor = torch.from_numpy(
|
|||
|
np.float32(
|
|||
|
get_2d_relative_pos_embed(in_channels, int(
|
|||
|
n**0.5)))).unsqueeze(0).unsqueeze(1)
|
|||
|
relative_pos_tensor = F.interpolate(
|
|||
|
relative_pos_tensor,
|
|||
|
size=(n, n // (r * r)),
|
|||
|
mode='bicubic',
|
|||
|
align_corners=False)
|
|||
|
self.relative_pos = nn.Parameter(
|
|||
|
-relative_pos_tensor.squeeze(1), requires_grad=False)
|
|||
|
|
|||
|
def _get_relative_pos(self, relative_pos, H, W):
|
|||
|
if relative_pos is None or H * W == self.n:
|
|||
|
return relative_pos
|
|||
|
else:
|
|||
|
N = H * W
|
|||
|
N_reduced = N // (self.r * self.r)
|
|||
|
return F.interpolate(
|
|||
|
relative_pos.unsqueeze(0), size=(N, N_reduced),
|
|||
|
mode='bicubic').squeeze(0)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
B, C, H, W = x.shape
|
|||
|
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
|
|||
|
shortcut = x
|
|||
|
x = self.fc1(x)
|
|||
|
x = self.graph_conv(x, relative_pos)
|
|||
|
x = self.fc2(x)
|
|||
|
x = self.drop_path(x) + shortcut
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
class FFN(nn.Module):
|
|||
|
""""out_features = out_features or in_features\n
|
|||
|
hidden_features = hidden_features or in_features"""
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
in_features,
|
|||
|
hidden_features=None,
|
|||
|
out_features=None,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
drop_path=0.0):
|
|||
|
super().__init__()
|
|||
|
out_features = out_features or in_features
|
|||
|
hidden_features = hidden_features or in_features
|
|||
|
self.fc1 = Sequential(
|
|||
|
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
|
|||
|
build_norm_layer(dict(type='BN'), hidden_features),
|
|||
|
)
|
|||
|
self.act = build_activation_layer(act_cfg)
|
|||
|
self.fc2 = Sequential(
|
|||
|
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
|
|||
|
build_norm_layer(dict(type='BN'), out_features),
|
|||
|
)
|
|||
|
self.drop_path = DropPath(
|
|||
|
drop_path) if drop_path > 0. else nn.Identity()
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
shortcut = x
|
|||
|
x = self.fc1(x)
|
|||
|
x = self.act(x)
|
|||
|
x = self.fc2(x)
|
|||
|
x = self.drop_path(x) + shortcut
|
|||
|
return x
|
|||
|
|
|||
|
|
|||
|
@MODELS.register_module()
|
|||
|
class Vig(BaseBackbone):
|
|||
|
"""Vision GNN backbone.
|
|||
|
|
|||
|
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
|
|||
|
<https://arxiv.org/abs/2206.00272>`_.
|
|||
|
|
|||
|
Modified from the official implementation
|
|||
|
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
|
|||
|
|
|||
|
Args:
|
|||
|
arch(str): Vision GNN architecture,
|
|||
|
choose from 'tiny', 'small' and 'base'.
|
|||
|
in_channels (int): The number of channels of input images.
|
|||
|
Defaults to 3.
|
|||
|
k (int): The number of KNN's k. Defaults to 9.
|
|||
|
out_indices (Sequence | int): Output from which blocks.
|
|||
|
Defaults to -1, means the last block.
|
|||
|
act_cfg (dict): The config of activative functions.
|
|||
|
Defaults to ``dict(type='GELU'))``.
|
|||
|
norm_cfg (dict): The config of normalization layers.
|
|||
|
Defaults to ``dict(type='BN', eps=1e-6)``.
|
|||
|
graph_conv_bias (bool): Whether to use bias in the convolution
|
|||
|
layers in Grapher. Defaults to True.
|
|||
|
graph_conv_type (str): The type of graph convolution,choose
|
|||
|
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
|
|||
|
epsilon (float): Probability of random arrangement in KNN. It only
|
|||
|
works when ``use_dilation=True`` and ``use_stochastic=True``.
|
|||
|
Defaults to 0.2.
|
|||
|
use_dilation(bool): Whether to use dilation in KNN. Defaults to True.
|
|||
|
use_stochastic(bool): Whether to use stochastic in KNN.
|
|||
|
Defaults to False.
|
|||
|
drop_path (float): stochastic depth rate. Default 0.0
|
|||
|
relative_pos(bool): Whether to use relative position embedding.
|
|||
|
Defaults to False.
|
|||
|
norm_eval (bool): Whether to set the normalization layer to eval mode.
|
|||
|
Defaults to False.
|
|||
|
frozen_stages (int): Blocks to be frozen (all param fixed).
|
|||
|
Defaults to 0, which means not freezing any parameters.
|
|||
|
init_cfg (dict, optional): The initialization configs.
|
|||
|
Defaults to None.
|
|||
|
""" # noqa: E501
|
|||
|
|
|||
|
arch_settings = {
|
|||
|
'tiny': dict(num_blocks=12, channels=192),
|
|||
|
'small': dict(num_blocks=16, channels=320),
|
|||
|
'base': dict(num_blocks=16, channels=640),
|
|||
|
}
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
arch,
|
|||
|
in_channels=3,
|
|||
|
k=9,
|
|||
|
out_indices=-1,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
graph_conv_bias=True,
|
|||
|
graph_conv_type='mr',
|
|||
|
epsilon=0.2,
|
|||
|
use_dilation=True,
|
|||
|
use_stochastic=False,
|
|||
|
drop_path=0.,
|
|||
|
relative_pos=False,
|
|||
|
norm_eval=False,
|
|||
|
frozen_stages=0,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
arch = self.arch_settings[arch]
|
|||
|
self.num_blocks = arch['num_blocks']
|
|||
|
channels = arch['channels']
|
|||
|
|
|||
|
if isinstance(out_indices, int):
|
|||
|
out_indices = [out_indices]
|
|||
|
elif isinstance(out_indices, tuple):
|
|||
|
out_indices = list(out_indices)
|
|||
|
elif not isinstance(out_indices, list):
|
|||
|
raise TypeError('"out_indices" must by a tuple, list or int, '
|
|||
|
f'get {type(out_indices)} instead.')
|
|||
|
for i, index in enumerate(out_indices):
|
|||
|
if index < 0:
|
|||
|
out_indices[i] = self.num_blocks + index
|
|||
|
assert 0 <= out_indices[i] <= self.num_blocks, \
|
|||
|
f'Invalid out_indices {index}'
|
|||
|
self.out_indices = out_indices
|
|||
|
|
|||
|
self.stem = Sequential(
|
|||
|
nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels // 8),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels // 4),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels // 2),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels, channels, 3, stride=1, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels),
|
|||
|
)
|
|||
|
|
|||
|
# stochastic depth decay rule
|
|||
|
dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)]
|
|||
|
# number of knn's k
|
|||
|
num_knn = [
|
|||
|
int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks)
|
|||
|
]
|
|||
|
max_dilation = 196 // max(num_knn)
|
|||
|
|
|||
|
self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14))
|
|||
|
|
|||
|
self.blocks = ModuleList([
|
|||
|
Sequential(
|
|||
|
Grapher(
|
|||
|
in_channels=channels,
|
|||
|
k=num_knn[i],
|
|||
|
dilation=min(i // 4 +
|
|||
|
1, max_dilation) if use_dilation else 1,
|
|||
|
graph_conv_type=graph_conv_type,
|
|||
|
act_cfg=act_cfg,
|
|||
|
norm_cfg=norm_cfg,
|
|||
|
graph_conv_bias=graph_conv_bias,
|
|||
|
use_stochastic=use_stochastic,
|
|||
|
epsilon=epsilon,
|
|||
|
drop_path=dpr[i],
|
|||
|
relative_pos=relative_pos),
|
|||
|
FFN(in_features=channels,
|
|||
|
hidden_features=channels * 4,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop_path=dpr[i])) for i in range(self.num_blocks)
|
|||
|
])
|
|||
|
|
|||
|
self.norm_eval = norm_eval
|
|||
|
self.frozen_stages = frozen_stages
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
outs = []
|
|||
|
x = self.stem(inputs) + self.pos_embed
|
|||
|
|
|||
|
for i, block in enumerate(self.blocks):
|
|||
|
x = block(x)
|
|||
|
|
|||
|
if i in self.out_indices:
|
|||
|
outs.append(x)
|
|||
|
|
|||
|
return tuple(outs)
|
|||
|
|
|||
|
def _freeze_stages(self):
|
|||
|
self.stem.eval()
|
|||
|
for i in range(self.frozen_stages):
|
|||
|
m = self.blocks[i]
|
|||
|
m.eval()
|
|||
|
for param in m.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
def train(self, mode=True):
|
|||
|
super(Vig, self).train(mode)
|
|||
|
self._freeze_stages()
|
|||
|
if mode and self.norm_eval:
|
|||
|
for m in self.modules():
|
|||
|
# trick: eval have effect on BatchNorm only
|
|||
|
if isinstance(m, _BatchNorm):
|
|||
|
m.eval()
|
|||
|
|
|||
|
|
|||
|
@MODELS.register_module()
|
|||
|
class PyramidVig(BaseBackbone):
|
|||
|
"""Pyramid Vision GNN backbone.
|
|||
|
|
|||
|
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
|
|||
|
<https://arxiv.org/abs/2206.00272>`_.
|
|||
|
|
|||
|
Modified from the official implementation
|
|||
|
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
|
|||
|
|
|||
|
Args:
|
|||
|
arch (str): Vision GNN architecture, choose from 'tiny',
|
|||
|
'small' and 'base'.
|
|||
|
in_channels (int): The number of channels of input images.
|
|||
|
Defaults to 3.
|
|||
|
k (int): The number of KNN's k. Defaults to 9.
|
|||
|
out_indices (Sequence | int): Output from which stages.
|
|||
|
Defaults to -1, means the last stage.
|
|||
|
act_cfg (dict): The config of activative functions.
|
|||
|
Defaults to ``dict(type='GELU'))``.
|
|||
|
norm_cfg (dict): The config of normalization layers.
|
|||
|
Defaults to ``dict(type='BN')``.
|
|||
|
graph_conv_bias (bool): Whether to use bias in the convolution
|
|||
|
layers in Grapher. Defaults to True.
|
|||
|
graph_conv_type (str): The type of graph convolution,choose
|
|||
|
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
|
|||
|
epsilon (float): Probability of random arrangement in KNN. It only
|
|||
|
works when ``use_stochastic=True``. Defaults to 0.2.
|
|||
|
use_stochastic (bool): Whether to use stochastic in KNN.
|
|||
|
Defaults to False.
|
|||
|
drop_path (float): stochastic depth rate. Default 0.0
|
|||
|
norm_eval (bool): Whether to set the normalization layer to eval mode.
|
|||
|
Defaults to False.
|
|||
|
frozen_stages (int): Stages to be frozen (all param fixed).
|
|||
|
Defaults to 0, which means not freezing any parameters.
|
|||
|
init_cfg (dict, optional): The initialization configs.
|
|||
|
Defaults to None.
|
|||
|
""" # noqa: E501
|
|||
|
arch_settings = {
|
|||
|
'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]),
|
|||
|
'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]),
|
|||
|
'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]),
|
|||
|
'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]),
|
|||
|
}
|
|||
|
|
|||
|
def __init__(self,
|
|||
|
arch,
|
|||
|
in_channels=3,
|
|||
|
k=9,
|
|||
|
out_indices=-1,
|
|||
|
act_cfg=dict(type='GELU'),
|
|||
|
norm_cfg=dict(type='BN'),
|
|||
|
graph_conv_bias=True,
|
|||
|
graph_conv_type='mr',
|
|||
|
epsilon=0.2,
|
|||
|
use_stochastic=False,
|
|||
|
drop_path=0.,
|
|||
|
norm_eval=False,
|
|||
|
frozen_stages=0,
|
|||
|
init_cfg=None):
|
|||
|
super().__init__(init_cfg=init_cfg)
|
|||
|
arch = self.arch_settings[arch]
|
|||
|
self.blocks = arch['blocks']
|
|||
|
self.num_blocks = sum(self.blocks)
|
|||
|
self.num_stages = len(self.blocks)
|
|||
|
channels = arch['channels']
|
|||
|
self.channels = channels
|
|||
|
|
|||
|
if isinstance(out_indices, int):
|
|||
|
out_indices = [out_indices]
|
|||
|
assert isinstance(out_indices, Sequence), \
|
|||
|
f'"out_indices" must by a sequence or int, ' \
|
|||
|
f'get {type(out_indices)} instead.'
|
|||
|
for i, index in enumerate(out_indices):
|
|||
|
if index < 0:
|
|||
|
out_indices[i] = self.num_stages + index
|
|||
|
assert 0 <= out_indices[i] <= self.num_stages, \
|
|||
|
f'Invalid out_indices {index}'
|
|||
|
self.out_indices = out_indices
|
|||
|
|
|||
|
self.stem = Sequential(
|
|||
|
nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels[0] // 2),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels[0]),
|
|||
|
build_activation_layer(act_cfg),
|
|||
|
nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1),
|
|||
|
build_norm_layer(norm_cfg, channels[0]),
|
|||
|
)
|
|||
|
|
|||
|
# stochastic depth decay rule
|
|||
|
dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)]
|
|||
|
# number of knn's k
|
|||
|
num_knn = [
|
|||
|
int(x.item()) for x in torch.linspace(k, k, self.num_blocks)
|
|||
|
]
|
|||
|
max_dilation = 49 // max(num_knn)
|
|||
|
|
|||
|
self.pos_embed = nn.Parameter(
|
|||
|
torch.zeros(1, channels[0], 224 // 4, 224 // 4))
|
|||
|
HW = 224 // 4 * 224 // 4
|
|||
|
reduce_ratios = [4, 2, 1, 1]
|
|||
|
|
|||
|
self.stages = ModuleList()
|
|||
|
block_idx = 0
|
|||
|
for stage_idx, num_blocks in enumerate(self.blocks):
|
|||
|
mid_channels = channels[stage_idx]
|
|||
|
reduce_ratio = reduce_ratios[stage_idx]
|
|||
|
blocks = []
|
|||
|
if stage_idx > 0:
|
|||
|
blocks.append(
|
|||
|
Sequential(
|
|||
|
nn.Conv2d(
|
|||
|
self.channels[stage_idx - 1],
|
|||
|
mid_channels,
|
|||
|
kernel_size=3,
|
|||
|
stride=2,
|
|||
|
padding=1),
|
|||
|
build_norm_layer(norm_cfg, mid_channels),
|
|||
|
))
|
|||
|
HW = HW // 4
|
|||
|
for _ in range(num_blocks):
|
|||
|
blocks.append(
|
|||
|
Sequential(
|
|||
|
Grapher(
|
|||
|
in_channels=mid_channels,
|
|||
|
k=num_knn[block_idx],
|
|||
|
dilation=min(block_idx // 4 + 1, max_dilation),
|
|||
|
graph_conv_type=graph_conv_type,
|
|||
|
act_cfg=act_cfg,
|
|||
|
norm_cfg=norm_cfg,
|
|||
|
graph_conv_bias=graph_conv_bias,
|
|||
|
use_stochastic=use_stochastic,
|
|||
|
epsilon=epsilon,
|
|||
|
r=reduce_ratio,
|
|||
|
n=HW,
|
|||
|
drop_path=dpr[block_idx],
|
|||
|
relative_pos=True),
|
|||
|
FFN(in_features=mid_channels,
|
|||
|
hidden_features=mid_channels * 4,
|
|||
|
act_cfg=act_cfg,
|
|||
|
drop_path=dpr[block_idx])))
|
|||
|
block_idx += 1
|
|||
|
self.stages.append(Sequential(*blocks))
|
|||
|
|
|||
|
self.norm_eval = norm_eval
|
|||
|
self.frozen_stages = frozen_stages
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
outs = []
|
|||
|
x = self.stem(inputs) + self.pos_embed
|
|||
|
|
|||
|
for i, blocks in enumerate(self.stages):
|
|||
|
x = blocks(x)
|
|||
|
|
|||
|
if i in self.out_indices:
|
|||
|
outs.append(x)
|
|||
|
|
|||
|
return tuple(outs)
|
|||
|
|
|||
|
def _freeze_stages(self):
|
|||
|
self.stem.eval()
|
|||
|
for i in range(self.frozen_stages):
|
|||
|
m = self.stages[i]
|
|||
|
m.eval()
|
|||
|
for param in m.parameters():
|
|||
|
param.requires_grad = False
|
|||
|
|
|||
|
def train(self, mode=True):
|
|||
|
super(PyramidVig, self).train(mode)
|
|||
|
self._freeze_stages()
|
|||
|
if mode and self.norm_eval:
|
|||
|
for m in self.modules():
|
|||
|
# trick: eval have effect on BatchNorm only
|
|||
|
if isinstance(m, _BatchNorm):
|
|||
|
m.eval()
|