reinitialize mlp

pull/3/head
Xinlei Chen 2021-07-01 12:54:43 -07:00
parent 7dea7af347
commit 7fbc15248a
1 changed files with 6 additions and 6 deletions

View File

@ -39,13 +39,13 @@ class MoCo(nn.Module):
self.base_encoder = base_encoder(num_classes=mlp_dim)
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
self.base_encoder.fc = nn.Sequential(self.base_encoder.fc,
hidden_dim = self.base_encoder.fc.weight.shape[1]
self.base_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # first layer
nn.Linear(mlp_dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False)) # second layer
self.base_encoder.fc[0].bias.requires_grad = False # hack: not use bias as it is followed by BN
self.momentum_encoder.fc = nn.Sequential(self.momentum_encoder.fc,
self.momentum_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # first layer
nn.Linear(mlp_dim, dim, bias=False),
@ -64,7 +64,8 @@ class MoCo(nn.Module):
self.base_encoder = base_encoder(num_classes=mlp_dim)
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
self.base_encoder.head = nn.Sequential(self.base_encoder.head,
hidden_dim = self.base_encoder.head.weight.shape[1]
self.base_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.GELU(), # first layer
nn.Linear(mlp_dim, mlp_dim, bias=False),
@ -73,8 +74,7 @@ class MoCo(nn.Module):
nn.BatchNorm1d(mlp_dim),
nn.Linear(mlp_dim, dim, bias=False),
nn.BatchNorm1d(dim, affine=False)) # third layer
self.base_encoder.head[0].bias.requires_grad = False # hack: not use bias as it is followed by BN
self.momentum_encoder.head = nn.Sequential(self.momentum_encoder.head,
self.momentum_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.GELU(), # first layer
nn.Linear(mlp_dim, mlp_dim, bias=False),