deep-person-reid/train_imgreid_xent.py

381 lines
15 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import print_function
from __future__ import division
2018-03-12 05:17:48 +08:00
import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np
import torch
2018-03-12 05:27:48 +08:00
import torch.nn as nn
2018-03-12 05:17:48 +08:00
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
2018-03-12 06:52:33 +08:00
from torch.optim import lr_scheduler
2018-03-12 05:17:48 +08:00
2018-08-15 16:48:17 +08:00
from torchreid import data_manager
from torchreid.dataset_loader import ImageDataset
from torchreid import transforms as T
from torchreid import models
from torchreid.losses import CrossEntropyLabelSmooth, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.logger import Logger
from torchreid.utils.torchtools import set_bn_to_eval, count_num_param
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.eval_metrics import evaluate
from torchreid.optimizers import init_optim
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-03-12 05:17:48 +08:00
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
# Datasets
2018-07-04 17:32:43 +08:00
parser.add_argument('--root', type=str, default='data',
help="root path to data directory")
2018-03-12 05:17:48 +08:00
parser.add_argument('-d', '--dataset', type=str, default='market1501',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
help="number of data loading workers (default: 4)")
parser.add_argument('--height', type=int, default=256,
help="height of an image (default: 256)")
parser.add_argument('--width', type=int, default=128,
help="width of an image (default: 128)")
2018-07-04 17:32:43 +08:00
parser.add_argument('--split-id', type=int, default=0,
2018-09-20 04:15:17 +08:00
help="split index (0-based)")
2018-04-23 23:08:24 +08:00
# CUHK03-specific setting
2018-04-23 03:48:40 +08:00
parser.add_argument('--cuhk03-labeled', action='store_true',
2018-09-20 04:15:17 +08:00
help="use labeled images, if false, detected images are used (default: False)")
2018-04-23 22:58:51 +08:00
parser.add_argument('--cuhk03-classic-split', action='store_true',
2018-09-20 04:15:17 +08:00
help="use classic split by Li et al. CVPR'14 (default: False)")
2018-04-23 19:57:50 +08:00
parser.add_argument('--use-metric-cuhk03', action='store_true',
2018-09-20 04:15:17 +08:00
help="use cuhk03-metric (default: False)")
2018-03-12 05:17:48 +08:00
# Optimization options
2018-07-04 17:32:43 +08:00
parser.add_argument('--optim', type=str, default='adam',
help="optimization algorithm (see optimizers.py)")
2018-03-12 06:52:33 +08:00
parser.add_argument('--max-epoch', default=60, type=int,
2018-03-12 05:17:48 +08:00
help="maximum epochs to run")
parser.add_argument('--start-epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
parser.add_argument('--train-batch', default=32, type=int,
help="train batch size")
2018-07-04 17:32:43 +08:00
parser.add_argument('--test-batch', default=100, type=int,
help="test batch size")
2018-03-12 07:24:47 +08:00
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
2018-03-12 05:17:48 +08:00
help="initial learning rate")
2018-07-02 17:17:14 +08:00
parser.add_argument('--stepsize', default=[20, 40], nargs='+', type=int,
help="stepsize to decay learning rate")
2018-03-12 06:52:33 +08:00
parser.add_argument('--gamma', default=0.1, type=float,
help="learning rate decay")
2018-03-12 07:24:47 +08:00
parser.add_argument('--weight-decay', default=5e-04, type=float,
2018-03-12 05:17:48 +08:00
help="weight decay (default: 5e-04)")
2018-07-06 18:03:38 +08:00
parser.add_argument('--fixbase-epoch', default=0, type=int,
help="epochs to fix base network (only train classifier, default: 0)")
parser.add_argument('--fixbase-lr', default=0.0003, type=float,
help="learning rate (when base network is frozen)")
parser.add_argument('--freeze-bn', action='store_true',
help="freeze running statistics in BatchNorm layers during training (default: False)")
parser.add_argument('--label-smooth', action='store_true',
help="use label smoothing regularizer in cross entropy loss")
2018-03-12 05:17:48 +08:00
# Architecture
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
# Miscs
2018-07-04 17:32:43 +08:00
parser.add_argument('--print-freq', type=int, default=10,
help="print frequency")
parser.add_argument('--seed', type=int, default=1,
help="manual seed")
2018-03-12 05:17:48 +08:00
parser.add_argument('--resume', type=str, default='', metavar='PATH')
2018-07-10 18:56:30 +08:00
parser.add_argument('--load-weights', type=str, default='',
help="load pretrained weights but ignores layers that don't match in size")
2018-07-04 17:32:43 +08:00
parser.add_argument('--evaluate', action='store_true',
help="evaluation only")
2018-03-12 07:24:47 +08:00
parser.add_argument('--eval-step', type=int, default=-1,
help="run evaluation for every N epochs (set to -1 to test after training)")
2018-07-04 17:32:43 +08:00
parser.add_argument('--start-eval', type=int, default=0,
help="start to evaluate after specific epoch")
2018-03-12 05:17:48 +08:00
parser.add_argument('--save-dir', type=str, default='log')
2018-07-04 17:32:43 +08:00
parser.add_argument('--use-cpu', action='store_true',
help="use cpu")
parser.add_argument('--gpu-devices', default='0', type=str,
help='gpu device ids for CUDA_VISIBLE_DEVICES')
2018-09-20 04:15:17 +08:00
parser.add_argument('--use-avai-gpus', action='store_true',
help="use available gpus instead of specified devices (this is useful when using managed clusters)")
parser.add_argument('--visualize-ranks', action='store_true',
2018-08-01 17:52:06 +08:00
help="visualize ranked results, only available in evaluation mode (default: False)")
2018-03-12 05:17:48 +08:00
# global variables
2018-03-12 05:17:48 +08:00
args = parser.parse_args()
best_rank1 = -np.inf
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-03-12 05:17:48 +08:00
def main():
global args, best_rank1
2018-03-12 05:17:48 +08:00
torch.manual_seed(args.seed)
2018-09-20 04:15:17 +08:00
if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
2018-03-12 05:17:48 +08:00
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
2018-03-12 05:55:13 +08:00
if not args.evaluate:
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
else:
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
2018-03-12 05:42:57 +08:00
print("==========\nArgs:{}\n==========".format(args))
2018-03-12 05:35:49 +08:00
2018-03-12 05:17:48 +08:00
if use_gpu:
2018-03-14 20:35:14 +08:00
print("Currently using GPU {}".format(args.gpu_devices))
2018-03-12 05:17:48 +08:00
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU (GPU is highly recommended)")
print("Initializing dataset {}".format(args.dataset))
2018-07-02 17:17:14 +08:00
dataset = data_manager.init_imgreid_dataset(
2018-05-02 22:59:06 +08:00
root=args.root, name=args.dataset, split_id=args.split_id,
2018-04-23 22:58:51 +08:00
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)
2018-03-12 05:17:48 +08:00
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
2018-03-12 21:18:25 +08:00
pin_memory = True if use_gpu else False
2018-03-12 05:17:48 +08:00
trainloader = DataLoader(
2018-08-12 05:22:48 +08:00
ImageDataset(dataset.train, transform=transform_train),
2018-03-12 05:17:48 +08:00
batch_size=args.train_batch, shuffle=True, num_workers=args.workers,
2018-03-12 21:18:25 +08:00
pin_memory=pin_memory, drop_last=True,
2018-03-12 05:17:48 +08:00
)
queryloader = DataLoader(
2018-08-12 05:22:48 +08:00
ImageDataset(dataset.query, transform=transform_test),
2018-03-12 05:17:48 +08:00
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
2018-03-12 21:18:25 +08:00
pin_memory=pin_memory, drop_last=False,
2018-03-12 05:17:48 +08:00
)
galleryloader = DataLoader(
2018-08-12 05:22:48 +08:00
ImageDataset(dataset.gallery, transform=transform_test),
2018-03-12 05:17:48 +08:00
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
2018-03-12 21:18:25 +08:00
pin_memory=pin_memory, drop_last=False,
2018-03-12 05:17:48 +08:00
)
print("Initializing model: {}".format(args.arch))
2018-05-05 02:03:38 +08:00
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
2018-07-06 18:03:38 +08:00
print("Model size: {:.3f} M".format(count_num_param(model)))
2018-03-12 05:17:48 +08:00
if args.label_smooth:
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
else:
criterion = nn.CrossEntropyLoss()
2018-04-27 16:51:04 +08:00
optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)
2018-07-02 17:17:14 +08:00
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
2018-03-12 05:17:48 +08:00
2018-07-06 18:03:38 +08:00
if args.fixbase_epoch > 0:
2018-07-10 18:56:30 +08:00
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
optimizer_tmp = init_optim(args.optim, model.classifier.parameters(), args.fixbase_lr, args.weight_decay)
else:
print("Warn: model has no attribute 'classifier' and fixbase_epoch is reset to 0")
args.fixbase_epoch = 0
if args.load_weights and check_isfile(args.load_weights):
2018-07-10 18:56:30 +08:00
# load pretrained weights but ignore layers that don't match in size
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
best_rank1 = checkpoint['rank1']
print("Loaded checkpoint from '{}'".format(args.resume))
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, best_rank1))
2018-03-12 05:17:48 +08:00
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print("Evaluate only")
2018-08-01 19:04:58 +08:00
distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
if args.visualize_ranks:
2018-08-01 17:52:06 +08:00
visualize_ranked_results(
distmat, dataset,
save_dir=osp.join(args.save_dir, 'ranked_results'),
topk=20,
)
2018-03-12 05:17:48 +08:00
return
start_time = time.time()
2018-04-30 18:19:44 +08:00
train_time = 0
best_epoch = args.start_epoch
2018-04-27 17:03:24 +08:00
print("==> Start training")
2018-03-12 05:17:48 +08:00
2018-07-06 18:03:38 +08:00
if args.fixbase_epoch > 0:
print("Train classifier for {} epochs while keeping base network frozen".format(args.fixbase_epoch))
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer_tmp, trainloader, use_gpu, freeze_bn=True)
train_time += round(time.time() - start_train_time)
del optimizer_tmp
print("Now open all layers for training")
2018-08-01 21:59:43 +08:00
for epoch in range(args.start_epoch, args.max_epoch):
2018-04-30 18:19:44 +08:00
start_train_time = time.time()
2018-04-27 17:03:24 +08:00
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
2018-04-30 18:19:44 +08:00
train_time += round(time.time() - start_train_time)
2018-03-12 06:52:33 +08:00
2018-07-02 17:17:14 +08:00
scheduler.step()
2018-03-12 06:52:33 +08:00
2018-08-01 21:59:43 +08:00
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
2018-03-12 05:17:48 +08:00
print("==> Test")
rank1 = test(model, queryloader, galleryloader, use_gpu)
is_best = rank1 > best_rank1
2018-07-06 18:03:38 +08:00
2018-04-27 16:33:12 +08:00
if is_best:
best_rank1 = rank1
best_epoch = epoch + 1
2018-07-06 18:03:38 +08:00
2018-03-28 07:41:49 +08:00
if use_gpu:
2018-03-29 23:45:39 +08:00
state_dict = model.module.state_dict()
2018-03-28 07:41:49 +08:00
else:
state_dict = model.state_dict()
2018-07-06 18:03:38 +08:00
2018-03-12 05:17:48 +08:00
save_checkpoint({
2018-03-28 07:41:49 +08:00
'state_dict': state_dict,
2018-03-12 05:17:48 +08:00
'rank1': rank1,
'epoch': epoch,
2018-08-01 21:59:43 +08:00
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
2018-03-12 05:17:48 +08:00
2018-04-27 16:33:12 +08:00
print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))
2018-03-12 06:29:53 +08:00
elapsed = round(time.time() - start_time)
2018-03-12 05:17:48 +08:00
elapsed = str(datetime.timedelta(seconds=elapsed))
2018-04-30 18:19:44 +08:00
train_time = str(datetime.timedelta(seconds=train_time))
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-07-06 18:03:38 +08:00
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=False):
2018-03-12 05:17:48 +08:00
losses = AverageMeter()
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
2018-03-12 05:17:48 +08:00
2018-07-06 18:03:38 +08:00
if freeze_bn or args.freeze_bn:
model.apply(set_bn_to_eval)
2018-05-23 02:27:09 +08:00
end = time.time()
2018-03-12 05:17:48 +08:00
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
2018-05-23 18:42:32 +08:00
data_time.update(time.time() - end)
2018-07-02 19:55:28 +08:00
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
2018-03-12 05:17:48 +08:00
outputs = model(imgs)
2018-04-27 01:00:03 +08:00
if isinstance(outputs, tuple):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
2018-03-12 05:17:48 +08:00
optimizer.zero_grad()
loss.backward()
optimizer.step()
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
losses.update(loss.item(), pids.size(0))
2018-03-12 05:17:48 +08:00
2018-08-01 21:59:43 +08:00
if (batch_idx + 1) % args.print_freq == 0:
2018-05-23 02:27:09 +08:00
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
2018-05-23 02:27:09 +08:00
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
2018-08-01 21:59:43 +08:00
epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time,
2018-05-23 02:27:09 +08:00
data_time=data_time, loss=losses))
2018-07-02 23:39:52 +08:00
end = time.time()
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-08-01 17:52:06 +08:00
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False):
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
2018-03-12 05:17:48 +08:00
model.eval()
2018-04-26 21:11:43 +08:00
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
if use_gpu: imgs = imgs.cuda()
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
if use_gpu: imgs = imgs.cuda()
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
2018-05-23 02:27:09 +08:00
print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch))
2018-03-12 05:17:48 +08:00
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
2018-03-12 06:29:53 +08:00
distmat = distmat.numpy()
2018-03-12 05:17:48 +08:00
print("Computing CMC and mAP")
2018-04-23 19:57:50 +08:00
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03)
2018-03-12 05:17:48 +08:00
2018-03-12 06:36:46 +08:00
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
2018-03-12 06:10:43 +08:00
for r in ranks:
2018-03-12 06:20:10 +08:00
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
2018-03-12 06:36:46 +08:00
print("------------------")
2018-03-12 05:17:48 +08:00
2018-08-01 17:52:06 +08:00
if return_distmat:
2018-08-01 19:04:58 +08:00
return distmat
2018-03-12 05:17:48 +08:00
return cmc[0]
2018-07-02 20:57:11 +08:00
2018-03-12 05:17:48 +08:00
if __name__ == '__main__':
2018-07-02 23:39:52 +08:00
main()