add momentum schedule
parent
a9b3a1e0a6
commit
0ec146f8eb
18
main_moco.py
18
main_moco.py
|
@ -106,7 +106,9 @@ parser.add_argument('--moco-dim', default=256, type=int,
|
|||
parser.add_argument('--moco-mlp-dim', default=4096, type=int,
|
||||
help='hidden dimension in MLPs (default: 4096)')
|
||||
parser.add_argument('--moco-m', default=0.99, type=float,
|
||||
help='moco momentum of updating momentum encoder (default: 0.99)')
|
||||
help='moco (base) momentum of updating momentum encoder (default: 0.99)')
|
||||
parser.add_argument('--moco-m-cos', action='store_true',
|
||||
help='increase moco (base) momentum with a half-cycle cosine schedule')
|
||||
parser.add_argument('--moco-t', default=1.0, type=float,
|
||||
help='softmax temperature (default: 1.0)')
|
||||
|
||||
|
@ -183,7 +185,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||
print("=> creating model '{}'".format(args.arch))
|
||||
model = moco.builder.MoCo(
|
||||
torchvision_models.__dict__[args.arch],
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_m, args.moco_t)
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
||||
|
||||
# infer learning rate before changing batch size
|
||||
init_lr = args.lr * args.batch_size / 256
|
||||
|
@ -330,6 +332,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
|
|||
model.train()
|
||||
|
||||
end = time.time()
|
||||
moco_m = adjust_moco_momentum(epoch, args)
|
||||
for i, (images, _) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
@ -339,7 +342,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
|
|||
images[1] = images[1].cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output1, output2, target = model(im1=images[0], im2=images[1])
|
||||
output1, output2, target = model(im1=images[0], im2=images[1], m=moco_m)
|
||||
loss = (criterion(output1, target) + criterion(output2, target)) * (args.moco_t * 2.)
|
||||
|
||||
# acc1/acc5 are N-way contrast classifier accuracy
|
||||
|
@ -419,6 +422,15 @@ def adjust_learning_rate(optimizer, init_lr, epoch, args):
|
|||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def adjust_moco_momentum(epoch, args):
|
||||
"""Adjust moco momentum based on current epoch"""
|
||||
if args.moco_m_cos:
|
||||
m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m)
|
||||
else:
|
||||
m = args.moco_m
|
||||
return m
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -13,7 +13,7 @@ class MoCo(nn.Module):
|
|||
Build a MoCo model with: a base encoder, a momentum encoder
|
||||
https://arxiv.org/abs/1911.05722
|
||||
"""
|
||||
def __init__(self, base_encoder, dim=256, mlp_dim=4096, m=0.99, T=1.0):
|
||||
def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
|
||||
"""
|
||||
dim: feature dimension (default: 256)
|
||||
mlp_dim: hidden dimension in MLPs (default: 4096)
|
||||
|
@ -22,7 +22,6 @@ class MoCo(nn.Module):
|
|||
"""
|
||||
super(MoCo, self).__init__()
|
||||
|
||||
self.m = m
|
||||
self.T = T
|
||||
|
||||
# create the encoders
|
||||
|
@ -51,16 +50,17 @@ class MoCo(nn.Module):
|
|||
nn.Linear(mlp_dim, dim)) # output layer
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_momentum_encoder(self):
|
||||
def _update_momentum_encoder(self, m):
|
||||
"""Momentum update of the momentum encoder"""
|
||||
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
|
||||
param_m.data = param_m.data * self.m + param_b.data * (1. - self.m)
|
||||
param_m.data = param_m.data * m + param_b.data * (1. - m)
|
||||
|
||||
def forward(self, im1, im2):
|
||||
def forward(self, im1, im2, m):
|
||||
"""
|
||||
Input:
|
||||
im1: first views of images
|
||||
im2: second views of images
|
||||
m: moco momentum
|
||||
Output:
|
||||
logits, targets
|
||||
"""
|
||||
|
@ -74,7 +74,7 @@ class MoCo(nn.Module):
|
|||
|
||||
# compute momentum features as targets
|
||||
with torch.no_grad(): # no gradient
|
||||
self._update_momentum_encoder() # update the momentum encoder
|
||||
self._update_momentum_encoder(m) # update the momentum encoder
|
||||
t1 = self.momentum_encoder(im1)
|
||||
t2 = self.momentum_encoder(im2)
|
||||
# normalize
|
||||
|
|
Loading…
Reference in New Issue