mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
efficientvit_msra refactor
This commit is contained in:
parent
047bab6ab2
commit
e94c60b546
@ -22,7 +22,7 @@ import itertools
|
||||
class ConvBN(torch.nn.Sequential):
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(
|
||||
self.add_module('conv', torch.nn.Conv2d(
|
||||
a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
||||
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
||||
@ -46,10 +46,10 @@ class BNLinear(torch.nn.Sequential):
|
||||
def __init__(self, a, b, bias=True, std=0.02):
|
||||
super().__init__()
|
||||
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
||||
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
|
||||
trunc_normal_(self.l.weight, std=std)
|
||||
self.add_module('linear', torch.nn.Linear(a, b, bias=bias))
|
||||
trunc_normal_(self.linear.weight, std=std)
|
||||
if bias:
|
||||
torch.nn.init.constant_(self.l.bias, 0)
|
||||
torch.nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
@ -59,9 +59,9 @@ class BNLinear(torch.nn.Sequential):
|
||||
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = l.weight * w[None, :]
|
||||
if l.bias is None:
|
||||
b = b @ self.l.weight.T
|
||||
b = b @ self.linear.weight.T
|
||||
else:
|
||||
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
||||
b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
|
||||
m = torch.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
@ -288,6 +288,38 @@ class EfficientViTBlock(torch.nn.Module):
|
||||
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
|
||||
|
||||
|
||||
class EfficientViTStage(torch.nn.Module):
|
||||
def __init__(self, do, pre_ed, ed, kd, nh=8,
|
||||
ar=4,
|
||||
resolution=14,
|
||||
window_resolution=7,
|
||||
kernels=[5, 5, 5, 5],
|
||||
depth=1):
|
||||
super().__init__()
|
||||
if do[0] == 'subsample':
|
||||
self.resolution = (resolution - 1) // do[1] + 1
|
||||
down_blocks = []
|
||||
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
|
||||
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))
|
||||
down_blocks.append(PatchMerging(pre_ed, ed))
|
||||
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
|
||||
ResidualDrop(FFN(ed, int(ed * 2))),))
|
||||
self.downsample = nn.Sequential(*down_blocks)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
self.resolution = resolution
|
||||
|
||||
blocks = []
|
||||
for d in range(depth):
|
||||
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbedding(torch.nn.Sequential):
|
||||
def __init__(self, in_chans, dim):
|
||||
super().__init__()
|
||||
@ -331,22 +363,13 @@ class EfficientViTMSRA(nn.Module):
|
||||
# Build EfficientViT blocks
|
||||
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
|
||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
|
||||
blocks = []
|
||||
pre_ed = embed_dim[i - 1]
|
||||
stage = EfficientViTStage(do, pre_ed, ed, kd, nh, ar, resolution, wd, kernels, dpth)
|
||||
if do[0] == 'subsample' and i != 0:
|
||||
# Build EfficientViT downsample block
|
||||
resolution_ = (resolution - 1) // do[1] + 1
|
||||
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i - 1], embed_dim[i - 1], 3, 1, 1, groups=embed_dim[i - 1])),
|
||||
ResidualDrop(FFN(embed_dim[i - 1], int(embed_dim[i - 1] * 2))),))
|
||||
blocks.append(PatchMerging(*embed_dim[i - 1:i + 1]))
|
||||
resolution = resolution_
|
||||
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
|
||||
ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
|
||||
stride *= 2
|
||||
for d in range(dpth):
|
||||
blocks.append(EfficientViTBlock(ed, kd, nh, ar, resolution, wd, kernels))
|
||||
stages.append(nn.Sequential(*blocks))
|
||||
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
|
||||
|
||||
resolution = stage.resolution
|
||||
stages.append(stage)
|
||||
self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||||
@ -401,10 +424,24 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
if k.startswith('patch_embed'):
|
||||
k = k.split('.')
|
||||
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
|
||||
if k[2] == 'c':
|
||||
k[2] = 'conv'
|
||||
k = '.'.join(k)
|
||||
elif k.startswith('blocks'):
|
||||
pass
|
||||
# k = k.split('.')
|
||||
# k[0] = 'stages.' + str(int(k[0][6:]) - 1)
|
||||
# if int(k[1]) >= 2:
|
||||
# k[1] = 'block'
|
||||
# else:
|
||||
# k[1] = 'downsample.' + k[1]
|
||||
# if k[-1] == 'c':
|
||||
# k[-1] = 'conv'
|
||||
# k = '.'.join(k)
|
||||
elif k.startswith('head'):
|
||||
k = k.split('.')
|
||||
k[0] = 'stages.' + str(int(k[0][6:]) - 1)
|
||||
if k[1] == 'l':
|
||||
k[1] = 'linear'
|
||||
k = '.'.join(k)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
@ -416,8 +453,8 @@ def _cfg(url='', **kwargs):
|
||||
'num_classes': 1000,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.in_conv.conv',
|
||||
'classifier': 'head',
|
||||
'first_conv': 'patch_embed.conv1.conv',
|
||||
'classifier': 'head.linear',
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user