From c28324a150a32b47651c0121750bbad5c8f8649f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Aug 2023 16:39:58 -0700 Subject: [PATCH] Update efficient_vit (msra), hf hub weights --- timm/models/efficientvit_mit.py | 4 +- timm/models/efficientvit_msra.py | 491 ++++++++++++++++++------------- 2 files changed, 291 insertions(+), 204 deletions(-) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index fa97cd1e..0b9ddfc3 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -527,7 +527,7 @@ class EfficientVit(nn.Module): ) else: if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) else: self.head = nn.Identity() @@ -561,7 +561,7 @@ class EfficientVit(nn.Module): ) else: if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) else: self.head = nn.Identity() diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index ca93283e..7a631163 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -6,55 +6,57 @@ Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group At Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT """ -__all__ = ['EfficientViTMSRA'] - -import torch -import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.vision_transformer import trunc_normal_ -from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d -from ._registry import register_model, generate_default_cfgs -from ._builder import build_model_with_cfg -from ._manipulate import checkpoint_seq +__all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict +import torch +import torch.nn as nn -class ConvBN(torch.nn.Sequential): - def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_ +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq +from ._registry import register_model, generate_default_cfgs + + +class ConvNorm(torch.nn.Sequential): + def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() - self.add_module('conv', torch.nn.Conv2d( - a, b, ks, stride, pad, dilation, groups, bias=False)) - self.add_module('bn', torch.nn.BatchNorm2d(b)) + self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False) + self.bn = nn.BatchNorm2d(out_chs) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) @torch.no_grad() def fuse(self): - c, bn = self._modules.values() + c, bn = self.conv, self.bn w = bn.weight / (bn.running_var + bn.eps)**0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 - m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( - 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m = torch.nn.Conv2d( + w.size(1) * self.c.groups, w.size(0), w.shape[2:], + stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m -class BNLinear(torch.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): +class NormLinear(torch.nn.Sequential): + def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.): super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - self.add_module('linear', torch.nn.Linear(a, b, bias=bias)) + self.bn = nn.BatchNorm1d(in_features) + self.drop = nn.Dropout(drop) + self.linear = nn.Linear(in_features, out_features, bias=bias) + trunc_normal_(self.linear.weight, std=std) - if bias: - torch.nn.init.constant_(self.linear.bias, 0) + if self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) @torch.no_grad() def fuse(self): - bn, linear = self._modules.values() + bn, linear = self.bn, self.linear w = bn.weight / (bn.running_var + bn.eps)**0.5 b = bn.bias - self.bn.running_mean * \ self.bn.weight / (bn.running_var + bn.eps)**0.5 @@ -73,11 +75,11 @@ class PatchMerging(torch.nn.Module): def __init__(self, dim, out_dim): super().__init__() hid_dim = int(dim * 4) - self.conv1 = ConvBN(dim, hid_dim, 1, 1, 0) + self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0) self.act = torch.nn.ReLU() - self.conv2 = ConvBN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) + self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim) self.se = SqueezeExcite(hid_dim, .25) - self.conv3 = ConvBN(hid_dim, out_dim, 1, 1, 0) + self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0) def forward(self, x): x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) @@ -92,18 +94,18 @@ class ResidualDrop(torch.nn.Module): def forward(self, x): if self.training and self.drop > 0: - return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, - device=x.device).ge_(self.drop).div(1 - self.drop).detach() + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() else: return x + self.m(x) -class FFN(torch.nn.Module): +class ConvMlp(torch.nn.Module): def __init__(self, ed, h): super().__init__() - self.pw1 = ConvBN(ed, h) + self.pw1 = ConvNorm(ed, h) self.act = torch.nn.ReLU() - self.pw2 = ConvBN(h, ed, bn_weight_init=0) + self.pw2 = ConvNorm(h, ed, bn_weight_init=0) def forward(self, x): x = self.pw2(self.act(self.pw1(x))) @@ -121,10 +123,15 @@ class CascadedGroupAttention(torch.nn.Module): resolution (int): Input resolution, correspond to the window size. kernels (List[int]): The kernel size of the dw conv on query. """ - def __init__(self, dim, key_dim, num_heads=8, - attn_ratio=4, - resolution=14, - kernels=[5, 5, 5, 5],): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + kernels=(5, 5, 5, 5), + ): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 @@ -135,13 +142,13 @@ class CascadedGroupAttention(torch.nn.Module): qkvs = [] dws = [] for i in range(num_heads): - qkvs.append(ConvBN(dim // (num_heads), self.key_dim * 2 + self.d)) - dws.append(ConvBN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim)) + qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.d)) + dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) self.qkvs = torch.nn.ModuleList(qkvs) self.dws = torch.nn.ModuleList(dws) self.proj = torch.nn.Sequential( torch.nn.ReLU(), - ConvBN(self.d * num_heads, dim, bn_weight_init=0) + ConvNorm(self.d * num_heads, dim, bn_weight_init=0) ) points = list(itertools.product(range(resolution), range(resolution))) @@ -156,8 +163,7 @@ class CascadedGroupAttention(torch.nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter( torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', - torch.LongTensor(idxs).view(N, N)) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) @torch.no_grad() def train(self, mode=True): @@ -169,24 +175,23 @@ class CascadedGroupAttention(torch.nn.Module): def forward(self, x): B, C, H, W = x.shape - trainingab = self.attention_biases[:, self.attention_bias_idxs] feats_in = x.chunk(len(self.qkvs), dim=1) feats_out = [] feat = feats_in[0] for i, qkv in enumerate(self.qkvs): + attn_bias = self.attention_biases[:, self.attention_bias_idxs][i] if self.training else self.ab[i] if i > 0: feat = feat + feats_in[i] feat = qkv(feat) q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) q = self.dws[i](q) q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) - attn = ( - (q.transpose(-2, -1) @ k) * self.scale - + - (trainingab[i] if self.training else self.ab[i]) - ) + q = q * self.scale + attn = q.transpose(-2, -1) @ k + attn = attn + attn_bias attn = attn.softmax(dim=-1) - feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) + feat = v @ attn.transpose(-2, -1) + feat = feat.view(B, self.d, H, W) feats_out.append(feat) x = self.proj(torch.cat(feats_out, 1)) return x @@ -204,11 +209,16 @@ class LocalWindowAttention(torch.nn.Module): window_resolution (int): Local window resolution. kernels (List[int]): The kernel size of the dw conv on query. """ - def __init__(self, dim, key_dim, num_heads=8, - attn_ratio=4, - resolution=14, - window_resolution=7, - kernels=[5, 5, 5, 5],): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=(5, 5, 5, 5), + ): super().__init__() self.dim = dim self.num_heads = num_heads @@ -216,10 +226,12 @@ class LocalWindowAttention(torch.nn.Module): assert window_resolution > 0, 'window_size must be greater than 0' self.window_resolution = window_resolution window_resolution = min(window_resolution, resolution) - self.attn = CascadedGroupAttention(dim, key_dim, num_heads, - attn_ratio=attn_ratio, - resolution=window_resolution, - kernels=kernels,) + self.attn = CascadedGroupAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels, + ) def forward(self, x): H = W = self.resolution @@ -231,88 +243,111 @@ class LocalWindowAttention(torch.nn.Module): x = self.attn(x) else: x = x.permute(0, 2, 3, 1) - pad_b = (self.window_resolution - H % - self.window_resolution) % self.window_resolution - pad_r = (self.window_resolution - W % - self.window_resolution) % self.window_resolution - padding = pad_b > 0 or pad_r > 0 - - if padding: - x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) pH, pW = H + pad_b, W + pad_r nH = pH // self.window_resolution nW = pW // self.window_resolution # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw - x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape( - B * nH * nW, self.window_resolution, self.window_resolution, C - ).permute(0, 3, 1, 2) + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3) + x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2) x = self.attn(x) # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC - x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C).transpose(2, 3).reshape(B, pH, pW, C) - if padding: - x = x[:, :H, :W].contiguous() + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C) + x = x.transpose(2, 3).reshape(B, pH, pW, C) + x = x[:, :H, :W].contiguous() x = x.permute(0, 3, 1, 2) return x -class EfficientViTBlock(torch.nn.Module): - """ A basic EfficientViT building block. +class EfficientVitBlock(torch.nn.Module): + """ A basic EfficientVit building block. Args: - ed (int): Number of input channels. - kd (int): Dimension for query and key in the token mixer. - nh (int): Number of attention heads. - ar (int): Multiplier for the query dim for value dimension. + dim (int): Number of input channels. + key_dim (int): Dimension for query and key in the token mixer. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. resolution (int): Input resolution. window_resolution (int): Local window resolution. kernels (List[int]): The kernel size of the dw conv on query. """ - def __init__(self, ed, kd, nh=8, - ar=4, - resolution=14, - window_resolution=7, - kernels=[5, 5, 5, 5],): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + ): super().__init__() - self.dw0 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) - self.ffn0 = ResidualDrop(FFN(ed, int(ed * 2))) + self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2))) self.mixer = ResidualDrop( - LocalWindowAttention(ed, kd, nh, attn_ratio=ar, resolution=resolution, - window_resolution=window_resolution, kernels=kernels)) + LocalWindowAttention( + dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=resolution, + window_resolution=window_resolution, + kernels=kernels, + ) + ) - self.dw1 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.)) - self.ffn1 = ResidualDrop(FFN(ed, int(ed * 2))) + self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0.)) + self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2))) def forward(self, x): return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) -class EfficientViTStage(torch.nn.Module): - def __init__(self, do, pre_ed, ed, kd, nh=8, - ar=4, - resolution=14, - window_resolution=7, - kernels=[5, 5, 5, 5], - depth=1): +class EfficientVitStage(torch.nn.Module): + def __init__( + self, + in_dim, + out_dim, + key_dim, + downsample=('', 1), + num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5], + depth=1, + ): super().__init__() - if do[0] == 'subsample': - self.resolution = (resolution - 1) // do[1] + 1 + if downsample[0] == 'subsample': + self.resolution = (resolution - 1) // downsample[1] + 1 down_blocks = [] - down_blocks.append(('res1', torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)), - ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))) - down_blocks.append(('patchmerge', PatchMerging(pre_ed, ed))) - down_blocks.append(('res2', torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)), - ResidualDrop(FFN(ed, int(ed * 2))),))) + down_blocks.append(( + 'res1', + torch.nn.Sequential( + ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim)), + ResidualDrop(ConvMlp(in_dim, int(in_dim * 2))), + ) + )) + down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim))) + down_blocks.append(( + 'res2', + torch.nn.Sequential( + ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim)), + ResidualDrop(ConvMlp(out_dim, int(out_dim * 2))), + ) + )) self.downsample = nn.Sequential(OrderedDict(down_blocks)) else: + assert in_dim == out_dim self.downsample = nn.Identity() self.resolution = resolution blocks = [] for d in range(depth): - blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels)) + blocks.append(EfficientVitBlock(out_dim, key_dim, num_heads, attn_ratio, self.resolution, window_resolution, kernels)) self.blocks = nn.Sequential(*blocks) def forward(self, x): @@ -324,57 +359,77 @@ class EfficientViTStage(torch.nn.Module): class PatchEmbedding(torch.nn.Sequential): def __init__(self, in_chans, dim): super().__init__() - self.add_module('conv1', ConvBN(in_chans, dim // 8, 3, 2, 1)) + self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1)) self.add_module('relu1', torch.nn.ReLU()) - self.add_module('conv2', ConvBN(dim // 8, dim // 4, 3, 2, 1)) + self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1)) self.add_module('relu2', torch.nn.ReLU()) - self.add_module('conv3', ConvBN(dim // 4, dim // 2, 3, 2, 1)) + self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1)) self.add_module('relu3', torch.nn.ReLU()) - self.add_module('conv4', ConvBN(dim // 2, dim, 3, 2, 1)) + self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1)) self.patch_size = 16 -class EfficientViTMSRA(nn.Module): +class EfficientVitMsra(nn.Module): def __init__( - self, - img_size=224, - in_chans=3, - num_classes=1000, - embed_dim=[64, 128, 192], - key_dim=[16, 16, 16], - depth=[1, 2, 3], - num_heads=[4, 4, 4], - window_size=[7, 7, 7], - kernels=[5, 5, 5, 5], - down_ops=[[''], ['subsample', 2], ['subsample', 2]], - global_pool='avg', + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dim=(64, 128, 192), + key_dim=(16, 16, 16), + depth=(1, 2, 3), + num_heads=(4, 4, 4), + window_size=(7, 7, 7), + kernels=(5, 5, 5, 5), + down_ops=(('', 1), ('subsample', 2), ('subsample', 2)), + global_pool='avg', + drop_rate=0., ): - super(EfficientViTMSRA, self).__init__() + super(EfficientVitMsra, self).__init__() self.grad_checkpointing = False - resolution = img_size + self.num_classes = num_classes + self.drop_rate = drop_rate + # Patch embedding self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) stride = self.patch_embed.patch_size resolution = img_size // self.patch_embed.patch_size attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] - # Build EfficientViT blocks + # Build EfficientVit blocks self.feature_info = [] stages = [] + pre_ed = embed_dim[0] for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): - pre_ed = embed_dim[i - 1] - stage = EfficientViTStage(do, pre_ed, ed, kd, nh, ar, resolution, wd, kernels, dpth) + stage = EfficientVitStage( + in_dim=pre_ed, + out_dim=ed, + key_dim=kd, + downsample=do, + num_heads=nh, + attn_ratio=ar, + resolution=resolution, + window_resolution=wd, + kernels=kernels, + depth=dpth, + ) + pre_ed = ed if do[0] == 'subsample' and i != 0: - stride *= 2 + stride *= do[1] resolution = stage.resolution stages.append(stage) self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() self.num_features = embed_dim[-1] - self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() @torch.jit.ignore def group_matcher(self, coarse=False): @@ -395,8 +450,13 @@ class EfficientViTMSRA(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') - self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity() + if global_pool == 'avg': + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + else: + assert num_classes == 0 + self.global_pool = nn.Identity() + self.head = NormLinear( + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() def forward_features(self, x): x = self.patch_embed(x) @@ -416,33 +476,38 @@ class EfficientViTMSRA(nn.Module): return x -def checkpoint_filter_fn(state_dict, model): - if 'model' in state_dict.keys(): - state_dict = state_dict['model'] - tmp_dict = {} - out_dict = {} - target_keys = model.state_dict().keys() - target_keys = [k for k in target_keys if k.startswith('stages.')] - for k, v in state_dict.items(): - k = k.split('.') - if k[-2] == 'c': - k[-2] = 'conv' - if k[-2] == 'l': - k[-2] = 'linear' - k = '.'.join(k) - tmp_dict[k] = v - for k, v in tmp_dict.items(): - if k.startswith('patch_embed'): - k = k.split('.') - k[1] = 'conv' + str(int(k[1]) // 2 + 1) - k = '.'.join(k) - elif k.startswith('blocks'): - kw = '.'.join(k.split('.')[2:]) - find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] - idx = find_kw.index(k) - k = [a for a in target_keys if kw in a][idx] - out_dict[k] = v - return out_dict +# def checkpoint_filter_fn(state_dict, model): +# if 'model' in state_dict.keys(): +# state_dict = state_dict['model'] +# tmp_dict = {} +# out_dict = {} +# target_keys = model.state_dict().keys() +# target_keys = [k for k in target_keys if k.startswith('stages.')] +# +# for k, v in state_dict.items(): +# if 'attention_bias_idxs' in k: +# continue +# k = k.split('.') +# if k[-2] == 'c': +# k[-2] = 'conv' +# if k[-2] == 'l': +# k[-2] = 'linear' +# k = '.'.join(k) +# tmp_dict[k] = v +# +# for k, v in tmp_dict.items(): +# if k.startswith('patch_embed'): +# k = k.split('.') +# k[1] = 'conv' + str(int(k[1]) // 2 + 1) +# k = '.'.join(k) +# elif k.startswith('blocks'): +# kw = '.'.join(k.split('.')[2:]) +# find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] +# idx = find_kw.index(k) +# k = [a for a in target_keys if kw in a][idx] +# out_dict[k] = v +# +# return out_dict def _cfg(url='', **kwargs): @@ -460,22 +525,28 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs( { 'efficientvit_m0.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' ), 'efficientvit_m1.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' ), 'efficientvit_m2.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' ), 'efficientvit_m3.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' ), 'efficientvit_m4.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' ), 'efficientvit_m5.r224_in1k': _cfg( - url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' ), } ) @@ -484,10 +555,9 @@ default_cfgs = generate_default_cfgs( def _create_efficientvit_msra(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2)) model = build_model_with_cfg( - EfficientViTMSRA, + EfficientVitMsra, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) @@ -496,60 +566,77 @@ def _create_efficientvit_msra(variant, pretrained=False, **kwargs): @register_model def efficientvit_m0(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[64, 128, 192], - depth=[1, 2, 3], - num_heads=[4, 4, 4], - window_size=[7, 7, 7], - kernels=[5, 5, 5, 5]) + model_args = dict( + img_size=224, + embed_dim=[64, 128, 192], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def efficientvit_m1(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[128, 144, 192], - depth=[1, 2, 3], - num_heads=[2, 3, 3], - window_size=[7, 7, 7], - kernels=[7, 5, 3, 3]) + model_args = dict( + img_size=224, + embed_dim=[128, 144, 192], + depth=[1, 2, 3], + num_heads=[2, 3, 3], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def efficientvit_m2(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[128, 192, 224], - depth=[1, 2, 3], - num_heads=[4, 3, 2], - window_size=[7, 7, 7], - kernels=[7, 5, 3, 3]) + model_args = dict( + img_size=224, + embed_dim=[128, 192, 224], + depth=[1, 2, 3], + num_heads=[4, 3, 2], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def efficientvit_m3(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[128, 240, 320], - depth=[1, 2, 3], - num_heads=[4, 3, 4], - window_size=[7, 7, 7], - kernels=[5, 5, 5, 5]) + model_args = dict( + img_size=224, + embed_dim=[128, 240, 320], + depth=[1, 2, 3], + num_heads=[4, 3, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5] + ) return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def efficientvit_m4(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[128, 256, 384], - depth=[1, 2, 3], - num_heads=[4, 4, 4], - window_size=[7, 7, 7], - kernels=[7, 5, 3, 3]) + model_args = dict( + img_size=224, + embed_dim=[128, 256, 384], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model def efficientvit_m5(pretrained=False, **kwargs): - model_args = dict(img_size=224, - embed_dim=[192, 288, 384], - depth=[1, 3, 4], - num_heads=[3, 3, 4], - window_size=[7, 7, 7], - kernels=[7, 5, 3, 3]) + model_args = dict( + img_size=224, + embed_dim=[192, 288, 384], + depth=[1, 3, 4], + num_heads=[3, 3, 4], + window_size=[7, 7, 7], + kernels=[7, 5, 3, 3] + ) return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))