fix bug for dimensions

pull/3/head
Xinlei Chen 2021-08-05 03:06:48 -07:00
parent 1e7ac7a290
commit 6cf5a82e24
1 changed files with 1 additions and 2 deletions

View File

@ -17,7 +17,6 @@ class MoCo(nn.Module):
"""
dim: feature dimension (default: 256)
mlp_dim: hidden dimension in MLPs (default: 4096)
m: moco momentum of updating momentum encoder (default: 0.99)
T: softmax temperature (default: 1.0)
"""
super(MoCo, self).__init__()
@ -111,7 +110,7 @@ class MoCo_ResNet(MoCo):
class MoCo_ViT(MoCo):
def _build_projector_and_predictor_mlps(self, base_encoder, dim=256, mlp_dim=4096):
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
hidden_dim = self.base_encoder.head.weight.shape[1]
del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer