fix efficientvit_msra pretrained load

This commit is contained in:
方曦 2023-08-03 18:44:38 +08:00
parent e94c60b546
commit a56e2bbf19

View File

@ -17,6 +17,7 @@ from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
import itertools
from collections import OrderedDict
class ConvBN(torch.nn.Sequential):
@ -53,15 +54,15 @@ class BNLinear(torch.nn.Sequential):
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
bn, linear = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
w = l.weight * w[None, :]
if l.bias is None:
w = linear.weight * w[None, :]
if linear.bias is None:
b = b @ self.linear.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
@ -299,16 +300,16 @@ class EfficientViTStage(torch.nn.Module):
if do[0] == 'subsample':
self.resolution = (resolution - 1) // do[1] + 1
down_blocks = []
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))
down_blocks.append(PatchMerging(pre_ed, ed))
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
ResidualDrop(FFN(ed, int(ed * 2))),))
self.downsample = nn.Sequential(*down_blocks)
down_blocks.append(('res1', torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),)))
down_blocks.append(('patchmerge', PatchMerging(pre_ed, ed)))
down_blocks.append(('res2', torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
ResidualDrop(FFN(ed, int(ed * 2))),)))
self.downsample = nn.Sequential(OrderedDict(down_blocks))
else:
self.downsample = nn.Identity()
self.resolution = resolution
blocks = []
for d in range(depth):
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
@ -355,12 +356,11 @@ class EfficientViTMSRA(nn.Module):
self.patch_embed = PatchEmbedding(in_chans, embed_dim[0])
stride = self.patch_embed.patch_size
resolution = img_size // self.patch_embed.patch_size
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
self.feature_info = []
stages = []
# Build EfficientViT blocks
self.feature_info = []
stages = []
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
pre_ed = embed_dim[i - 1]
@ -419,30 +419,28 @@ class EfficientViTMSRA(nn.Module):
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict.keys():
state_dict = state_dict['model']
tmp_dict = {}
out_dict = {}
target_keys = model.state_dict().keys()
target_keys = [k for k in target_keys if k.startswith('stages.')]
for k, v in state_dict.items():
k = k.split('.')
if k[-2] == 'c':
k[-2] = 'conv'
if k[-2] == 'l':
k[-2] = 'linear'
k = '.'.join(k)
tmp_dict[k] = v
for k, v in tmp_dict.items():
if k.startswith('patch_embed'):
k = k.split('.')
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
if k[2] == 'c':
k[2] = 'conv'
k = '.'.join(k)
elif k.startswith('blocks'):
pass
# k = k.split('.')
# k[0] = 'stages.' + str(int(k[0][6:]) - 1)
# if int(k[1]) >= 2:
# k[1] = 'block'
# else:
# k[1] = 'downsample.' + k[1]
# if k[-1] == 'c':
# k[-1] = 'conv'
# k = '.'.join(k)
elif k.startswith('head'):
k = k.split('.')
if k[1] == 'l':
k[1] = 'linear'
k = '.'.join(k)
kw = '.'.join(k.split('.')[2:])
find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
idx = find_kw.index(k)
k = [a for a in target_keys if kw in a][idx]
out_dict[k] = v
return out_dict