efficientvit_msra refactor

This commit is contained in:
方曦 2023-08-03 17:45:50 +08:00
parent 047bab6ab2
commit e94c60b546

View File

@ -22,7 +22,7 @@ import itertools
class ConvBN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
self.add_module('conv', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
@ -46,10 +46,10 @@ class BNLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
self.add_module('linear', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.linear.weight, std=std)
if bias:
torch.nn.init.constant_(self.l.bias, 0)
torch.nn.init.constant_(self.linear.bias, 0)
@torch.no_grad()
def fuse(self):
@ -59,9 +59,9 @@ class BNLinear(torch.nn.Sequential):
self.bn.weight / (bn.running_var + bn.eps)**0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
b = b @ self.linear.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
m = torch.nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
@ -288,6 +288,38 @@ class EfficientViTBlock(torch.nn.Module):
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
class EfficientViTStage(torch.nn.Module):
def __init__(self, do, pre_ed, ed, kd, nh=8,
ar=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],
depth=1):
super().__init__()
if do[0] == 'subsample':
self.resolution = (resolution - 1) // do[1] + 1
down_blocks = []
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(pre_ed, pre_ed, 3, 1, 1, groups=pre_ed)),
ResidualDrop(FFN(pre_ed, int(pre_ed * 2))),))
down_blocks.append(PatchMerging(pre_ed, ed))
down_blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed)),
ResidualDrop(FFN(ed, int(ed * 2))),))
self.downsample = nn.Sequential(*down_blocks)
else:
self.downsample = nn.Identity()
self.resolution = resolution
blocks = []
for d in range(depth):
blocks.append(EfficientViTBlock(ed, kd, nh, ar, self.resolution, window_resolution, kernels))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
class PatchEmbedding(torch.nn.Sequential):
def __init__(self, in_chans, dim):
super().__init__()
@ -331,22 +363,13 @@ class EfficientViTMSRA(nn.Module):
# Build EfficientViT blocks
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
blocks = []
pre_ed = embed_dim[i - 1]
stage = EfficientViTStage(do, pre_ed, ed, kd, nh, ar, resolution, wd, kernels, dpth)
if do[0] == 'subsample' and i != 0:
# Build EfficientViT downsample block
resolution_ = (resolution - 1) // do[1] + 1
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i - 1], embed_dim[i - 1], 3, 1, 1, groups=embed_dim[i - 1])),
ResidualDrop(FFN(embed_dim[i - 1], int(embed_dim[i - 1] * 2))),))
blocks.append(PatchMerging(*embed_dim[i - 1:i + 1]))
resolution = resolution_
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
stride *= 2
for d in range(dpth):
blocks.append(EfficientViTBlock(ed, kd, nh, ar, resolution, wd, kernels))
stages.append(nn.Sequential(*blocks))
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
resolution = stage.resolution
stages.append(stage)
self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
@ -401,10 +424,24 @@ def checkpoint_filter_fn(state_dict, model):
if k.startswith('patch_embed'):
k = k.split('.')
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
if k[2] == 'c':
k[2] = 'conv'
k = '.'.join(k)
elif k.startswith('blocks'):
pass
# k = k.split('.')
# k[0] = 'stages.' + str(int(k[0][6:]) - 1)
# if int(k[1]) >= 2:
# k[1] = 'block'
# else:
# k[1] = 'downsample.' + k[1]
# if k[-1] == 'c':
# k[-1] = 'conv'
# k = '.'.join(k)
elif k.startswith('head'):
k = k.split('.')
k[0] = 'stages.' + str(int(k[0][6:]) - 1)
if k[1] == 'l':
k[1] = 'linear'
k = '.'.join(k)
out_dict[k] = v
return out_dict
@ -416,8 +453,8 @@ def _cfg(url='', **kwargs):
'num_classes': 1000,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.in_conv.conv',
'classifier': 'head',
'first_conv': 'patch_embed.conv1.conv',
'classifier': 'head.linear',
**kwargs,
}