linear eval

This commit is contained in:
Xinlei Chen 2021-06-18 20:33:39 -07:00
parent 89d29d8b64
commit 7c5f6867de

View File

@ -1,7 +1,13 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
import builtins import builtins
import math
import os import os
import random import random
import shutil import shutil
@ -35,19 +41,17 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
' (default: resnet50)') ' (default: resnet50)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
help='number of data loading workers (default: 32)') help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N', parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run') help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=4096, type=int,
metavar='N', metavar='N',
help='mini-batch size (default: 256), this is the total ' help='mini-batch size (default: 4096), this is the total '
'batch size of all GPUs on the current node when ' 'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel') 'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=30., type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr') metavar='LR', help='initial (base) learning rate', dest='lr')
parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0., type=float, parser.add_argument('--wd', '--weight-decay', default=0., type=float,
@ -77,6 +81,7 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
'fastest way to use PyTorch for either single node or ' 'fastest way to use PyTorch for either single node or '
'multi node data parallel training') 'multi node data parallel training')
# additional configs:
parser.add_argument('--pretrained', default='', type=str, parser.add_argument('--pretrained', default='', type=str,
help='path to moco pretrained checkpoint') help='path to moco pretrained checkpoint')
@ -161,10 +166,10 @@ def main_worker(gpu, ngpus_per_node, args):
# rename moco pre-trained keys # rename moco pre-trained keys
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer # retain only base_encoder up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.fc'):
# remove prefix # remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k] state_dict[k[len("module.base_encoder."):]] = state_dict[k]
# delete renamed or unused k # delete renamed or unused k
del state_dict[k] del state_dict[k]
@ -176,6 +181,9 @@ def main_worker(gpu, ngpus_per_node, args):
else: else:
print("=> no checkpoint found at '{}'".format(args.pretrained)) print("=> no checkpoint found at '{}'".format(args.pretrained))
# infer learning rate before changing batch size
init_lr = args.lr * args.batch_size / 256
if args.distributed: if args.distributed:
# For multiprocessing distributed, DistributedDataParallel constructor # For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise, # should always set the single device scope, otherwise,
@ -211,7 +219,8 @@ def main_worker(gpu, ngpus_per_node, args):
# optimize only the linear classifier # optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters())) parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2 # fc.weight, fc.bias assert len(parameters) == 2 # fc.weight, fc.bias
optimizer = torch.optim.SGD(parameters, args.lr,
optimizer = torch.optim.SGD(parameters, init_lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
@ -270,7 +279,7 @@ def main_worker(gpu, ngpus_per_node, args):
transforms.ToTensor(), transforms.ToTensor(),
normalize, normalize,
])), ])),
batch_size=args.batch_size, shuffle=False, batch_size=256, shuffle=False,
num_workers=args.workers, pin_memory=True) num_workers=args.workers, pin_memory=True)
if args.evaluate: if args.evaluate:
@ -280,7 +289,7 @@ def main_worker(gpu, ngpus_per_node, args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch, args) adjust_learning_rate(optimizer, init_lr, epoch, args)
# train for one epoch # train for one epoch
train(train_loader, model, criterion, optimizer, epoch, args) train(train_loader, model, criterion, optimizer, epoch, args)
@ -422,8 +431,8 @@ def sanity_check(state_dict, pretrained_weights):
continue continue
# name in pretrained model # name in pretrained model
k_pre = 'module.encoder_q.' + k[len('module.'):] \ k_pre = 'module.base_encoder.' + k[len('module.'):] \
if k.startswith('module.') else 'module.encoder_q.' + k if k.startswith('module.') else 'module.base_encoder.' + k
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
'{} is changed in linear classifier training.'.format(k) '{} is changed in linear classifier training.'.format(k)
@ -472,13 +481,11 @@ class ProgressMeter(object):
return '[' + fmt + '/' + fmt.format(num_batches) + ']' return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, epoch, args): def adjust_learning_rate(optimizer, init_lr, epoch, args):
"""Decay the learning rate based on schedule""" """Decay the learning rate based on schedule"""
lr = args.lr cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = cur_lr
def accuracy(output, target, topk=(1,)): def accuracy(output, target, topk=(1,)):
@ -493,7 +500,7 @@ def accuracy(output, target, topk=(1,)):
res = [] res = []
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size)) res.append(correct_k.mul_(100.0 / batch_size))
return res return res