fix efficientvits

pull/1894/head
方曦 2023-08-02 14:12:37 +08:00
parent 82d1e99e1a
commit 43443f64eb
2 changed files with 129 additions and 45 deletions

View File

@ -15,7 +15,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from functools import partial
from timm.layers import SelectAdaptivePool2d
from collections import OrderedDict
@ -25,6 +24,7 @@ def val2list(x: list or tuple or any, repeat_time=1):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
# repeat elements if necessary
x = val2list(x)
@ -33,6 +33,7 @@ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
return tuple(x)
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
@ -40,6 +41,7 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, .
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2
class ConvNormAct(nn.Module):
def __init__(
self,
@ -263,9 +265,9 @@ class LiteMSA(nn.Module):
)
multi_scale_qkv = torch.transpose(multi_scale_qkv, -1, -2)
q, k, v = (
multi_scale_qkv[..., 0 : self.dim],
multi_scale_qkv[..., self.dim : 2 * self.dim],
multi_scale_qkv[..., 2 * self.dim :],
multi_scale_qkv[..., 0: self.dim],
multi_scale_qkv[..., self.dim: 2 * self.dim],
multi_scale_qkv[..., 2 * self.dim:],
)
# lightweight global attention
@ -286,6 +288,7 @@ class LiteMSA(nn.Module):
return out
class EfficientViTBlock(nn.Module):
def __init__(
self,
@ -355,6 +358,7 @@ class ResidualBlock(nn.Module):
res = self.post_act(res)
return res
class ClsHead(nn.Module):
def __init__(
self,
@ -425,7 +429,7 @@ class EfficientViT(nn.Module):
stem_block += 1
in_channels = width_list[0]
self.stem = nn.Sequential(OrderedDict(input_stem))
stride = 4
stride = 2
self.feature_info = []
stages = []
stage_idx = 0
@ -448,7 +452,7 @@ class EfficientViT(nn.Module):
stages.append(nn.Sequential(*stage))
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{stage_idx}')]
stage_idx += 1
for w, d in zip(width_list[3:], depth_list[3:]):
stage = []
block = self.build_local_block(
@ -625,11 +629,13 @@ default_cfgs = generate_default_cfgs(
def _create_efficientvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
EfficientViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model

View File

@ -10,15 +10,12 @@ __all__ = ['EfficientViTMSRA']
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import trunc_normal_
from timm.models.layers import SqueezeExcite
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from functools import partial
from collections import OrderedDict
import itertools
@ -44,6 +41,7 @@ class ConvBN(torch.nn.Sequential):
m.bias.data.copy_(b)
return m
class BNLinear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
@ -69,6 +67,7 @@ class BNLinear(torch.nn.Sequential):
m.bias.data.copy_(b)
return m
class PatchMerging(torch.nn.Module):
def __init__(self, dim, out_dim):
super().__init__()
@ -83,6 +82,7 @@ class PatchMerging(torch.nn.Module):
x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
return x
class ResidualDrop(torch.nn.Module):
def __init__(self, m, drop=0.):
super().__init__()
@ -96,6 +96,7 @@ class ResidualDrop(torch.nn.Module):
else:
return x + self.m(x)
class FFN(torch.nn.Module):
def __init__(self, ed, h):
super().__init__()
@ -107,6 +108,7 @@ class FFN(torch.nn.Module):
x = self.pw2(self.act(self.pw1(x)))
return x
class CascadedGroupAttention(torch.nn.Module):
r""" Cascaded Group Attention.
@ -164,30 +166,31 @@ class CascadedGroupAttention(torch.nn.Module):
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x (B,C,H,W)
def forward(self, x):
B, C, H, W = x.shape
trainingab = self.attention_biases[:, self.attention_bias_idxs]
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
if i > 0: # add the previous output to the input
if i > 0:
feat = feat + feats_in[i]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
q = self.dws[i](q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
attn = (
(q.transpose(-2, -1) @ k) * self.scale
+
(trainingab[i] if self.training else self.ab[i])
)
attn = attn.softmax(dim=-1) # BNN
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
attn = attn.softmax(dim=-1)
feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W)
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
class LocalWindowAttention(torch.nn.Module):
r""" Local Window Attention.
@ -211,19 +214,18 @@ class LocalWindowAttention(torch.nn.Module):
self.resolution = resolution
assert window_resolution > 0, 'window_size must be greater than 0'
self.window_resolution = window_resolution
window_resolution = min(window_resolution, resolution)
self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
attn_ratio=attn_ratio,
resolution=window_resolution,
kernels=kernels,)
attn_ratio=attn_ratio,
resolution=window_resolution,
kernels=kernels,)
def forward(self, x):
H = W = self.resolution
B, C, H_, W_ = x.shape
# Only check this for classifcation models
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
@ -246,14 +248,14 @@ class LocalWindowAttention(torch.nn.Module):
).permute(0, 3, 1, 2)
x = self.attn(x)
# window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
C).transpose(2, 3).reshape(B, pH, pW, C)
x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.permute(0, 3, 1, 2)
return x
class EfficientViTBlock(torch.nn.Module):
class EfficientViTBlock(torch.nn.Module):
""" A basic EfficientViT building block.
Args:
@ -265,26 +267,27 @@ class EfficientViTBlock(torch.nn.Module):
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self,ed, kd, nh=8,
def __init__(self, ed, kd, nh=8,
ar=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],):
super().__init__()
self.dw0 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
self.ffn0 = ResidualDrop(FFN(ed, int(ed * 2)))
self.mixer = ResidualDrop(
LocalWindowAttention(ed, kd, nh, attn_ratio=ar, resolution=resolution,
window_resolution=window_resolution, kernels=kernels))
window_resolution=window_resolution, kernels=kernels))
self.dw1 = ResidualDrop(ConvBN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
self.ffn1 = ResidualDrop(FFN(ed, int(ed * 2)))
def forward(self, x):
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
class PatchEmbedding(torch.nn.Sequential):
def __init__(self, in_chans, dim):
super().__init__()
@ -297,6 +300,7 @@ class PatchEmbedding(torch.nn.Sequential):
self.add_module('conv4', ConvBN(dim // 2, dim, 3, 2, 1))
self.patch_size = 16
class EfficientViTMSRA(nn.Module):
def __init__(
self,
@ -323,31 +327,31 @@ class EfficientViTMSRA(nn.Module):
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
self.feature_info = []
stages = []
self.feature_info += [dict(num_chs=embed_dim[0], reduction=stride, module=f'stages.{0}')]
# Build EfficientViT blocks
# Build EfficientViT blocks
for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
blocks = []
if do[0] == 'subsample':
if do[0] == 'subsample' and i != 0:
# Build EfficientViT downsample block
resolution_ = (resolution - 1) // do[1] + 1
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
blocks.append(PatchMerging(*embed_dim[i:i + 2], resolution))
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i - 1], embed_dim[i - 1], 3, 1, 1, groups=embed_dim[i - 1])),
ResidualDrop(FFN(embed_dim[i - 1], int(embed_dim[i - 1] * 2))),))
blocks.append(PatchMerging(*embed_dim[i - 1:i + 1]))
resolution = resolution_
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),
ResidualDrop(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))),))
blocks.append(torch.nn.Sequential(ResidualDrop(ConvBN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
ResidualDrop(FFN(embed_dim[i], int(embed_dim[i] * 2))),))
stride *= 2
for d in range(dpth):
blocks.append(EfficientViTBlock(ed, kd, nh, ar, resolution, wd, kernels))
stages.append(nn.Sequential(*blocks))
self.feature_info += [dict(num_chs=embed_dim[i+1], reduction=stride, module=f'stages.{i+1}')]
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
self.head = BNLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.out_dims = embed_dim[-1]
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
@ -369,7 +373,7 @@ class EfficientViTMSRA(nn.Module):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
self.head = BNLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
@ -389,12 +393,19 @@ class EfficientViTMSRA(nn.Module):
def checkpoint_filter_fn(state_dict, model):
target_keys = list(model.state_dict().keys())
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
if 'model' in state_dict.keys():
state_dict = state_dict['model']
out_dict = {}
for i, (k, v) in enumerate(state_dict.items()):
out_dict[target_keys[i]] = v
for k, v in state_dict.items():
if k.startswith('patch_embed'):
k = k.split('.')
k[1] = 'conv' + str(int(k[1]) // 2 + 1)
k = '.'.join(k)
elif k.startswith('blocks'):
k = k.split('.')
k[0] = 'stages.' + str(int(k[0][6:]) - 1)
k = '.'.join(k)
out_dict[k] = v
return out_dict
@ -415,16 +426,33 @@ default_cfgs = generate_default_cfgs(
'efficientvit_m0.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
),
'efficientvit_m1.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
),
'efficientvit_m2.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
),
'efficientvit_m3.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
),
'efficientvit_m4.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
),
'efficientvit_m5.r224_in1k': _cfg(
url='https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
),
}
)
def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2))
model = build_model_with_cfg(
EfficientViTMSRA,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
@ -439,3 +467,53 @@ def efficientvit_m0(pretrained=False, **kwargs):
window_size=[7, 7, 7],
kernels=[5, 5, 5, 5])
return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_m1(pretrained=False, **kwargs):
model_args = dict(img_size=224,
embed_dim=[128, 144, 192],
depth=[1, 2, 3],
num_heads=[2, 3, 3],
window_size=[7, 7, 7],
kernels=[7, 5, 3, 3])
return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_m2(pretrained=False, **kwargs):
model_args = dict(img_size=224,
embed_dim=[128, 192, 224],
depth=[1, 2, 3],
num_heads=[4, 3, 2],
window_size=[7, 7, 7],
kernels=[7, 5, 3, 3])
return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_m3(pretrained=False, **kwargs):
model_args = dict(img_size=224,
embed_dim=[128, 240, 320],
depth=[1, 2, 3],
num_heads=[4, 3, 4],
window_size=[7, 7, 7],
kernels=[5, 5, 5, 5])
return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_m4(pretrained=False, **kwargs):
model_args = dict(img_size=224,
embed_dim=[128, 256, 384],
depth=[1, 2, 3],
num_heads=[4, 4, 4],
window_size=[7, 7, 7],
kernels=[7, 5, 3, 3])
return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_m5(pretrained=False, **kwargs):
model_args = dict(img_size=224,
embed_dim=[192, 288, 384],
depth=[1, 3, 4],
num_heads=[3, 3, 4],
window_size=[7, 7, 7],
kernels=[7, 5, 3, 3])
return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))