Update resmlp_models.py

pull/118/head
Hugo Touvron 2021-06-18 20:59:28 +02:00 committed by GitHub
parent 6fa7ef60b4
commit 31b3d676b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 18 deletions

View File

@ -22,7 +22,7 @@ class Affine(nn.Module):
def forward(self, x):
return self.alpha * x + self.beta
class layers_scale_mlp_blocks(nn.Module):
class layers_scale_mlp_blocks(nn.Module):
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
super().__init__()
@ -38,17 +38,17 @@ class Affine(nn.Module):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class resmlp_models(nn.Module):
class resmlp_models(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
Patch_layer=PatchEmbed,act_layer=nn.GELU,
drop_path_rate=0.0,init_scale=1e-4):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
@ -56,23 +56,23 @@ class Affine(nn.Module):
img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
layers_scale_mlp_blocks(
dim=embed_dim,drop=drop_rate,drop_path=dpr[i],
act_layer=act_layer,init_values=init_scale,
num_patches=num_patches)
for i in range(depth)])
self.norm = Affine(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
@ -81,7 +81,7 @@ class Affine(nn.Module):
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
@ -93,12 +93,12 @@ class Affine(nn.Module):
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
for i , blk in enumerate(self.blocks):
x = blk(x)
x = self.norm(x)
x = x.mean(dim=1).reshape(B,1,-1)
@ -108,7 +108,7 @@ class Affine(nn.Module):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def resmlp_12(pretrained=False,dist=False, **kwargs):
model = resmlp_models(
@ -173,7 +173,7 @@ def resmlp_36(pretrained=False,dist=False, **kwargs):
return model
@register_model
def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs):
model = resmlp_models(
patch_size=8, embed_dim=768, depth=24,
Patch_layer=PatchEmbed,
@ -182,7 +182,7 @@ def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth"
elif 22k:
elif in_22k:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"