From 0ec146f8ebf27d9c98c3162fa58874c8498e1f74 Mon Sep 17 00:00:00 2001 From: Xinlei Chen Date: Wed, 23 Jun 2021 23:26:16 -0700 Subject: [PATCH] add momentum schedule --- main_moco.py | 18 +++++++++++++++--- moco/builder.py | 12 ++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/main_moco.py b/main_moco.py index c15306e..33ff7a4 100755 --- a/main_moco.py +++ b/main_moco.py @@ -106,7 +106,9 @@ parser.add_argument('--moco-dim', default=256, type=int, parser.add_argument('--moco-mlp-dim', default=4096, type=int, help='hidden dimension in MLPs (default: 4096)') parser.add_argument('--moco-m', default=0.99, type=float, - help='moco momentum of updating momentum encoder (default: 0.99)') + help='moco (base) momentum of updating momentum encoder (default: 0.99)') +parser.add_argument('--moco-m-cos', action='store_true', + help='increase moco (base) momentum with a half-cycle cosine schedule') parser.add_argument('--moco-t', default=1.0, type=float, help='softmax temperature (default: 1.0)') @@ -183,7 +185,7 @@ def main_worker(gpu, ngpus_per_node, args): print("=> creating model '{}'".format(args.arch)) model = moco.builder.MoCo( torchvision_models.__dict__[args.arch], - args.moco_dim, args.moco_mlp_dim, args.moco_m, args.moco_t) + args.moco_dim, args.moco_mlp_dim, args.moco_t) # infer learning rate before changing batch size init_lr = args.lr * args.batch_size / 256 @@ -330,6 +332,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args): model.train() end = time.time() + moco_m = adjust_moco_momentum(epoch, args) for i, (images, _) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) @@ -339,7 +342,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args): images[1] = images[1].cuda(args.gpu, non_blocking=True) # compute output - output1, output2, target = model(im1=images[0], im2=images[1]) + output1, output2, target = model(im1=images[0], im2=images[1], m=moco_m) loss = (criterion(output1, target) + criterion(output2, target)) * (args.moco_t * 2.) # acc1/acc5 are N-way contrast classifier accuracy @@ -419,6 +422,15 @@ def adjust_learning_rate(optimizer, init_lr, epoch, args): param_group['lr'] = lr +def adjust_moco_momentum(epoch, args): + """Adjust moco momentum based on current epoch""" + if args.moco_m_cos: + m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m) + else: + m = args.moco_m + return m + + def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): diff --git a/moco/builder.py b/moco/builder.py index 07c8fcd..152a74a 100644 --- a/moco/builder.py +++ b/moco/builder.py @@ -13,7 +13,7 @@ class MoCo(nn.Module): Build a MoCo model with: a base encoder, a momentum encoder https://arxiv.org/abs/1911.05722 """ - def __init__(self, base_encoder, dim=256, mlp_dim=4096, m=0.99, T=1.0): + def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0): """ dim: feature dimension (default: 256) mlp_dim: hidden dimension in MLPs (default: 4096) @@ -22,7 +22,6 @@ class MoCo(nn.Module): """ super(MoCo, self).__init__() - self.m = m self.T = T # create the encoders @@ -51,16 +50,17 @@ class MoCo(nn.Module): nn.Linear(mlp_dim, dim)) # output layer @torch.no_grad() - def _update_momentum_encoder(self): + def _update_momentum_encoder(self, m): """Momentum update of the momentum encoder""" for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): - param_m.data = param_m.data * self.m + param_b.data * (1. - self.m) + param_m.data = param_m.data * m + param_b.data * (1. - m) - def forward(self, im1, im2): + def forward(self, im1, im2, m): """ Input: im1: first views of images im2: second views of images + m: moco momentum Output: logits, targets """ @@ -74,7 +74,7 @@ class MoCo(nn.Module): # compute momentum features as targets with torch.no_grad(): # no gradient - self._update_momentum_encoder() # update the momentum encoder + self._update_momentum_encoder(m) # update the momentum encoder t1 = self.momentum_encoder(im1) t2 = self.momentum_encoder(im2) # normalize