reinitialize mlp
parent
7dea7af347
commit
7fbc15248a
|
@ -39,13 +39,13 @@ class MoCo(nn.Module):
|
|||
self.base_encoder = base_encoder(num_classes=mlp_dim)
|
||||
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
|
||||
|
||||
self.base_encoder.fc = nn.Sequential(self.base_encoder.fc,
|
||||
hidden_dim = self.base_encoder.fc.weight.shape[1]
|
||||
self.base_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
|
||||
nn.BatchNorm1d(mlp_dim),
|
||||
nn.ReLU(inplace=True), # first layer
|
||||
nn.Linear(mlp_dim, dim, bias=False),
|
||||
nn.BatchNorm1d(dim, affine=False)) # second layer
|
||||
self.base_encoder.fc[0].bias.requires_grad = False # hack: not use bias as it is followed by BN
|
||||
self.momentum_encoder.fc = nn.Sequential(self.momentum_encoder.fc,
|
||||
self.momentum_encoder.fc = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
|
||||
nn.BatchNorm1d(mlp_dim),
|
||||
nn.ReLU(inplace=True), # first layer
|
||||
nn.Linear(mlp_dim, dim, bias=False),
|
||||
|
@ -64,7 +64,8 @@ class MoCo(nn.Module):
|
|||
self.base_encoder = base_encoder(num_classes=mlp_dim)
|
||||
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
|
||||
|
||||
self.base_encoder.head = nn.Sequential(self.base_encoder.head,
|
||||
hidden_dim = self.base_encoder.head.weight.shape[1]
|
||||
self.base_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
|
||||
nn.BatchNorm1d(mlp_dim),
|
||||
nn.GELU(), # first layer
|
||||
nn.Linear(mlp_dim, mlp_dim, bias=False),
|
||||
|
@ -73,8 +74,7 @@ class MoCo(nn.Module):
|
|||
nn.BatchNorm1d(mlp_dim),
|
||||
nn.Linear(mlp_dim, dim, bias=False),
|
||||
nn.BatchNorm1d(dim, affine=False)) # third layer
|
||||
self.base_encoder.head[0].bias.requires_grad = False # hack: not use bias as it is followed by BN
|
||||
self.momentum_encoder.head = nn.Sequential(self.momentum_encoder.head,
|
||||
self.momentum_encoder.head = nn.Sequential(nn.Linear(hidden_dim, mlp_dim, bias=False),
|
||||
nn.BatchNorm1d(mlp_dim),
|
||||
nn.GELU(), # first layer
|
||||
nn.Linear(mlp_dim, mlp_dim, bias=False),
|
||||
|
|
Loading…
Reference in New Issue