From e94c60b546a4544afda00fa44ebd27ca08158c41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E6=9B=A6?= Date: Thu, 3 Aug 2023 17:45:50 +0800 Subject: [PATCH] efficientvit_msra refactor --- timm/models/efficientvit_msra.py | 83 +++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index bef462e0..bba1530e 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -22,7 +22,7 @@ import itertools class ConvBN(torch.nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() - self.add_module('c', torch.nn.Conv2d( + self.add_module('conv', torch.nn.Conv2d( a, b, ks, stride, pad, dilation, groups, bias=False)) self.add_module('bn', torch.nn.BatchNorm2d(b)) torch.nn.init.constant_(self.bn.weight, bn_weight_init) @@ -46,10 +46,10 @@ class BNLinear(torch.nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() self.add_module('bn', torch.nn.BatchNorm1d(a)) - self.add_module('l', torch.nn.Linear(a, b, bias=bias)) - trunc_normal_(self.l.weight, std=std) + self.add_module('linear', torch.nn.Linear(a, b, bias=bias)) + trunc_normal_(self.linear.weight, std=std) if bias: - torch.nn.init.constant_(self.l.bias, 0) + torch.nn.init.constant_(self.linear.bias, 0) @torch.no_grad() def fuse(self): @@ -59,9 +59,9 @@ class BNLinear(torch.nn.Sequential): self.bn.weight / (bn.running_var + bn.eps)**0.5 w = l.weight * w[None, :] if l.bias is None: - b = b @ self.l.weight.T + b = b @ self.linear.weight.T else: - b = (l.weight @ b[:, None]).view(-1) + self.l.bias + b = (l.weight @ b[:, None]).view(-1) + self.linear.bias m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) @@ -288,6 +288,38 @@ class EfficientViTBlock(torch.nn.Module): return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) +class EfficientViTStage(torch.nn.Module): + def __init__(self, do, pre_ed, ed, kd, nh=8, + ar=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + depth=1): + super().__init__() + if do[0] == 'subsample': + self.resolution = (resolution - 1) // do[1] + 1 + down_blocks = [] + down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)), + ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),)) + down_blocks.append(PatchMerging(pre_ed, ed)) + down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)), + ResidualDrop(FFN(ed, int(ed * 2))),)) + self.downsample = nn.Sequential(*down_blocks) + else: + self.downsample = nn.Identity() + self.resolution = resolution + + blocks = [] + for d in range(depth): + blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels)) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + class PatchEmbedding(torch.nn.Sequential): def __init__(self, in_chans, dim): super().__init__() @@ -331,22 +363,13 @@ class EfficientViTMSRA(nn.Module): # Build EfficientViT blocks for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): - blocks = [] + pre_ed = embed_dim[i - 1] + stage = EfficientViTStage(do, pre_ed, ed, kd, nh, ar, resolution, wd, kernels, dpth) if do[0] == 'subsample' and i != 0: - # Build EfficientViT downsample block - resolution_ = (resolution - 1) // do[1] + 1 - blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i - 1], embed_dim[i - 1], 3, 1, 1, groups=embed_dim[i - 1])), - ResidualDrop(FFN(embed_dim[i - 1], int(embed_dim[i - 1] * 2))),)) - blocks.append(PatchMerging(*embed_dim[i - 1:i + 1])) - resolution = resolution_ - blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])), - ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),)) stride *= 2 - for d in range(dpth): - blocks.append(EfficientViTBlock(ed, kd, nh, ar, resolution, wd, kernels)) - stages.append(nn.Sequential(*blocks)) - self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')] - + resolution = stage.resolution + stages.append(stage) + self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') @@ -401,10 +424,24 @@ def checkpoint_filter_fn(state_dict, model): if k.startswith('patch_embed'): k = k.split('.') k[1] = 'conv' + str(int(k[1]) // 2 + 1) + if k[2] == 'c': + k[2] = 'conv' k = '.'.join(k) elif k.startswith('blocks'): + pass + # k = k.split('.') + # k[0] = 'stages.' + str(int(k[0][6:]) - 1) + # if int(k[1]) >= 2: + # k[1] = 'block' + # else: + # k[1] = 'downsample.' + k[1] + # if k[-1] == 'c': + # k[-1] = 'conv' + # k = '.'.join(k) + elif k.startswith('head'): k = k.split('.') - k[0] = 'stages.' + str(int(k[0][6:]) - 1) + if k[1] == 'l': + k[1] = 'linear' k = '.'.join(k) out_dict[k] = v return out_dict @@ -416,8 +453,8 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.in_conv.conv', - 'classifier': 'head', + 'first_conv': 'patch_embed.conv1.conv', + 'classifier': 'head.linear', **kwargs, }