Update resmlp_models.py

pull/118/head
Hugo Touvron 2021-06-18 20:59:28 +02:00 committed by GitHub
parent 6fa7ef60b4
commit 31b3d676b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 18 deletions

View File

@ -22,7 +22,7 @@ class Affine(nn.Module):
def forward(self, x):
return self.alpha * x + self.beta
class layers_scale_mlp_blocks(nn.Module):
class layers_scale_mlp_blocks(nn.Module):
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
super().__init__()
@ -40,7 +40,7 @@ class Affine(nn.Module):
return x
class resmlp_models(nn.Module):
class resmlp_models(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
Patch_layer=PatchEmbed,act_layer=nn.GELU,
@ -173,7 +173,7 @@ def resmlp_36(pretrained=False,dist=False, **kwargs):
return model
@register_model
def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs):
model = resmlp_models(
patch_size=8, embed_dim=768, depth=24,
Patch_layer=PatchEmbed,
@ -182,7 +182,7 @@ def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
if pretrained:
if dist:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth"
elif 22k:
elif in_22k:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth"
else:
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"