add momentum schedule

pull/3/head
Xinlei Chen 2021-06-23 23:26:16 -07:00
parent a9b3a1e0a6
commit 0ec146f8eb
2 changed files with 21 additions and 9 deletions

View File

@ -106,7 +106,9 @@ parser.add_argument('--moco-dim', default=256, type=int,
parser.add_argument('--moco-mlp-dim', default=4096, type=int,
help='hidden dimension in MLPs (default: 4096)')
parser.add_argument('--moco-m', default=0.99, type=float,
help='moco momentum of updating momentum encoder (default: 0.99)')
help='moco (base) momentum of updating momentum encoder (default: 0.99)')
parser.add_argument('--moco-m-cos', action='store_true',
help='increase moco (base) momentum with a half-cycle cosine schedule')
parser.add_argument('--moco-t', default=1.0, type=float,
help='softmax temperature (default: 1.0)')
@ -183,7 +185,7 @@ def main_worker(gpu, ngpus_per_node, args):
print("=> creating model '{}'".format(args.arch))
model = moco.builder.MoCo(
torchvision_models.__dict__[args.arch],
args.moco_dim, args.moco_mlp_dim, args.moco_m, args.moco_t)
args.moco_dim, args.moco_mlp_dim, args.moco_t)
# infer learning rate before changing batch size
init_lr = args.lr * args.batch_size / 256
@ -330,6 +332,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
model.train()
end = time.time()
moco_m = adjust_moco_momentum(epoch, args)
for i, (images, _) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
@ -339,7 +342,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
images[1] = images[1].cuda(args.gpu, non_blocking=True)
# compute output
output1, output2, target = model(im1=images[0], im2=images[1])
output1, output2, target = model(im1=images[0], im2=images[1], m=moco_m)
loss = (criterion(output1, target) + criterion(output2, target)) * (args.moco_t * 2.)
# acc1/acc5 are N-way contrast classifier accuracy
@ -419,6 +422,15 @@ def adjust_learning_rate(optimizer, init_lr, epoch, args):
param_group['lr'] = lr
def adjust_moco_momentum(epoch, args):
"""Adjust moco momentum based on current epoch"""
if args.moco_m_cos:
m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m)
else:
m = args.moco_m
return m
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():

View File

@ -13,7 +13,7 @@ class MoCo(nn.Module):
Build a MoCo model with: a base encoder, a momentum encoder
https://arxiv.org/abs/1911.05722
"""
def __init__(self, base_encoder, dim=256, mlp_dim=4096, m=0.99, T=1.0):
def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
"""
dim: feature dimension (default: 256)
mlp_dim: hidden dimension in MLPs (default: 4096)
@ -22,7 +22,6 @@ class MoCo(nn.Module):
"""
super(MoCo, self).__init__()
self.m = m
self.T = T
# create the encoders
@ -51,16 +50,17 @@ class MoCo(nn.Module):
nn.Linear(mlp_dim, dim)) # output layer
@torch.no_grad()
def _update_momentum_encoder(self):
def _update_momentum_encoder(self, m):
"""Momentum update of the momentum encoder"""
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
param_m.data = param_m.data * self.m + param_b.data * (1. - self.m)
param_m.data = param_m.data * m + param_b.data * (1. - m)
def forward(self, im1, im2):
def forward(self, im1, im2, m):
"""
Input:
im1: first views of images
im2: second views of images
m: moco momentum
Output:
logits, targets
"""
@ -74,7 +74,7 @@ class MoCo(nn.Module):
# compute momentum features as targets
with torch.no_grad(): # no gradient
self._update_momentum_encoder() # update the momentum encoder
self._update_momentum_encoder(m) # update the momentum encoder
t1 = self.momentum_encoder(im1)
t2 = self.momentum_encoder(im2)
# normalize