mirror of
https://github.com/facebookresearch/moco-v3.git
synced 2025-06-03 14:59:22 +08:00
linear eval
This commit is contained in:
parent
89d29d8b64
commit
7c5f6867de
@ -1,7 +1,13 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import builtins
|
import builtins
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
@ -35,19 +41,17 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
|||||||
' (default: resnet50)')
|
' (default: resnet50)')
|
||||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||||
help='number of data loading workers (default: 32)')
|
help='number of data loading workers (default: 32)')
|
||||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||||
help='number of total epochs to run')
|
help='number of total epochs to run')
|
||||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||||
help='manual epoch number (useful on restarts)')
|
help='manual epoch number (useful on restarts)')
|
||||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
parser.add_argument('-b', '--batch-size', default=4096, type=int,
|
||||||
metavar='N',
|
metavar='N',
|
||||||
help='mini-batch size (default: 256), this is the total '
|
help='mini-batch size (default: 4096), this is the total '
|
||||||
'batch size of all GPUs on the current node when '
|
'batch size of all GPUs on the current node when '
|
||||||
'using Data Parallel or Distributed Data Parallel')
|
'using Data Parallel or Distributed Data Parallel')
|
||||||
parser.add_argument('--lr', '--learning-rate', default=30., type=float,
|
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||||
metavar='LR', help='initial learning rate', dest='lr')
|
metavar='LR', help='initial (base) learning rate', dest='lr')
|
||||||
parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
|
|
||||||
help='learning rate schedule (when to drop lr by a ratio)')
|
|
||||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
help='momentum')
|
help='momentum')
|
||||||
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
|
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
|
||||||
@ -77,6 +81,7 @@ parser.add_argument('--multiprocessing-distributed', action='store_true',
|
|||||||
'fastest way to use PyTorch for either single node or '
|
'fastest way to use PyTorch for either single node or '
|
||||||
'multi node data parallel training')
|
'multi node data parallel training')
|
||||||
|
|
||||||
|
# additional configs:
|
||||||
parser.add_argument('--pretrained', default='', type=str,
|
parser.add_argument('--pretrained', default='', type=str,
|
||||||
help='path to moco pretrained checkpoint')
|
help='path to moco pretrained checkpoint')
|
||||||
|
|
||||||
@ -161,10 +166,10 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
# rename moco pre-trained keys
|
# rename moco pre-trained keys
|
||||||
state_dict = checkpoint['state_dict']
|
state_dict = checkpoint['state_dict']
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
# retain only encoder_q up to before the embedding layer
|
# retain only base_encoder up to before the embedding layer
|
||||||
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.fc'):
|
||||||
# remove prefix
|
# remove prefix
|
||||||
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
||||||
# delete renamed or unused k
|
# delete renamed or unused k
|
||||||
del state_dict[k]
|
del state_dict[k]
|
||||||
|
|
||||||
@ -176,6 +181,9 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
else:
|
else:
|
||||||
print("=> no checkpoint found at '{}'".format(args.pretrained))
|
print("=> no checkpoint found at '{}'".format(args.pretrained))
|
||||||
|
|
||||||
|
# infer learning rate before changing batch size
|
||||||
|
init_lr = args.lr * args.batch_size / 256
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
# For multiprocessing distributed, DistributedDataParallel constructor
|
# For multiprocessing distributed, DistributedDataParallel constructor
|
||||||
# should always set the single device scope, otherwise,
|
# should always set the single device scope, otherwise,
|
||||||
@ -211,7 +219,8 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
# optimize only the linear classifier
|
# optimize only the linear classifier
|
||||||
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||||
assert len(parameters) == 2 # fc.weight, fc.bias
|
assert len(parameters) == 2 # fc.weight, fc.bias
|
||||||
optimizer = torch.optim.SGD(parameters, args.lr,
|
|
||||||
|
optimizer = torch.optim.SGD(parameters, init_lr,
|
||||||
momentum=args.momentum,
|
momentum=args.momentum,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
@ -270,7 +279,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
])),
|
])),
|
||||||
batch_size=args.batch_size, shuffle=False,
|
batch_size=256, shuffle=False,
|
||||||
num_workers=args.workers, pin_memory=True)
|
num_workers=args.workers, pin_memory=True)
|
||||||
|
|
||||||
if args.evaluate:
|
if args.evaluate:
|
||||||
@ -280,7 +289,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
adjust_learning_rate(optimizer, epoch, args)
|
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
||||||
|
|
||||||
# train for one epoch
|
# train for one epoch
|
||||||
train(train_loader, model, criterion, optimizer, epoch, args)
|
train(train_loader, model, criterion, optimizer, epoch, args)
|
||||||
@ -422,8 +431,8 @@ def sanity_check(state_dict, pretrained_weights):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# name in pretrained model
|
# name in pretrained model
|
||||||
k_pre = 'module.encoder_q.' + k[len('module.'):] \
|
k_pre = 'module.base_encoder.' + k[len('module.'):] \
|
||||||
if k.startswith('module.') else 'module.encoder_q.' + k
|
if k.startswith('module.') else 'module.base_encoder.' + k
|
||||||
|
|
||||||
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
|
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
|
||||||
'{} is changed in linear classifier training.'.format(k)
|
'{} is changed in linear classifier training.'.format(k)
|
||||||
@ -472,13 +481,11 @@ class ProgressMeter(object):
|
|||||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||||
|
|
||||||
|
|
||||||
def adjust_learning_rate(optimizer, epoch, args):
|
def adjust_learning_rate(optimizer, init_lr, epoch, args):
|
||||||
"""Decay the learning rate based on schedule"""
|
"""Decay the learning rate based on schedule"""
|
||||||
lr = args.lr
|
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
|
||||||
for milestone in args.schedule:
|
|
||||||
lr *= 0.1 if epoch >= milestone else 1.
|
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = cur_lr
|
||||||
|
|
||||||
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
def accuracy(output, target, topk=(1,)):
|
||||||
@ -493,7 +500,7 @@ def accuracy(output, target, topk=(1,)):
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
for k in topk:
|
for k in topk:
|
||||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user