2021-06-10 23:59:43 +08:00
|
|
|
# Copyright (c) 2015-present, Facebook, Inc.
|
|
|
|
# All rights reserved.
|
2021-06-10 23:57:03 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
from timm.models.vision_transformer import Mlp, PatchEmbed , _cfg
|
|
|
|
from timm.models.registry import register_model
|
2021-06-19 04:36:35 +08:00
|
|
|
from timm.models.layers import trunc_normal_, DropPath
|
2021-06-10 23:57:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
2021-12-07 19:41:24 +08:00
|
|
|
'resmlp_12', 'resmlp_24', 'resmlp_36', 'resmlpB_24'
|
2021-06-10 23:57:03 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
class Affine(nn.Module):
|
|
|
|
def __init__(self, dim):
|
|
|
|
super().__init__()
|
|
|
|
self.alpha = nn.Parameter(torch.ones(dim))
|
|
|
|
self.beta = nn.Parameter(torch.zeros(dim))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.alpha * x + self.beta
|
|
|
|
|
2021-06-19 02:59:28 +08:00
|
|
|
class layers_scale_mlp_blocks(nn.Module):
|
2021-06-10 23:57:03 +08:00
|
|
|
|
|
|
|
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
|
|
|
|
super().__init__()
|
|
|
|
self.norm1 = Affine(dim)
|
|
|
|
self.attn = nn.Linear(num_patches, num_patches)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.norm2 = Affine(dim)
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=int(4.0 * dim), act_layer=act_layer, drop=drop)
|
|
|
|
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
|
|
|
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2))
|
|
|
|
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
|
|
|
return x
|
2021-06-19 02:59:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
class resmlp_models(nn.Module):
|
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
|
|
|
|
Patch_layer=PatchEmbed,act_layer=nn.GELU,
|
|
|
|
drop_path_rate=0.0,init_scale=1e-4):
|
|
|
|
super().__init__()
|
|
|
|
|
2021-06-19 02:59:28 +08:00
|
|
|
|
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_features = self.embed_dim = embed_dim
|
|
|
|
|
|
|
|
self.patch_embed = Patch_layer(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim)
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
dpr = [drop_path_rate for i in range(depth)]
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
self.blocks = nn.ModuleList([
|
|
|
|
layers_scale_mlp_blocks(
|
|
|
|
dim=embed_dim,drop=drop_rate,drop_path=dpr[i],
|
|
|
|
act_layer=act_layer,init_values=init_scale,
|
|
|
|
num_patches=num_patches)
|
|
|
|
for i in range(depth)])
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
|
|
|
|
self.norm = Affine(embed_dim)
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
|
|
|
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
self.apply(self._init_weights)
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
def _init_weights(self, m):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
trunc_normal_(m.weight, std=0.02)
|
2022-02-06 09:15:05 +08:00
|
|
|
if m.bias is not None:
|
2021-06-10 23:57:03 +08:00
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
B = x.shape[0]
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
x = self.patch_embed(x)
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
for i , blk in enumerate(self.blocks):
|
|
|
|
x = blk(x)
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
x = self.norm(x)
|
|
|
|
x = x.mean(dim=1).reshape(B,1,-1)
|
|
|
|
|
|
|
|
return x[:, 0]
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.head(x)
|
|
|
|
return x
|
2021-06-19 02:59:28 +08:00
|
|
|
|
2021-06-10 23:57:03 +08:00
|
|
|
@register_model
|
|
|
|
def resmlp_12(pretrained=False,dist=False, **kwargs):
|
|
|
|
model = resmlp_models(
|
|
|
|
patch_size=16, embed_dim=384, depth=12,
|
|
|
|
Patch_layer=PatchEmbed,
|
|
|
|
init_scale=0.1,**kwargs)
|
|
|
|
|
|
|
|
model.default_cfg = _cfg()
|
|
|
|
if pretrained:
|
|
|
|
if dist:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth"
|
|
|
|
else:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth"
|
|
|
|
checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
|
url=url_path,
|
|
|
|
map_location="cpu", check_hash=True
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
return model
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def resmlp_24(pretrained=False,dist=False,dino=False, **kwargs):
|
|
|
|
model = resmlp_models(
|
|
|
|
patch_size=16, embed_dim=384, depth=24,
|
|
|
|
Patch_layer=PatchEmbed,
|
|
|
|
init_scale=1e-5,**kwargs)
|
|
|
|
model.default_cfg = _cfg()
|
|
|
|
if pretrained:
|
|
|
|
if dist:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth"
|
|
|
|
elif dino:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth"
|
|
|
|
else:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth"
|
|
|
|
checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
|
url=url_path,
|
|
|
|
map_location="cpu", check_hash=True
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
return model
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def resmlp_36(pretrained=False,dist=False, **kwargs):
|
|
|
|
model = resmlp_models(
|
|
|
|
patch_size=16, embed_dim=384, depth=36,
|
|
|
|
Patch_layer=PatchEmbed,
|
|
|
|
init_scale=1e-6,**kwargs)
|
|
|
|
model.default_cfg = _cfg()
|
|
|
|
if pretrained:
|
|
|
|
if dist:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth"
|
|
|
|
else:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth"
|
|
|
|
checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
|
url=url_path,
|
|
|
|
map_location="cpu", check_hash=True
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
return model
|
|
|
|
|
|
|
|
@register_model
|
2021-06-19 02:59:28 +08:00
|
|
|
def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs):
|
2021-06-10 23:57:03 +08:00
|
|
|
model = resmlp_models(
|
|
|
|
patch_size=8, embed_dim=768, depth=24,
|
|
|
|
Patch_layer=PatchEmbed,
|
|
|
|
init_scale=1e-6,**kwargs)
|
|
|
|
model.default_cfg = _cfg()
|
|
|
|
if pretrained:
|
|
|
|
if dist:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth"
|
2021-06-19 02:59:28 +08:00
|
|
|
elif in_22k:
|
2021-06-10 23:57:03 +08:00
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth"
|
|
|
|
else:
|
|
|
|
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"
|
|
|
|
|
|
|
|
checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
|
url=url_path,
|
|
|
|
map_location="cpu", check_hash=True
|
|
|
|
)
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
|
|
|
|
return model
|