Should have included Conv2d layers in original weight init. Lets see what the impact is...

This commit is contained in:
Ross Wightman 2021-03-18 23:15:48 -07:00
parent 4de57ccf01
commit cbcb76d72c

View File

@ -476,7 +476,7 @@ class VisionTransformer(nn.Module):
def _init_weights_original(m: nn.Module, n: str = ''):
if isinstance(m, nn.Linear):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)