diff --git a/resmlp_models.py b/resmlp_models.py index 87e1c1a..2a7d62b 100644 --- a/resmlp_models.py +++ b/resmlp_models.py @@ -22,7 +22,7 @@ class Affine(nn.Module): def forward(self, x): return self.alpha * x + self.beta - class layers_scale_mlp_blocks(nn.Module): +class layers_scale_mlp_blocks(nn.Module): def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU,init_values=1e-4,num_patches = 196): super().__init__() @@ -38,17 +38,17 @@ class Affine(nn.Module): x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x).transpose(1,2)).transpose(1,2)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x - - - class resmlp_models(nn.Module): - + + +class resmlp_models(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,drop_rate=0., Patch_layer=PatchEmbed,act_layer=nn.GELU, drop_path_rate=0.0,init_scale=1e-4): super().__init__() - - + + self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim @@ -56,23 +56,23 @@ class Affine(nn.Module): img_size=img_size, patch_size=patch_size, in_chans=int(in_chans), embed_dim=embed_dim) num_patches = self.patch_embed.num_patches dpr = [drop_path_rate for i in range(depth)] - + self.blocks = nn.ModuleList([ layers_scale_mlp_blocks( dim=embed_dim,drop=drop_rate,drop_path=dpr[i], act_layer=act_layer,init_values=init_scale, num_patches=num_patches) for i in range(depth)]) - + self.norm = Affine(embed_dim) - + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) - + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) @@ -81,7 +81,7 @@ class Affine(nn.Module): elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - + def get_classifier(self): @@ -93,12 +93,12 @@ class Affine(nn.Module): def forward_features(self, x): B = x.shape[0] - + x = self.patch_embed(x) - + for i , blk in enumerate(self.blocks): x = blk(x) - + x = self.norm(x) x = x.mean(dim=1).reshape(B,1,-1) @@ -108,7 +108,7 @@ class Affine(nn.Module): x = self.forward_features(x) x = self.head(x) return x - + @register_model def resmlp_12(pretrained=False,dist=False, **kwargs): model = resmlp_models( @@ -173,7 +173,7 @@ def resmlp_36(pretrained=False,dist=False, **kwargs): return model @register_model -def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs): +def resmlpB_24(pretrained=False,dist=False, in_22k = False, **kwargs): model = resmlp_models( patch_size=8, embed_dim=768, depth=24, Patch_layer=PatchEmbed, @@ -182,7 +182,7 @@ def resmlpB_24(pretrained=False,dist=False, 22k = False, **kwargs): if pretrained: if dist: url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth" - elif 22k: + elif in_22k: url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth" else: url_path = "https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth"