mirror of
https://github.com/facebookresearch/moco-v3.git
synced 2025-06-03 14:59:22 +08:00
linear eval
This commit is contained in:
parent
89d29d8b64
commit
7c5f6867de
@ -1,7 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import builtins
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
@ -35,19 +41,17 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
' (default: resnet50)')
|
||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 32)')
|
||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
parser.add_argument('-b', '--batch-size', default=4096, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
help='mini-batch size (default: 4096), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=30., type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
|
||||
help='learning rate schedule (when to drop lr by a ratio)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial (base) learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
|
||||
@ -77,6 +81,7 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
# additional configs:
|
||||
parser.add_argument('--pretrained', default='', type=str,
|
||||
help='path to moco pretrained checkpoint')
|
||||
|
||||
@ -161,10 +166,10 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
# rename moco pre-trained keys
|
||||
state_dict = checkpoint['state_dict']
|
||||
for k in list(state_dict.keys()):
|
||||
# retain only encoder_q up to before the embedding layer
|
||||
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
||||
# retain only base_encoder up to before the embedding layer
|
||||
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.fc'):
|
||||
# remove prefix
|
||||
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
||||
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
||||
# delete renamed or unused k
|
||||
del state_dict[k]
|
||||
|
||||
@ -176,6 +181,9 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.pretrained))
|
||||
|
||||
# infer learning rate before changing batch size
|
||||
init_lr = args.lr * args.batch_size / 256
|
||||
|
||||
if args.distributed:
|
||||
# For multiprocessing distributed, DistributedDataParallel constructor
|
||||
# should always set the single device scope, otherwise,
|
||||
@ -211,7 +219,8 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
# optimize only the linear classifier
|
||||
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
assert len(parameters) == 2 # fc.weight, fc.bias
|
||||
optimizer = torch.optim.SGD(parameters, args.lr,
|
||||
|
||||
optimizer = torch.optim.SGD(parameters, init_lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
@ -270,7 +279,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
batch_size=256, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
|
||||
if args.evaluate:
|
||||
@ -280,7 +289,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, criterion, optimizer, epoch, args)
|
||||
@ -422,8 +431,8 @@ def sanity_check(state_dict, pretrained_weights):
|
||||
continue
|
||||
|
||||
# name in pretrained model
|
||||
k_pre = 'module.encoder_q.' + k[len('module.'):] \
|
||||
if k.startswith('module.') else 'module.encoder_q.' + k
|
||||
k_pre = 'module.base_encoder.' + k[len('module.'):] \
|
||||
if k.startswith('module.') else 'module.base_encoder.' + k
|
||||
|
||||
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
|
||||
'{} is changed in linear classifier training.'.format(k)
|
||||
@ -472,13 +481,11 @@ class ProgressMeter(object):
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
def adjust_learning_rate(optimizer, init_lr, epoch, args):
|
||||
"""Decay the learning rate based on schedule"""
|
||||
lr = args.lr
|
||||
for milestone in args.schedule:
|
||||
lr *= 0.1 if epoch >= milestone else 1.
|
||||
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
param_group['lr'] = cur_lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
@ -493,7 +500,7 @@ def accuracy(output, target, topk=(1,)):
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user