Merge pull request #147 from kozistr/refactor/init-weights

Refactor init_weights()
pull/148/head
Hugo Touvron 2022-02-06 13:53:54 +01:00 committed by GitHub
commit 53c0a07cae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -212,7 +212,7 @@ class cait_models(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)

View File

@ -279,7 +279,7 @@ class PatchConvnet(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.rescale)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)

View File

@ -76,7 +76,7 @@ class resmlp_models(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)