diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index bba1530e..57d5cf6d 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -17,6 +17,7 @@ from ._registry import register_model, generate_default_cfgs from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq import itertools +from collections import OrderedDict class ConvBN(torch.nn.Sequential): @@ -53,15 +54,15 @@ class BNLinear(torch.nn.Sequential): @torch.no_grad() def fuse(self): - bn, l = self._modules.values() + bn, linear = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 b = bn.bias - self.bn.running_mean * \ self.bn.weight / (bn.running_var + bn.eps)**0.5 - w = l.weight * w[None, :] - if l.bias is None: + w = linear.weight * w[None, :] + if linear.bias is None: b = b @ self.linear.weight.T else: - b = (l.weight @ b[:, None]).view(-1) + self.linear.bias + b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias m = torch.nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) @@ -299,16 +300,16 @@ class EfficientViTStage(torch.nn.Module): if do[0] == 'subsample': self.resolution = (resolution - 1) // do[1] + 1 down_blocks = [] - down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)), - ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),)) - down_blocks.append(PatchMerging(pre_ed, ed)) - down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)), - ResidualDrop(FFN(ed, int(ed * 2))),)) - self.downsample = nn.Sequential(*down_blocks) + down_blocks.append(('res1', torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)), + ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))) + down_blocks.append(('patchmerge', PatchMerging(pre_ed, ed))) + down_blocks.append(('res2', torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)), + ResidualDrop(FFN(ed, int(ed * 2))),))) + self.downsample = nn.Sequential(OrderedDict(down_blocks)) else: self.downsample = nn.Identity() self.resolution = resolution - + blocks = [] for d in range(depth): blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels)) @@ -355,12 +356,11 @@ class EfficientViTMSRA(nn.Module): self.patch_embed = PatchEmbedding(in_chans, embed_dim[0]) stride = self.patch_embed.patch_size resolution = img_size // self.patch_embed.patch_size - attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] - self.feature_info = [] - stages = [] # Build EfficientViT blocks + self.feature_info = [] + stages = [] for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): pre_ed = embed_dim[i - 1] @@ -419,30 +419,28 @@ class EfficientViTMSRA(nn.Module): def checkpoint_filter_fn(state_dict, model): if 'model' in state_dict.keys(): state_dict = state_dict['model'] + tmp_dict = {} out_dict = {} + target_keys = model.state_dict().keys() + target_keys = [k for k in target_keys if k.startswith('stages.')] for k, v in state_dict.items(): + k = k.split('.') + if k[-2] == 'c': + k[-2] = 'conv' + if k[-2] == 'l': + k[-2] = 'linear' + k = '.'.join(k) + tmp_dict[k] = v + for k, v in tmp_dict.items(): if k.startswith('patch_embed'): k = k.split('.') k[1] = 'conv' + str(int(k[1]) // 2 + 1) - if k[2] == 'c': - k[2] = 'conv' k = '.'.join(k) elif k.startswith('blocks'): - pass - # k = k.split('.') - # k[0] = 'stages.' + str(int(k[0][6:]) - 1) - # if int(k[1]) >= 2: - # k[1] = 'block' - # else: - # k[1] = 'downsample.' + k[1] - # if k[-1] == 'c': - # k[-1] = 'conv' - # k = '.'.join(k) - elif k.startswith('head'): - k = k.split('.') - if k[1] == 'l': - k[1] = 'linear' - k = '.'.join(k) + kw = '.'.join(k.split('.')[2:]) + find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a] + idx = find_kw.index(k) + k = [a for a in target_keys if kw in a][idx] out_dict[k] = v return out_dict