From 82d1e99e1aafaa841e10f5e90819e18679228ced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=9B=A6?= Date: Tue, 1 Aug 2023 18:51:08 +0800 Subject: [PATCH] add efficientvit(msra) --- timm/models/__init__.py | 1 + timm/models/efficientvit_mit.py | 36 +-- timm/models/efficientvit_msra.py | 441 +++++++++++++++++++++++++++++++ 3 files changed, 461 insertions(+), 17 deletions(-) create mode 100644 timm/models/efficientvit_msra.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 125d4c1b..56ff246c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -18,6 +18,7 @@ from .efficientformer import * from .efficientformer_v2 import * from .efficientnet import * from .efficientvit_mit import * +from .efficientvit_msra import * from .eva import * from .focalnet import * from .gcvit import * diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index de175a68..252337d3 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -1,4 +1,4 @@ -""" EfficientViT +""" EfficientViT (by MIT Song Han's Lab) Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` - https://arxiv.org/abs/2205.14756 @@ -40,7 +40,7 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, . assert kernel_size % 2 > 0, "kernel size should be odd number" return kernel_size // 2 -class ConvLayer(nn.Module): +class ConvNormAct(nn.Module): def __init__( self, in_channels: int, @@ -54,7 +54,7 @@ class ConvLayer(nn.Module): norm=nn.BatchNorm2d, act_func=nn.ReLU, ): - super(ConvLayer, self).__init__() + super(ConvNormAct, self).__init__() padding = get_same_padding(kernel_size) padding *= dilation @@ -101,7 +101,7 @@ class DSConv(nn.Module): norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) - self.depth_conv = ConvLayer( + self.depth_conv = ConvNormAct( in_channels, in_channels, kernel_size, @@ -111,7 +111,7 @@ class DSConv(nn.Module): act_func=act_func[0], use_bias=use_bias[0], ) - self.point_conv = ConvLayer( + self.point_conv = ConvNormAct( in_channels, out_channels, 1, @@ -146,7 +146,7 @@ class MBConv(nn.Module): act_func = val2tuple(act_func, 3) mid_channels = mid_channels or round(in_channels * expand_ratio) - self.inverted_conv = ConvLayer( + self.inverted_conv = ConvNormAct( in_channels, mid_channels, 1, @@ -155,7 +155,7 @@ class MBConv(nn.Module): act_func=act_func[0], use_bias=use_bias[0], ) - self.depth_conv = ConvLayer( + self.depth_conv = ConvNormAct( mid_channels, mid_channels, kernel_size, @@ -165,7 +165,7 @@ class MBConv(nn.Module): act_func=act_func[1], use_bias=use_bias[1], ) - self.point_conv = ConvLayer( + self.point_conv = ConvNormAct( mid_channels, out_channels, 1, @@ -207,7 +207,7 @@ class LiteMSA(nn.Module): act_func = val2tuple(act_func, 2) self.dim = dim - self.qkv = ConvLayer( + self.qkv = ConvNormAct( in_channels, 3 * total_dim, 1, @@ -233,7 +233,7 @@ class LiteMSA(nn.Module): ) self.kernel_func = kernel_func(inplace=False) - self.proj = ConvLayer( + self.proj = ConvNormAct( total_dim * (1 + len(scales)), out_channels, 1, @@ -368,7 +368,7 @@ class ClsHead(nn.Module): ): super(ClsHead, self).__init__() self.ops = nn.Sequential( - ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func), + ConvNormAct(in_channels, width_list[0], 1, norm=norm, act_func=act_func), SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW'), nn.Linear(width_list[0], width_list[1], bias=False), nn.LayerNorm(width_list[1]), @@ -402,7 +402,7 @@ class EfficientViT(nn.Module): self.global_pool = global_pool # input stem input_stem = [ - ('in_conv', ConvLayer( + ('in_conv', ConvNormAct( in_channels=3, out_channels=width_list[0], kernel_size=3, @@ -425,23 +425,24 @@ class EfficientViT(nn.Module): stem_block += 1 in_channels = width_list[0] self.stem = nn.Sequential(OrderedDict(input_stem)) - + stride = 4 self.feature_info = [] stages = [] stage_idx = 0 for w, d in zip(width_list[1:3], depth_list[1:3]): stage = [] for i in range(d): - stride = 2 if i == 0 else 1 + stage_stride = 2 if i == 0 else 1 + stride *= stage_stride block = self.build_local_block( in_channels=in_channels, out_channels=w, - stride=stride, + stride=stage_stride, expand_ratio=expand_ratio, norm=norm, act_func=act_func, ) - block = ResidualBlock(block, nn.Identity() if stride == 1 else None) + block = ResidualBlock(block, nn.Identity() if stage_stride == 1 else None) stage.append(block) in_channels = w stages.append(nn.Sequential(*stage)) @@ -473,9 +474,10 @@ class EfficientViT(nn.Module): ) ) stages.append(nn.Sequential(*stage)) + stride *= 2 self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')] stage_idx += 1 - + self.stages = nn.Sequential(*stages) self.num_features = in_channels self.head_width_list = head_width_list diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py new file mode 100644 index 00000000..f843c5ad --- /dev/null +++ b/timm/models/efficientvit_msra.py @@ -0,0 +1,441 @@ +""" EfficientViT (by MSRA) + +Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention` + - https://arxiv.org/abs/2305.07027 + +Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT +""" + +__all__ = ['EfficientViTMSRA'] + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.vision_transformer import trunc_normal_ +from timm.models.layers import SqueezeExcite +from ._registry import register_model, generate_default_cfgs +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from functools import partial +from collections import OrderedDict +import itertools + + +class ConvBN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', torch.nn.BatchNorm2d(b)) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + +class BNLinear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + self.add_module('l', torch.nn.Linear(a, b, bias=bias)) + trunc_normal_(self.l.weight, std=std) + if bias: + torch.nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + +class PatchMerging(torch.nn.Module): + def __init__(self, dim, out_dim): + super().__init__() + hid_dim = int(dim * 4) + self.conv1 = ConvBN(dim, hid_dim, 1, 1, 0) + self.act = torch.nn.ReLU() + self.conv2 = ConvBN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.se = SqueezeExcite(hid_dim, .25) + self.conv3 = ConvBN(hid_dim, out_dim, 1, 1, 0) + + def forward(self, x): + x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) + return x + +class ResidualDrop(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + +class FFN(torch.nn.Module): + def __init__(self, ed, h): + super().__init__() + self.pw1 = ConvBN(ed, h) + self.act = torch.nn.ReLU() + self.pw2 = ConvBN(h, ed, bn_weight_init=0) + + def forward(self, x): + x = self.pw2(self.act(self.pw1(x))) + return x + +class CascadedGroupAttention(torch.nn.Module): + r""" Cascaded Group Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution, correspond to the window size. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=14, + kernels=[5, 5, 5, 5],): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.d = int(attn_ratio * key_dim) + self.attn_ratio = attn_ratio + + qkvs = [] + dws = [] + for i in range(num_heads): + qkvs.append(ConvBN(dim // (num_heads), self.key_dim * 2 + self.d)) + dws.append(ConvBN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim)) + self.qkvs = torch.nn.ModuleList(qkvs) + self.dws = torch.nn.ModuleList(dws) + self.proj = torch.nn.Sequential( + torch.nn.ReLU(), + ConvBN(self.d * num_heads, dim, bn_weight_init=0) + ) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + trainingab = self.attention_biases[:, self.attention_bias_idxs] + feats_in = x.chunk(len(self.qkvs), dim=1) + feats_out = [] + feat = feats_in[0] + for i, qkv in enumerate(self.qkvs): + if i > 0: # add the previous output to the input + feat = feat + feats_in[i] + feat = qkv(feat) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W + q = self.dws[i](q) + q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N + attn = ( + (q.transpose(-2, -1) @ k) * self.scale + + + (trainingab[i] if self.training else self.ab[i]) + ) + attn = attn.softmax(dim=-1) # BNN + feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW + feats_out.append(feat) + x = self.proj(torch.cat(feats_out, 1)) + return x + +class LocalWindowAttention(torch.nn.Module): + r""" Local Window Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5],): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.resolution = resolution + assert window_resolution > 0, 'window_size must be greater than 0' + self.window_resolution = window_resolution + + window_resolution = min(window_resolution, resolution) + self.attn = CascadedGroupAttention(dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels,) + + def forward(self, x): + H = W = self.resolution + B, C, H_, W_ = x.shape + # Only check this for classifcation models + assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_)) + + if H <= self.window_resolution and W <= self.window_resolution: + x = self.attn(x) + else: + x = x.permute(0, 2, 3, 1) + pad_b = (self.window_resolution - H % + self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % + self.window_resolution) % self.window_resolution + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_resolution + nW = pW // self.window_resolution + # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape( + B * nH * nW, self.window_resolution, self.window_resolution, C + ).permute(0, 3, 1, 2) + x = self.attn(x) + # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, + C).transpose(2, 3).reshape(B, pH, pW, C) + if padding: + x = x[:, :H, :W].contiguous() + x = x.permute(0, 3, 1, 2) + return x + +class EfficientViTBlock(torch.nn.Module): + """ A basic EfficientViT building block. + + Args: + ed (int): Number of input channels. + kd (int): Dimension for query and key in the token mixer. + nh (int): Number of attention heads. + ar (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self,ed, kd, nh=8, + ar=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5],): + super().__init__() + + self.dw0 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) + self.ffn0 = ResidualDrop(FFN(ed, int(ed * 2))) + + self.mixer = ResidualDrop( + LocalWindowAttention(ed, kd, nh, attn_ratio=ar, resolution=resolution, + window_resolution=window_resolution, kernels=kernels)) + + self.dw1 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) + self.ffn1 = ResidualDrop(FFN(ed, int(ed * 2))) + + def forward(self, x): + return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) + +class PatchEmbedding(torch.nn.Sequential): + def __init__(self, in_chans, dim): + super().__init__() + self.add_module('conv1', ConvBN(in_chans, dim // 8, 3, 2, 1)) + self.add_module('relu1', torch.nn.ReLU()) + self.add_module('conv2', ConvBN(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('relu2', torch.nn.ReLU()) + self.add_module('conv3', ConvBN(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('relu3', torch.nn.ReLU()) + self.add_module('conv4', ConvBN(dim // 2, dim, 3, 2, 1)) + self.patch_size = 16 + +class EfficientViTMSRA(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dim=[64, 128, 192], + key_dim=[16, 16, 16], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5], + down_ops=[[''], ['subsample', 2], ['subsample', 2]], + global_pool='avg', + ): + super(EfficientViTMSRA, self).__init__() + self.grad_checkpointing = False + resolution = img_size + # Patch embedding + self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) + stride = self.patch_embed.patch_size + resolution = img_size // self.patch_embed.patch_size + + attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] + self.feature_info = [] + stages = [] + self.feature_info += [dict(num_chs=embed_dim[0], reduction=stride, module=f'stages.{0}')] + + # Build EfficientViT blocks + for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): + blocks = [] + if do[0] == 'subsample': + # Build EfficientViT downsample block + resolution_ = (resolution - 1) // do[1] + 1 + blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])), + ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),)) + blocks.append(PatchMerging(*embed_dim[i:i + 2], resolution)) + resolution = resolution_ + blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])), + ResidualDrop(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),)) + stride *= 2 + for d in range(dpth): + blocks.append(EfficientViTBlock(ed, kd, nh, ar, resolution, wd, kernels)) + stages.append(nn.Sequential(*blocks)) + self.feature_info += [dict(num_chs=embed_dim[i+1], reduction=stride, module=f'stages.{i+1}')] + + self.stages = nn.Sequential(*stages) + + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.head = BNLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^patch_embed', + blocks=[(r'^stages\.(\d+)', None)] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.head = BNLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + target_keys = list(model.state_dict().keys()) + if 'state_dict' in state_dict.keys(): + state_dict = state_dict['state_dict'] + out_dict = {} + for i, (k, v) in enumerate(state_dict.items()): + out_dict[target_keys[i]] = v + return out_dict + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.in_conv.conv', + 'classifier': 'head', + **kwargs, + } + + +default_cfgs = generate_default_cfgs( + { + 'efficientvit_m0.r224_in1k': _cfg( + url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + ), + } +) + + +def _create_efficientvit_msra(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EfficientViTMSRA, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs + ) + return model + + +@register_model +def efficientvit_m0(pretrained=False, **kwargs): + model_args = dict(img_size=224, + embed_dim=[64, 128, 192], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5]) + return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))