2021-06-17 10:59:59 +08:00
|
|
|
#!/usr/bin/env python
|
2021-06-17 17:09:43 +08:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
import argparse
|
2021-06-17 17:39:28 +08:00
|
|
|
import builtins
|
2021-06-17 10:59:59 +08:00
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import shutil
|
|
|
|
import time
|
|
|
|
import warnings
|
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-19 07:20:47 +08:00
|
|
|
import signal
|
2021-07-10 06:34:23 +08:00
|
|
|
from tensorboardX import SummaryWriter
|
|
|
|
# ===== to delete =====
|
2021-06-19 07:20:47 +08:00
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.parallel
|
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.optim
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.utils.data
|
|
|
|
import torch.utils.data.distributed
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
import torchvision.datasets as datasets
|
2021-06-17 17:09:43 +08:00
|
|
|
import torchvision.models as torchvision_models
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-06-24 18:44:50 +08:00
|
|
|
from functools import partial
|
|
|
|
import vits
|
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
import moco.builder
|
2021-06-17 17:09:43 +08:00
|
|
|
import moco.loader
|
|
|
|
import moco.optimizer
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-19 07:20:47 +08:00
|
|
|
def signalHandler(a, b):
|
|
|
|
if a == signal.SIGUSR1:
|
|
|
|
logger.info('Got SIGUSR1.')
|
|
|
|
elif a == signal.SIGTERM:
|
|
|
|
logger.info('Got SIGTERM.')
|
|
|
|
|
|
|
|
signal.signal(signal.SIGUSR1, signalHandler)
|
|
|
|
signal.signal(signal.SIGTERM, signalHandler)
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-19 07:20:47 +08:00
|
|
|
|
2021-06-17 17:09:43 +08:00
|
|
|
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
2021-06-17 10:59:59 +08:00
|
|
|
if name.islower() and not name.startswith("__")
|
2021-06-17 17:39:28 +08:00
|
|
|
and callable(torchvision_models.__dict__[name]))
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-06-25 20:49:34 +08:00
|
|
|
model_names = ['vit_small', 'vit_base', 'vit_large', 'vit_huge'] + torchvision_model_names
|
2021-06-17 17:09:43 +08:00
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
help='path to dataset')
|
|
|
|
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
|
|
|
choices=model_names,
|
|
|
|
help='model architecture: ' +
|
|
|
|
' | '.join(model_names) +
|
|
|
|
' (default: resnet50)')
|
|
|
|
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
|
|
|
help='number of data loading workers (default: 32)')
|
2021-06-18 03:31:27 +08:00
|
|
|
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
2021-06-17 10:59:59 +08:00
|
|
|
help='number of total epochs to run')
|
|
|
|
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
|
|
|
help='manual epoch number (useful on restarts)')
|
2021-06-17 17:09:43 +08:00
|
|
|
parser.add_argument('-b', '--batch-size', default=4096, type=int,
|
2021-06-17 10:59:59 +08:00
|
|
|
metavar='N',
|
2021-06-17 17:09:43 +08:00
|
|
|
help='mini-batch size (default: 4096), this is the total '
|
2021-06-17 10:59:59 +08:00
|
|
|
'batch size of all GPUs on the current node when '
|
|
|
|
'using Data Parallel or Distributed Data Parallel')
|
2021-06-18 03:31:27 +08:00
|
|
|
parser.add_argument('--lr', '--learning-rate', default=0.45, type=float,
|
2021-06-19 17:59:06 +08:00
|
|
|
metavar='LR', help='initial (base) learning rate', dest='lr')
|
2021-06-17 10:59:59 +08:00
|
|
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
2021-06-17 17:09:43 +08:00
|
|
|
help='momentum')
|
|
|
|
parser.add_argument('--wd', '--weight-decay', default=1e-6, type=float,
|
|
|
|
metavar='W', help='weight decay (default: 1e-6)',
|
2021-06-17 10:59:59 +08:00
|
|
|
dest='weight_decay')
|
|
|
|
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
|
|
|
metavar='N', help='print frequency (default: 10)')
|
|
|
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
parser.add_argument('--world-size', default=-1, type=int,
|
|
|
|
help='number of nodes for distributed training')
|
|
|
|
parser.add_argument('--rank', default=-1, type=int,
|
|
|
|
help='node rank for distributed training')
|
|
|
|
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
|
|
|
help='url used to set up distributed training')
|
|
|
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
|
|
|
help='distributed backend')
|
|
|
|
parser.add_argument('--seed', default=None, type=int,
|
|
|
|
help='seed for initializing training. ')
|
|
|
|
parser.add_argument('--gpu', default=None, type=int,
|
|
|
|
help='GPU id to use.')
|
|
|
|
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
|
|
|
help='Use multi-processing distributed training to launch '
|
|
|
|
'N processes per node, which has N GPUs. This is the '
|
|
|
|
'fastest way to use PyTorch for either single node or '
|
|
|
|
'multi node data parallel training')
|
|
|
|
|
|
|
|
# moco specific configs:
|
2021-06-17 17:09:43 +08:00
|
|
|
parser.add_argument('--moco-dim', default=256, type=int,
|
|
|
|
help='feature dimension (default: 256)')
|
|
|
|
parser.add_argument('--moco-mlp-dim', default=4096, type=int,
|
|
|
|
help='hidden dimension in MLPs (default: 4096)')
|
|
|
|
parser.add_argument('--moco-m', default=0.99, type=float,
|
2021-06-27 18:22:50 +08:00
|
|
|
help='moco (initial) momentum of updating momentum encoder '
|
|
|
|
'the value will gradually increase to 1 with a '
|
|
|
|
'half-cycle cosine schedule (default: 0.99)')
|
2021-06-17 17:09:43 +08:00
|
|
|
parser.add_argument('--moco-t', default=1.0, type=float,
|
|
|
|
help='softmax temperature (default: 1.0)')
|
|
|
|
|
2021-07-09 17:22:59 +08:00
|
|
|
# vit specific configs:
|
2021-07-12 15:20:01 +08:00
|
|
|
parser.add_argument('--fix-init', action='store_true',
|
|
|
|
help='fix weight init for first conv, or patch embedding')
|
2021-07-09 17:22:59 +08:00
|
|
|
parser.add_argument('--stop-grad-conv1', action='store_true',
|
|
|
|
help='stop-grad after first conv, or patch embedding')
|
|
|
|
|
2021-06-17 17:09:43 +08:00
|
|
|
# other upgrades
|
|
|
|
parser.add_argument('--optimizer', default='lars', type=str,
|
2021-06-17 17:39:28 +08:00
|
|
|
choices=['lars', 'adamw'],
|
2021-06-17 17:09:43 +08:00
|
|
|
help='optimizer used (default: lars)')
|
2021-06-17 18:06:05 +08:00
|
|
|
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
|
|
|
|
help='number of warmup epochs')
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-27 18:22:50 +08:00
|
|
|
parser.add_argument('--checkpoint-folder', default='.', type=str, metavar='PATH',
|
|
|
|
help='path to save the checkpoints (default: .)')
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.seed is not None:
|
|
|
|
random.seed(args.seed)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
cudnn.deterministic = True
|
|
|
|
warnings.warn('You have chosen to seed training. '
|
|
|
|
'This will turn on the CUDNN deterministic setting, '
|
|
|
|
'which can slow down your training considerably! '
|
|
|
|
'You may see unexpected behavior when restarting '
|
|
|
|
'from checkpoints.')
|
|
|
|
|
|
|
|
if args.gpu is not None:
|
|
|
|
warnings.warn('You have chosen a specific GPU. This will completely '
|
|
|
|
'disable data parallelism.')
|
|
|
|
|
|
|
|
if args.dist_url == "env://" and args.world_size == -1:
|
|
|
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
|
|
|
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-25 20:49:34 +08:00
|
|
|
if not os.path.exists(args.checkpoint_folder):
|
|
|
|
os.makedirs(args.checkpoint_folder)
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-25 20:49:34 +08:00
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
|
|
if args.multiprocessing_distributed:
|
|
|
|
# Since we have ngpus_per_node processes per node, the total world_size
|
|
|
|
# needs to be adjusted accordingly
|
|
|
|
args.world_size = ngpus_per_node * args.world_size
|
|
|
|
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
|
|
|
# main_worker process function
|
|
|
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
|
|
|
else:
|
|
|
|
# Simply call main_worker function
|
|
|
|
main_worker(args.gpu, ngpus_per_node, args)
|
|
|
|
|
|
|
|
|
|
|
|
def main_worker(gpu, ngpus_per_node, args):
|
|
|
|
args.gpu = gpu
|
|
|
|
|
2021-06-19 07:18:22 +08:00
|
|
|
# suppress printing if not first GPU on each node
|
2021-06-27 18:22:50 +08:00
|
|
|
if args.multiprocessing_distributed and (args.gpu != 0 or args.rank != 0):
|
2021-06-17 17:39:28 +08:00
|
|
|
def print_pass(*args):
|
|
|
|
pass
|
|
|
|
builtins.print = print_pass
|
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
if args.gpu is not None:
|
|
|
|
print("Use GPU: {} for training".format(args.gpu))
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
if args.dist_url == "env://" and args.rank == -1:
|
|
|
|
args.rank = int(os.environ["RANK"])
|
|
|
|
if args.multiprocessing_distributed:
|
|
|
|
# For multiprocessing distributed training, rank needs to be the
|
|
|
|
# global rank among all the processes
|
|
|
|
args.rank = args.rank * ngpus_per_node + gpu
|
|
|
|
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
|
|
|
world_size=args.world_size, rank=args.rank)
|
2021-06-18 03:31:27 +08:00
|
|
|
torch.distributed.barrier()
|
2021-06-17 10:59:59 +08:00
|
|
|
# create model
|
|
|
|
print("=> creating model '{}'".format(args.arch))
|
2021-06-24 18:44:50 +08:00
|
|
|
if args.arch.startswith('vit'):
|
|
|
|
model = moco.builder.MoCo(
|
2021-07-12 15:20:01 +08:00
|
|
|
partial(vits.__dict__[args.arch], fix_init=args.fix_init, stop_grad_conv1=args.stop_grad_conv1),
|
2021-06-24 18:44:50 +08:00
|
|
|
True, # with vit setup
|
|
|
|
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
|
|
|
else:
|
|
|
|
model = moco.builder.MoCo(
|
|
|
|
partial(torchvision_models.__dict__[args.arch], zero_init_residual=True),
|
|
|
|
False, # with resnet setup
|
|
|
|
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-06-17 17:09:43 +08:00
|
|
|
# infer learning rate before changing batch size
|
|
|
|
init_lr = args.lr * args.batch_size / 256
|
|
|
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
print('using CPU, this will be slow')
|
|
|
|
elif args.distributed:
|
|
|
|
# Apply SyncBN
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
2021-06-17 10:59:59 +08:00
|
|
|
# For multiprocessing distributed, DistributedDataParallel constructor
|
|
|
|
# should always set the single device scope, otherwise,
|
|
|
|
# DistributedDataParallel will use all available devices.
|
|
|
|
if args.gpu is not None:
|
|
|
|
torch.cuda.set_device(args.gpu)
|
|
|
|
model.cuda(args.gpu)
|
|
|
|
# When using a single GPU per process and per
|
|
|
|
# DistributedDataParallel, we need to divide the batch size
|
|
|
|
# ourselves based on the total number of GPUs we have
|
2021-06-18 03:31:27 +08:00
|
|
|
args.batch_size = int(args.batch_size / args.world_size)
|
2021-06-17 10:59:59 +08:00
|
|
|
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
2021-07-08 09:47:47 +08:00
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
2021-06-17 10:59:59 +08:00
|
|
|
else:
|
|
|
|
model.cuda()
|
2021-07-08 09:47:47 +08:00
|
|
|
# DistributedDataParallel will divide and allocate batch_size to all
|
|
|
|
# available GPUs if device_ids are not set
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model)
|
2021-06-17 10:59:59 +08:00
|
|
|
elif args.gpu is not None:
|
|
|
|
torch.cuda.set_device(args.gpu)
|
|
|
|
model = model.cuda(args.gpu)
|
|
|
|
# comment out the following line for debugging
|
2021-07-08 09:47:47 +08:00
|
|
|
# raise NotImplementedError("Only DistributedDataParallel is supported.")
|
2021-06-17 10:59:59 +08:00
|
|
|
else:
|
2021-06-17 17:09:43 +08:00
|
|
|
# AllGather/rank implementation in this code only supports DistributedDataParallel.
|
2021-06-17 10:59:59 +08:00
|
|
|
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
2021-06-17 17:09:43 +08:00
|
|
|
print(model) # print model after SyncBatchNorm
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-06-17 17:09:43 +08:00
|
|
|
if args.optimizer == 'lars':
|
2021-07-02 05:00:25 +08:00
|
|
|
optimizer = moco.optimizer.LARS(model.parameters(), init_lr,
|
2021-06-17 17:09:43 +08:00
|
|
|
weight_decay=args.weight_decay,
|
2021-06-22 08:52:28 +08:00
|
|
|
momentum=args.momentum)
|
2021-06-17 17:09:43 +08:00
|
|
|
elif args.optimizer == 'adamw':
|
2021-07-02 05:00:25 +08:00
|
|
|
optimizer = moco.optimizer.AdamW(model.parameters(), init_lr,
|
2021-06-17 10:59:59 +08:00
|
|
|
weight_decay=args.weight_decay)
|
|
|
|
|
2021-07-08 19:34:54 +08:00
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
|
|
|
if args.rank == 0:
|
|
|
|
summary_writer = SummaryWriter(logdir=args.checkpoint_folder)
|
2021-07-12 15:20:01 +08:00
|
|
|
else:
|
|
|
|
summary_writer = None
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
2021-06-24 14:59:26 +08:00
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
# optionally resume from a checkpoint
|
|
|
|
if args.resume:
|
|
|
|
if os.path.isfile(args.resume):
|
|
|
|
print("=> loading checkpoint '{}'".format(args.resume))
|
|
|
|
if args.gpu is None:
|
|
|
|
checkpoint = torch.load(args.resume)
|
|
|
|
else:
|
|
|
|
# Map model to be loaded to specified single gpu.
|
|
|
|
loc = 'cuda:{}'.format(args.gpu)
|
|
|
|
checkpoint = torch.load(args.resume, map_location=loc)
|
|
|
|
args.start_epoch = checkpoint['epoch']
|
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
2021-07-08 19:34:54 +08:00
|
|
|
scaler.load_state_dict(checkpoint['scaler'])
|
2021-06-17 10:59:59 +08:00
|
|
|
print("=> loaded checkpoint '{}' (epoch {})"
|
|
|
|
.format(args.resume, checkpoint['epoch']))
|
|
|
|
else:
|
|
|
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
|
|
# Data loading code
|
|
|
|
traindir = os.path.join(args.data, 'train')
|
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
std=[0.229, 0.224, 0.225])
|
2021-06-17 17:09:43 +08:00
|
|
|
|
|
|
|
# BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
|
2021-07-08 09:47:47 +08:00
|
|
|
# except min-scale kept as 0.2
|
2021-06-17 17:09:43 +08:00
|
|
|
augmentation1 = [
|
|
|
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
|
|
|
transforms.RandomApply([
|
|
|
|
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
|
|
|
|
], p=0.8),
|
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
|
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=1.0),
|
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
normalize
|
|
|
|
]
|
|
|
|
|
|
|
|
augmentation2 = [
|
|
|
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
|
|
|
transforms.RandomApply([
|
|
|
|
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
|
|
|
|
], p=0.8),
|
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
|
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.1),
|
|
|
|
transforms.RandomApply([moco.loader.Solarize()], p=0.2),
|
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
normalize
|
|
|
|
]
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
train_dataset = datasets.ImageFolder(
|
|
|
|
traindir,
|
2021-06-17 17:09:43 +08:00
|
|
|
moco.loader.TwoCropsTransform(transforms.Compose(augmentation1),
|
|
|
|
transforms.Compose(augmentation2)))
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
|
|
else:
|
|
|
|
train_sampler = None
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
|
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
|
|
|
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
|
|
|
|
|
|
|
for epoch in range(args.start_epoch, args.epochs):
|
|
|
|
if args.distributed:
|
|
|
|
train_sampler.set_epoch(epoch)
|
2021-06-17 17:09:43 +08:00
|
|
|
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
# train for one epoch
|
2021-07-10 06:34:23 +08:00
|
|
|
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
2021-07-05 17:19:09 +08:00
|
|
|
and args.rank == 0): # only the first GPU saves checkpoint
|
2021-06-17 10:59:59 +08:00
|
|
|
save_checkpoint({
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
'arch': args.arch,
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
'optimizer' : optimizer.state_dict(),
|
2021-07-08 19:34:54 +08:00
|
|
|
'scaler': scaler.state_dict(),
|
2021-06-18 03:31:27 +08:00
|
|
|
}, is_best=False, filename='%s/checkpoint_%04d.pth.tar' % (args.checkpoint_folder, epoch))
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
|
|
|
if args.rank == 0:
|
|
|
|
summary_writer.close()
|
|
|
|
# ===== to delete =====
|
2021-06-17 10:59:59 +08:00
|
|
|
|
2021-07-10 06:34:23 +08:00
|
|
|
def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
|
2021-06-17 10:59:59 +08:00
|
|
|
batch_time = AverageMeter('Time', ':6.3f')
|
|
|
|
data_time = AverageMeter('Data', ':6.3f')
|
|
|
|
losses = AverageMeter('Loss', ':.4e')
|
|
|
|
progress = ProgressMeter(
|
|
|
|
len(train_loader),
|
2021-07-10 06:34:23 +08:00
|
|
|
[batch_time, data_time, losses],
|
2021-06-17 10:59:59 +08:00
|
|
|
prefix="Epoch: [{}]".format(epoch))
|
|
|
|
|
|
|
|
# switch to train mode
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
end = time.time()
|
2021-06-24 14:26:16 +08:00
|
|
|
moco_m = adjust_moco_momentum(epoch, args)
|
2021-06-17 10:59:59 +08:00
|
|
|
for i, (images, _) in enumerate(train_loader):
|
|
|
|
# measure data loading time
|
|
|
|
data_time.update(time.time() - end)
|
|
|
|
|
|
|
|
if args.gpu is not None:
|
|
|
|
images[0] = images[0].cuda(args.gpu, non_blocking=True)
|
|
|
|
images[1] = images[1].cuda(args.gpu, non_blocking=True)
|
|
|
|
|
|
|
|
# compute output
|
2021-07-08 19:34:54 +08:00
|
|
|
with torch.cuda.amp.autocast(True):
|
2021-07-10 04:32:59 +08:00
|
|
|
loss = model(images[0], images[1], moco_m)
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
losses.update(loss.item(), images[0].size(0))
|
2021-07-10 06:34:23 +08:00
|
|
|
# ===== to delete =====
|
|
|
|
if args.rank == 0:
|
|
|
|
summary_writer.add_scalar("loss", loss.item(), epoch * len(train_loader) + i)
|
|
|
|
# ===== to delete =====
|
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
# compute gradient and do SGD step
|
|
|
|
optimizer.zero_grad()
|
2021-07-08 19:34:54 +08:00
|
|
|
scaler.scale(loss).backward()
|
|
|
|
scaler.step(optimizer)
|
|
|
|
scaler.update()
|
2021-06-17 10:59:59 +08:00
|
|
|
|
|
|
|
# measure elapsed time
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
if i % args.print_freq == 0:
|
|
|
|
progress.display(i)
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
|
|
|
torch.save(state, filename)
|
|
|
|
if is_best:
|
|
|
|
shutil.copyfile(filename, 'model_best.pth.tar')
|
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter(object):
|
|
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, name, fmt=':f'):
|
|
|
|
self.name = name
|
|
|
|
self.fmt = fmt
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.val = 0
|
|
|
|
self.avg = 0
|
|
|
|
self.sum = 0
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
def update(self, val, n=1):
|
|
|
|
self.val = val
|
|
|
|
self.sum += val * n
|
|
|
|
self.count += n
|
|
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
|
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
|
|
|
|
|
|
|
|
class ProgressMeter(object):
|
|
|
|
def __init__(self, num_batches, meters, prefix=""):
|
|
|
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
|
|
|
self.meters = meters
|
|
|
|
self.prefix = prefix
|
|
|
|
|
|
|
|
def display(self, batch):
|
|
|
|
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
|
|
|
entries += [str(meter) for meter in self.meters]
|
|
|
|
print('\t'.join(entries))
|
|
|
|
|
|
|
|
def _get_batch_fmtstr(self, num_batches):
|
|
|
|
num_digits = len(str(num_batches // 1))
|
|
|
|
fmt = '{:' + str(num_digits) + 'd}'
|
|
|
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
|
|
|
|
|
|
|
|
2021-06-17 17:09:43 +08:00
|
|
|
def adjust_learning_rate(optimizer, init_lr, epoch, args):
|
2021-06-17 18:20:50 +08:00
|
|
|
"""Decays the learning rate with half-cycle cosine after warmup"""
|
|
|
|
if epoch < args.warmup_epochs:
|
2021-06-19 07:18:22 +08:00
|
|
|
lr = init_lr / (args.warmup_epochs + 1) * (epoch + 1)
|
2021-06-17 18:20:50 +08:00
|
|
|
else:
|
2021-06-19 07:18:22 +08:00
|
|
|
lr = init_lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
2021-06-17 10:59:59 +08:00
|
|
|
for param_group in optimizer.param_groups:
|
|
|
|
param_group['lr'] = lr
|
|
|
|
|
|
|
|
|
2021-06-24 14:26:16 +08:00
|
|
|
def adjust_moco_momentum(epoch, args):
|
|
|
|
"""Adjust moco momentum based on current epoch"""
|
2021-06-27 18:22:50 +08:00
|
|
|
return 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m)
|
2021-06-24 14:26:16 +08:00
|
|
|
|
|
|
|
|
2021-06-17 10:59:59 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|