# Copyright (c) OpenMMLab. All rights reserved. # modified from # https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch from typing import Sequence import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_activation_layer from mmcv.cnn.bricks import DropPath from mmengine.model import ModuleList, Sequential from torch.nn.modules.batchnorm import _BatchNorm from mmpretrain.models.backbones.base_backbone import BaseBackbone from mmpretrain.registry import MODELS from ..utils import build_norm_layer def get_2d_relative_pos_embed(embed_dim, grid_size): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, grid_size*grid_size] """ pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1] return relative_pos def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def xy_pairwise_distance(x, y): """Compute pairwise distance of a point cloud. Args: x: tensor (batch_size, num_points, num_dims) y: tensor (batch_size, num_points, num_dims) Returns: pairwise distance: (batch_size, num_points, num_points) """ with torch.no_grad(): xy_inner = -2 * torch.matmul(x, y.transpose(2, 1)) x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) return x_square + xy_inner + y_square.transpose(2, 1) def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): """Get KNN based on the pairwise distance. Args: x: (batch_size, num_dims, num_points, 1) y: (batch_size, num_dims, num_points, 1) k: int relative_pos:Whether to use relative_pos Returns: nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) """ with torch.no_grad(): x = x.transpose(2, 1).squeeze(-1) y = y.transpose(2, 1).squeeze(-1) batch_size, n_points, n_dims = x.shape dist = xy_pairwise_distance(x.detach(), y.detach()) if relative_pos is not None: dist += relative_pos _, nn_idx = torch.topk(-dist, k=k) center_idx = torch.arange( 0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) return torch.stack((nn_idx, center_idx), dim=0) class DenseDilated(nn.Module): """Find dilated neighbor from neighbor list. edge_index: (2, batch_size, num_points, k) """ def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): super(DenseDilated, self).__init__() self.dilation = dilation self.use_stochastic = use_stochastic self.epsilon = epsilon self.k = k def forward(self, edge_index): if self.use_stochastic: if torch.rand(1) < self.epsilon and self.training: num = self.k * self.dilation randnum = torch.randperm(num)[:self.k] edge_index = edge_index[:, :, :, randnum] else: edge_index = edge_index[:, :, :, ::self.dilation] else: edge_index = edge_index[:, :, :, ::self.dilation] return edge_index class DenseDilatedKnnGraph(nn.Module): """Find the neighbors' indices based on dilated knn.""" def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0): super(DenseDilatedKnnGraph, self).__init__() self.dilation = dilation self.use_stochastic = use_stochastic self.epsilon = epsilon self.k = k self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon) def forward(self, x, y=None, relative_pos=None): if y is not None: x = F.normalize(x, p=2.0, dim=1) y = F.normalize(y, p=2.0, dim=1) edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) else: x = F.normalize(x, p=2.0, dim=1) y = x.clone() edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) return self._dilated(edge_index) class BasicConv(Sequential): def __init__(self, channels, act_cfg, norm_cfg=None, graph_conv_bias=True, drop=0.): m = [] for i in range(1, len(channels)): m.append( nn.Conv2d( channels[i - 1], channels[i], 1, bias=graph_conv_bias, groups=4)) if norm_cfg is not None: m.append(build_norm_layer(norm_cfg, channels[-1])) if act_cfg is not None: m.append(build_activation_layer(act_cfg)) if drop > 0: m.append(nn.Dropout2d(drop)) super(BasicConv, self).__init__(*m) def batched_index_select(x, idx): r"""fetches neighbors features from a given neighbor idx Args: x (Tensor): input feature Tensor :math: `\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. idx (Tensor): edge_idx :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. Returns: Tensor: output neighbors features :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. """ batch_size, num_dims, num_vertices_reduced = x.shape[:3] _, num_vertices, k = idx.shape idx_base = torch.arange( 0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced idx = idx + idx_base idx = idx.contiguous().view(-1) x = x.transpose(2, 1) feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :] feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() return feature class MRConv2d(nn.Module): """Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type.""" def __init__(self, in_channels, out_channels, act_cfg, norm_cfg=None, graph_conv_bias=True): super(MRConv2d, self).__init__() self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, graph_conv_bias) def forward(self, x, edge_index, y=None): x_i = batched_index_select(x, edge_index[1]) if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) b, c, n, _ = x.shape x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _) return self.nn(x) class EdgeConv2d(nn.Module): """Edge convolution layer (with activation, batch normalization) for dense data type.""" def __init__(self, in_channels, out_channels, act_cfg, norm_cfg=None, graph_conv_bias=True): super(EdgeConv2d, self).__init__() self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, graph_conv_bias) def forward(self, x, edge_index, y=None): x_i = batched_index_select(x, edge_index[1]) if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) max_value, _ = torch.max( self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) return max_value class GraphSAGE(nn.Module): """GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) for dense data type.""" def __init__(self, in_channels, out_channels, act_cfg, norm_cfg=None, graph_conv_bias=True): super(GraphSAGE, self).__init__() self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg, graph_conv_bias) self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg, graph_conv_bias) def forward(self, x, edge_index, y=None): if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True) return self.nn2(torch.cat([x, x_j], dim=1)) class GINConv2d(nn.Module): """GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for dense data type.""" def __init__(self, in_channels, out_channels, act_cfg, norm_cfg=None, graph_conv_bias=True): super(GINConv2d, self).__init__() self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg, graph_conv_bias) eps_init = 0.0 self.eps = nn.Parameter(torch.Tensor([eps_init])) def forward(self, x, edge_index, y=None): if y is not None: x_j = batched_index_select(y, edge_index[0]) else: x_j = batched_index_select(x, edge_index[0]) x_j = torch.sum(x_j, -1, keepdim=True) return self.nn((1 + self.eps) * x + x_j) class GraphConv2d(nn.Module): """Static graph convolution layer.""" def __init__(self, in_channels, out_channels, graph_conv_type, act_cfg, norm_cfg=None, graph_conv_bias=True): super(GraphConv2d, self).__init__() if graph_conv_type == 'edge': self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg, norm_cfg, graph_conv_bias) elif graph_conv_type == 'mr': self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg, graph_conv_bias) elif graph_conv_type == 'sage': self.gconv = GraphSAGE(in_channels, out_channels, act_cfg, norm_cfg, graph_conv_bias) elif graph_conv_type == 'gin': self.gconv = GINConv2d(in_channels, out_channels, act_cfg, norm_cfg, graph_conv_bias) else: raise NotImplementedError( 'graph_conv_type:{} is not supported'.format(graph_conv_type)) def forward(self, x, edge_index, y=None): return self.gconv(x, edge_index, y) class DyGraphConv2d(GraphConv2d): """Dynamic graph convolution layer.""" def __init__(self, in_channels, out_channels, k=9, dilation=1, graph_conv_type='mr', act_cfg=dict(type='GELU'), norm_cfg=None, graph_conv_bias=True, use_stochastic=False, epsilon=0.2, r=1): super(DyGraphConv2d, self).__init__(in_channels, out_channels, graph_conv_type, act_cfg, norm_cfg, graph_conv_bias) self.k = k self.d = dilation self.r = r self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation, use_stochastic, epsilon) def forward(self, x, relative_pos=None): B, C, H, W = x.shape y = None if self.r > 1: y = F.avg_pool2d(x, self.r, self.r) y = y.reshape(B, C, -1, 1).contiguous() x = x.reshape(B, C, -1, 1).contiguous() edge_index = self.dilated_knn_graph(x, y, relative_pos) x = super(DyGraphConv2d, self).forward(x, edge_index, y) return x.reshape(B, -1, H, W).contiguous() class Grapher(nn.Module): """Grapher module with graph convolution and fc layers.""" def __init__(self, in_channels, k=9, dilation=1, graph_conv_type='mr', act_cfg=dict(type='GELU'), norm_cfg=None, graph_conv_bias=True, use_stochastic=False, epsilon=0.2, r=1, n=196, drop_path=0.0, relative_pos=False): super(Grapher, self).__init__() self.channels = in_channels self.n = n self.r = r self.fc1 = Sequential( nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), build_norm_layer(dict(type='BN'), in_channels), ) self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k, dilation, graph_conv_type, act_cfg, norm_cfg, graph_conv_bias, use_stochastic, epsilon, r) self.fc2 = Sequential( nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), build_norm_layer(dict(type='BN'), in_channels), ) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.relative_pos = None if relative_pos: relative_pos_tensor = torch.from_numpy( np.float32( get_2d_relative_pos_embed(in_channels, int( n**0.5)))).unsqueeze(0).unsqueeze(1) relative_pos_tensor = F.interpolate( relative_pos_tensor, size=(n, n // (r * r)), mode='bicubic', align_corners=False) self.relative_pos = nn.Parameter( -relative_pos_tensor.squeeze(1), requires_grad=False) def _get_relative_pos(self, relative_pos, H, W): if relative_pos is None or H * W == self.n: return relative_pos else: N = H * W N_reduced = N // (self.r * self.r) return F.interpolate( relative_pos.unsqueeze(0), size=(N, N_reduced), mode='bicubic').squeeze(0) def forward(self, x): B, C, H, W = x.shape relative_pos = self._get_relative_pos(self.relative_pos, H, W) shortcut = x x = self.fc1(x) x = self.graph_conv(x, relative_pos) x = self.fc2(x) x = self.drop_path(x) + shortcut return x class FFN(nn.Module): """"out_features = out_features or in_features\n hidden_features = hidden_features or in_features""" def __init__(self, in_features, hidden_features=None, out_features=None, act_cfg=dict(type='GELU'), drop_path=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = Sequential( nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), build_norm_layer(dict(type='BN'), hidden_features), ) self.act = build_activation_layer(act_cfg) self.fc2 = Sequential( nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), build_norm_layer(dict(type='BN'), out_features), ) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.fc1(x) x = self.act(x) x = self.fc2(x) x = self.drop_path(x) + shortcut return x @MODELS.register_module() class Vig(BaseBackbone): """Vision GNN backbone. A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes `_. Modified from the official implementation https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch Args: arch(str): Vision GNN architecture, choose from 'tiny', 'small' and 'base'. in_channels (int): The number of channels of input images. Defaults to 3. k (int): The number of KNN's k. Defaults to 9. out_indices (Sequence | int): Output from which blocks. Defaults to -1, means the last block. act_cfg (dict): The config of activative functions. Defaults to ``dict(type='GELU'))``. norm_cfg (dict): The config of normalization layers. Defaults to ``dict(type='BN', eps=1e-6)``. graph_conv_bias (bool): Whether to use bias in the convolution layers in Grapher. Defaults to True. graph_conv_type (str): The type of graph convolution,choose from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. epsilon (float): Probability of random arrangement in KNN. It only works when ``use_dilation=True`` and ``use_stochastic=True``. Defaults to 0.2. use_dilation(bool): Whether to use dilation in KNN. Defaults to True. use_stochastic(bool): Whether to use stochastic in KNN. Defaults to False. drop_path (float): stochastic depth rate. Default 0.0 relative_pos(bool): Whether to use relative position embedding. Defaults to False. norm_eval (bool): Whether to set the normalization layer to eval mode. Defaults to False. frozen_stages (int): Blocks to be frozen (all param fixed). Defaults to 0, which means not freezing any parameters. init_cfg (dict, optional): The initialization configs. Defaults to None. """ # noqa: E501 arch_settings = { 'tiny': dict(num_blocks=12, channels=192), 'small': dict(num_blocks=16, channels=320), 'base': dict(num_blocks=16, channels=640), } def __init__(self, arch, in_channels=3, k=9, out_indices=-1, act_cfg=dict(type='GELU'), norm_cfg=dict(type='BN'), graph_conv_bias=True, graph_conv_type='mr', epsilon=0.2, use_dilation=True, use_stochastic=False, drop_path=0., relative_pos=False, norm_eval=False, frozen_stages=0, init_cfg=None): super().__init__(init_cfg=init_cfg) arch = self.arch_settings[arch] self.num_blocks = arch['num_blocks'] channels = arch['channels'] if isinstance(out_indices, int): out_indices = [out_indices] elif isinstance(out_indices, tuple): out_indices = list(out_indices) elif not isinstance(out_indices, list): raise TypeError('"out_indices" must by a tuple, list or int, ' f'get {type(out_indices)} instead.') for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_blocks + index assert 0 <= out_indices[i] <= self.num_blocks, \ f'Invalid out_indices {index}' self.out_indices = out_indices self.stem = Sequential( nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels // 8), build_activation_layer(act_cfg), nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels // 4), build_activation_layer(act_cfg), nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels // 2), build_activation_layer(act_cfg), nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels), build_activation_layer(act_cfg), nn.Conv2d(channels, channels, 3, stride=1, padding=1), build_norm_layer(norm_cfg, channels), ) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] # number of knn's k num_knn = [ int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks) ] max_dilation = 196 // max(num_knn) self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14)) self.blocks = ModuleList([ Sequential( Grapher( in_channels=channels, k=num_knn[i], dilation=min(i // 4 + 1, max_dilation) if use_dilation else 1, graph_conv_type=graph_conv_type, act_cfg=act_cfg, norm_cfg=norm_cfg, graph_conv_bias=graph_conv_bias, use_stochastic=use_stochastic, epsilon=epsilon, drop_path=dpr[i], relative_pos=relative_pos), FFN(in_features=channels, hidden_features=channels * 4, act_cfg=act_cfg, drop_path=dpr[i])) for i in range(self.num_blocks) ]) self.norm_eval = norm_eval self.frozen_stages = frozen_stages def forward(self, inputs): outs = [] x = self.stem(inputs) + self.pos_embed for i, block in enumerate(self.blocks): x = block(x) if i in self.out_indices: outs.append(x) return tuple(outs) def _freeze_stages(self): self.stem.eval() for i in range(self.frozen_stages): m = self.blocks[i] m.eval() for param in m.parameters(): param.requires_grad = False def train(self, mode=True): super(Vig, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() @MODELS.register_module() class PyramidVig(BaseBackbone): """Pyramid Vision GNN backbone. A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes `_. Modified from the official implementation https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch Args: arch (str): Vision GNN architecture, choose from 'tiny', 'small' and 'base'. in_channels (int): The number of channels of input images. Defaults to 3. k (int): The number of KNN's k. Defaults to 9. out_indices (Sequence | int): Output from which stages. Defaults to -1, means the last stage. act_cfg (dict): The config of activative functions. Defaults to ``dict(type='GELU'))``. norm_cfg (dict): The config of normalization layers. Defaults to ``dict(type='BN')``. graph_conv_bias (bool): Whether to use bias in the convolution layers in Grapher. Defaults to True. graph_conv_type (str): The type of graph convolution,choose from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'. epsilon (float): Probability of random arrangement in KNN. It only works when ``use_stochastic=True``. Defaults to 0.2. use_stochastic (bool): Whether to use stochastic in KNN. Defaults to False. drop_path (float): stochastic depth rate. Default 0.0 norm_eval (bool): Whether to set the normalization layer to eval mode. Defaults to False. frozen_stages (int): Stages to be frozen (all param fixed). Defaults to 0, which means not freezing any parameters. init_cfg (dict, optional): The initialization configs. Defaults to None. """ # noqa: E501 arch_settings = { 'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]), 'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]), 'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]), 'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]), } def __init__(self, arch, in_channels=3, k=9, out_indices=-1, act_cfg=dict(type='GELU'), norm_cfg=dict(type='BN'), graph_conv_bias=True, graph_conv_type='mr', epsilon=0.2, use_stochastic=False, drop_path=0., norm_eval=False, frozen_stages=0, init_cfg=None): super().__init__(init_cfg=init_cfg) arch = self.arch_settings[arch] self.blocks = arch['blocks'] self.num_blocks = sum(self.blocks) self.num_stages = len(self.blocks) channels = arch['channels'] self.channels = channels if isinstance(out_indices, int): out_indices = [out_indices] assert isinstance(out_indices, Sequence), \ f'"out_indices" must by a sequence or int, ' \ f'get {type(out_indices)} instead.' for i, index in enumerate(out_indices): if index < 0: out_indices[i] = self.num_stages + index assert 0 <= out_indices[i] <= self.num_stages, \ f'Invalid out_indices {index}' self.out_indices = out_indices self.stem = Sequential( nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels[0] // 2), build_activation_layer(act_cfg), nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1), build_norm_layer(norm_cfg, channels[0]), build_activation_layer(act_cfg), nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1), build_norm_layer(norm_cfg, channels[0]), ) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)] # number of knn's k num_knn = [ int(x.item()) for x in torch.linspace(k, k, self.num_blocks) ] max_dilation = 49 // max(num_knn) self.pos_embed = nn.Parameter( torch.zeros(1, channels[0], 224 // 4, 224 // 4)) HW = 224 // 4 * 224 // 4 reduce_ratios = [4, 2, 1, 1] self.stages = ModuleList() block_idx = 0 for stage_idx, num_blocks in enumerate(self.blocks): mid_channels = channels[stage_idx] reduce_ratio = reduce_ratios[stage_idx] blocks = [] if stage_idx > 0: blocks.append( Sequential( nn.Conv2d( self.channels[stage_idx - 1], mid_channels, kernel_size=3, stride=2, padding=1), build_norm_layer(norm_cfg, mid_channels), )) HW = HW // 4 for _ in range(num_blocks): blocks.append( Sequential( Grapher( in_channels=mid_channels, k=num_knn[block_idx], dilation=min(block_idx // 4 + 1, max_dilation), graph_conv_type=graph_conv_type, act_cfg=act_cfg, norm_cfg=norm_cfg, graph_conv_bias=graph_conv_bias, use_stochastic=use_stochastic, epsilon=epsilon, r=reduce_ratio, n=HW, drop_path=dpr[block_idx], relative_pos=True), FFN(in_features=mid_channels, hidden_features=mid_channels * 4, act_cfg=act_cfg, drop_path=dpr[block_idx]))) block_idx += 1 self.stages.append(Sequential(*blocks)) self.norm_eval = norm_eval self.frozen_stages = frozen_stages def forward(self, inputs): outs = [] x = self.stem(inputs) + self.pos_embed for i, blocks in enumerate(self.stages): x = blocks(x) if i in self.out_indices: outs.append(x) return tuple(outs) def _freeze_stages(self): self.stem.eval() for i in range(self.frozen_stages): m = self.stages[i] m.eval() for param in m.parameters(): param.requires_grad = False def train(self, mode=True): super(PyramidVig, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()