From 322b428dc000065d60cc15c81cfbee16f36fb86b Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 6 Feb 2022 10:15:05 +0900 Subject: [PATCH] refactor: init_weights --- cait_models.py | 2 +- patchconvnet_models.py | 2 +- resmlp_models.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cait_models.py b/cait_models.py index 7821a1d..7822b24 100644 --- a/cait_models.py +++ b/cait_models.py @@ -212,7 +212,7 @@ class cait_models(nn.Module): def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: + if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) diff --git a/patchconvnet_models.py b/patchconvnet_models.py index b4bd046..db3ea16 100644 --- a/patchconvnet_models.py +++ b/patchconvnet_models.py @@ -279,7 +279,7 @@ class PatchConvnet(nn.Module): def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.rescale) - if isinstance(m, nn.Linear) and m.bias is not None: + if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) diff --git a/resmlp_models.py b/resmlp_models.py index 9882a29..ed3df21 100644 --- a/resmlp_models.py +++ b/resmlp_models.py @@ -76,7 +76,7 @@ class resmlp_models(nn.Module): def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: + if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0)