linear eval

This commit is contained in:
Xinlei Chen 2021-06-18 20:33:39 -07:00
parent 89d29d8b64
commit 7c5f6867de

View File

@ -1,7 +1,13 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import builtins
import math
import os
import random
import shutil
@ -35,19 +41,17 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
' (default: resnet50)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
parser.add_argument('-b', '--batch-size', default=4096, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
help='mini-batch size (default: 4096), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=30., type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial (base) learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
@ -77,6 +81,7 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
# additional configs:
parser.add_argument('--pretrained', default='', type=str,
help='path to moco pretrained checkpoint')
@ -161,10 +166,10 @@ def main_worker(gpu, ngpus_per_node, args):
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# retain only base_encoder up to before the embedding layer
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
@ -176,6 +181,9 @@ def main_worker(gpu, ngpus_per_node, args):
else:
print("=> no checkpoint found at '{}'".format(args.pretrained))
# infer learning rate before changing batch size
init_lr = args.lr * args.batch_size / 256
if args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
@ -211,7 +219,8 @@ def main_worker(gpu, ngpus_per_node, args):
# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # fc.weight, fc.bias
optimizer = torch.optim.SGD(parameters, args.lr,
optimizer = torch.optim.SGD(parameters, init_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
@ -270,7 +279,7 @@ def main_worker(gpu, ngpus_per_node, args):
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
batch_size=256, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
@ -280,7 +289,7 @@ def main_worker(gpu, ngpus_per_node, args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch, args)
adjust_learning_rate(optimizer, init_lr, epoch, args)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, args)
@ -422,8 +431,8 @@ def sanity_check(state_dict, pretrained_weights):
continue
# name in pretrained model
k_pre = 'module.encoder_q.' + k[len('module.'):] \
if k.startswith('module.') else 'module.encoder_q.' + k
k_pre = 'module.base_encoder.' + k[len('module.'):] \
if k.startswith('module.') else 'module.base_encoder.' + k
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
'{} is changed in linear classifier training.'.format(k)
@ -472,13 +481,11 @@ class ProgressMeter(object):
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, epoch, args):
def adjust_learning_rate(optimizer, init_lr, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
param_group['lr'] = cur_lr
def accuracy(output, target, topk=(1,)):
@ -493,7 +500,7 @@ def accuracy(output, target, topk=(1,)):
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res