mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix efficientvit_msra pretrained load
This commit is contained in:
parent
e94c60b546
commit
a56e2bbf19
@ -17,6 +17,7 @@ from ._registry import register_model, generate_default_cfgs
|
|||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
import itertools
|
import itertools
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class ConvBN(torch.nn.Sequential):
|
class ConvBN(torch.nn.Sequential):
|
||||||
@ -53,15 +54,15 @@ class BNLinear(torch.nn.Sequential):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
bn, l = self._modules.values()
|
bn, linear = self._modules.values()
|
||||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
b = bn.bias - self.bn.running_mean * \
|
b = bn.bias - self.bn.running_mean * \
|
||||||
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||||
w = l.weight * w[None, :]
|
w = linear.weight * w[None, :]
|
||||||
if l.bias is None:
|
if linear.bias is None:
|
||||||
b = b @ self.linear.weight.T
|
b = b @ self.linear.weight.T
|
||||||
else:
|
else:
|
||||||
b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
|
b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
|
||||||
m = torch.nn.Linear(w.size(1), w.size(0))
|
m = torch.nn.Linear(w.size(1), w.size(0))
|
||||||
m.weight.data.copy_(w)
|
m.weight.data.copy_(w)
|
||||||
m.bias.data.copy_(b)
|
m.bias.data.copy_(b)
|
||||||
@ -299,16 +300,16 @@ class EfficientViTStage(torch.nn.Module):
|
|||||||
if do[0] == 'subsample':
|
if do[0] == 'subsample':
|
||||||
self.resolution = (resolution - 1) // do[1] + 1
|
self.resolution = (resolution - 1) // do[1] + 1
|
||||||
down_blocks = []
|
down_blocks = []
|
||||||
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
|
down_blocks.append(('res1', torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
|
||||||
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))
|
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),)))
|
||||||
down_blocks.append(PatchMerging(pre_ed, ed))
|
down_blocks.append(('patchmerge', PatchMerging(pre_ed, ed)))
|
||||||
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
|
down_blocks.append(('res2', torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
|
||||||
ResidualDrop(FFN(ed, int(ed * 2))),))
|
ResidualDrop(FFN(ed, int(ed * 2))),)))
|
||||||
self.downsample = nn.Sequential(*down_blocks)
|
self.downsample = nn.Sequential(OrderedDict(down_blocks))
|
||||||
else:
|
else:
|
||||||
self.downsample = nn.Identity()
|
self.downsample = nn.Identity()
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
for d in range(depth):
|
for d in range(depth):
|
||||||
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
|
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
|
||||||
@ -355,12 +356,11 @@ class EfficientViTMSRA(nn.Module):
|
|||||||
self.patch_embed = PatchEmbedding(in_chans, embed_dim[0])
|
self.patch_embed = PatchEmbedding(in_chans, embed_dim[0])
|
||||||
stride = self.patch_embed.patch_size
|
stride = self.patch_embed.patch_size
|
||||||
resolution = img_size // self.patch_embed.patch_size
|
resolution = img_size // self.patch_embed.patch_size
|
||||||
|
|
||||||
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
|
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
|
||||||
self.feature_info = []
|
|
||||||
stages = []
|
|
||||||
|
|
||||||
# Build EfficientViT blocks
|
# Build EfficientViT blocks
|
||||||
|
self.feature_info = []
|
||||||
|
stages = []
|
||||||
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
|
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
|
||||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
|
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
|
||||||
pre_ed = embed_dim[i - 1]
|
pre_ed = embed_dim[i - 1]
|
||||||
@ -419,30 +419,28 @@ class EfficientViTMSRA(nn.Module):
|
|||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
if 'model' in state_dict.keys():
|
if 'model' in state_dict.keys():
|
||||||
state_dict = state_dict['model']
|
state_dict = state_dict['model']
|
||||||
|
tmp_dict = {}
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
|
target_keys = model.state_dict().keys()
|
||||||
|
target_keys = [k for k in target_keys if k.startswith('stages.')]
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
|
k = k.split('.')
|
||||||
|
if k[-2] == 'c':
|
||||||
|
k[-2] = 'conv'
|
||||||
|
if k[-2] == 'l':
|
||||||
|
k[-2] = 'linear'
|
||||||
|
k = '.'.join(k)
|
||||||
|
tmp_dict[k] = v
|
||||||
|
for k, v in tmp_dict.items():
|
||||||
if k.startswith('patch_embed'):
|
if k.startswith('patch_embed'):
|
||||||
k = k.split('.')
|
k = k.split('.')
|
||||||
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
|
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
|
||||||
if k[2] == 'c':
|
|
||||||
k[2] = 'conv'
|
|
||||||
k = '.'.join(k)
|
k = '.'.join(k)
|
||||||
elif k.startswith('blocks'):
|
elif k.startswith('blocks'):
|
||||||
pass
|
kw = '.'.join(k.split('.')[2:])
|
||||||
# k = k.split('.')
|
find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
|
||||||
# k[0] = 'stages.' + str(int(k[0][6:]) - 1)
|
idx = find_kw.index(k)
|
||||||
# if int(k[1]) >= 2:
|
k = [a for a in target_keys if kw in a][idx]
|
||||||
# k[1] = 'block'
|
|
||||||
# else:
|
|
||||||
# k[1] = 'downsample.' + k[1]
|
|
||||||
# if k[-1] == 'c':
|
|
||||||
# k[-1] = 'conv'
|
|
||||||
# k = '.'.join(k)
|
|
||||||
elif k.startswith('head'):
|
|
||||||
k = k.split('.')
|
|
||||||
if k[1] == 'l':
|
|
||||||
k[1] = 'linear'
|
|
||||||
k = '.'.join(k)
|
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user