refactor: init_weights

pull/147/head
kozistr 2022-02-06 10:15:05 +09:00
parent e9a4a1a848
commit 322b428dc0
3 changed files with 3 additions and 3 deletions

View File

@ -212,7 +212,7 @@ class cait_models(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)

View File

@ -279,7 +279,7 @@ class PatchConvnet(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.rescale)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)

View File

@ -76,7 +76,7 @@ class resmlp_models(nn.Module):
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)