mirror of https://github.com/facebookresearch/deit
Update resmlp_models.py
parent
6fa7ef60b4
commit
31b3d676b3
|
@ -22,7 +22,7 @@ class Affine(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.alpha * x + self.beta
|
||||
|
||||
class layers_scale_mlp_blocks(nn.Module):
|
||||
class layers_scale_mlp_blocks(nn.Module):
|
||||
|
||||
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196):
|
||||
super().__init__()
|
||||
|
@ -40,7 +40,7 @@ class Affine(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class resmlp_models(nn.Module):
|
||||
class resmlp_models(nn.Module):
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0.,
|
||||
Patch_layer=PatchEmbed,act_layer=nn.GELU,
|
||||
|
@ -173,7 +173,7 @@ def resmlp_36(pretrained=False,dist=False, **kwargs):
|
|||
return model
|
||||
|
||||
@register_model
|
||||
def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
|
||||
def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs):
|
||||
model = resmlp_models(
|
||||
patch_size=8, embed_dim=768, depth=24,
|
||||
Patch_layer=PatchEmbed,
|
||||
|
@ -182,7 +182,7 @@ def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs):
|
|||
if pretrained:
|
||||
if dist:
|
||||
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth"
|
||||
elif 22k:
|
||||
elif in_22k:
|
||||
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth"
|
||||
else:
|
||||
url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"
|
||||
|
|
Loading…
Reference in New Issue