deep-person-reid/train_imgreid_xent.py

258 lines
9.4 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import print_function
from __future__ import division
2018-03-12 05:17:48 +08:00
import os
import sys
import time
import datetime
import os.path as osp
import numpy as np
import torch
2018-03-12 05:27:48 +08:00
import torch.nn as nn
2018-03-12 05:17:48 +08:00
import torch.backends.cudnn as cudnn
2019-02-08 19:44:12 +08:00
from args import argument_parser, image_dataset_kwargs, optimizer_kwargs, lr_scheduler_kwargs
2018-11-06 05:18:56 +08:00
from torchreid.data_manager import ImageDataManager
2018-08-15 16:48:17 +08:00
from torchreid import models
from torchreid.losses import CrossEntropyLoss, DeepSupervision
2018-08-15 16:48:17 +08:00
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
2018-11-09 05:41:32 +08:00
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights
2018-08-15 16:48:17 +08:00
from torchreid.utils.reidtools import visualize_ranked_results
2019-01-10 05:30:46 +08:00
from torchreid.utils.generaltools import set_random_seed
2018-08-15 16:48:17 +08:00
from torchreid.eval_metrics import evaluate
2018-11-08 01:09:23 +08:00
from torchreid.optimizers import init_optimizer
2019-02-08 19:44:12 +08:00
from torchreid.lr_schedulers import init_lr_scheduler
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
# global variables
2018-11-07 23:36:49 +08:00
parser = argument_parser()
2018-03-12 05:17:48 +08:00
args = parser.parse_args()
2018-07-02 20:57:11 +08:00
2018-03-12 05:17:48 +08:00
def main():
2018-11-06 05:18:56 +08:00
global args
2019-01-10 05:30:46 +08:00
set_random_seed(args.seed)
2018-09-20 04:15:17 +08:00
if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
2018-03-12 05:17:48 +08:00
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
2018-11-08 01:09:23 +08:00
log_name = 'log_test.txt' if args.evaluate else 'log_train.txt'
sys.stdout = Logger(osp.join(args.save_dir, log_name))
2019-01-31 06:41:47 +08:00
print('==========\nArgs:{}\n=========='.format(args))
2018-03-12 05:35:49 +08:00
2018-03-12 05:17:48 +08:00
if use_gpu:
2019-01-31 06:41:47 +08:00
print('Currently using GPU {}'.format(args.gpu_devices))
2018-03-12 05:17:48 +08:00
cudnn.benchmark = True
else:
2019-01-31 06:41:47 +08:00
print('Currently using CPU, however, GPU is highly recommended')
2018-03-12 05:17:48 +08:00
2019-01-31 06:41:47 +08:00
print('Initializing image data manager')
2018-11-08 01:09:23 +08:00
dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args))
trainloader, testloader_dict = dm.return_dataloaders()
2018-03-12 05:17:48 +08:00
2019-01-31 06:41:47 +08:00
print('Initializing model: {}'.format(args.arch))
2018-11-06 05:18:56 +08:00
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
2019-01-31 06:41:47 +08:00
print('Model size: {:.3f} M'.format(count_num_param(model)))
2018-03-12 05:17:48 +08:00
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
2019-01-31 06:41:47 +08:00
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
2018-03-12 05:17:48 +08:00
2019-02-02 06:44:34 +08:00
model = nn.DataParallel(model).cuda() if use_gpu else model
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
2019-02-03 22:04:10 +08:00
optimizer = init_optimizer(model, **optimizer_kwargs(args))
2019-02-08 19:44:12 +08:00
scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
2019-02-02 06:44:34 +08:00
2018-03-12 05:17:48 +08:00
if args.evaluate:
2019-01-31 06:41:47 +08:00
print('Evaluate only')
2018-11-06 05:18:56 +08:00
2018-11-09 05:41:32 +08:00
for name in args.target_names:
2019-01-31 06:41:47 +08:00
print('Evaluating {} ...'.format(name))
2018-11-06 05:18:56 +08:00
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
if args.visualize_ranks:
visualize_ranked_results(
2018-11-08 05:46:39 +08:00
distmat, dm.return_testdataset_by_name(name),
2018-11-06 05:18:56 +08:00
save_dir=osp.join(args.save_dir, 'ranked_results', name),
topk=20
)
2018-03-12 05:17:48 +08:00
return
start_time = time.time()
2018-11-09 05:41:32 +08:00
ranklogger = RankLogger(args.source_names, args.target_names)
2018-04-30 18:19:44 +08:00
train_time = 0
2019-01-31 06:41:47 +08:00
print('=> Start training')
2018-03-12 05:17:48 +08:00
2018-07-06 18:03:38 +08:00
if args.fixbase_epoch > 0:
2019-01-31 06:41:47 +08:00
print('Train {} for {} epochs while keeping other layers frozen'.format(args.open_layers, args.fixbase_epoch))
2018-11-09 07:04:50 +08:00
initial_optim_state = optimizer.state_dict()
2018-07-06 18:03:38 +08:00
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
2018-11-09 07:04:50 +08:00
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
2018-07-06 18:03:38 +08:00
train_time += round(time.time() - start_train_time)
2019-01-31 06:41:47 +08:00
print('Done. All layers are open to train for {} epochs'.format(args.max_epoch))
2018-11-09 07:04:50 +08:00
optimizer.load_state_dict(initial_optim_state)
2018-07-06 18:03:38 +08:00
2018-08-01 21:59:43 +08:00
for epoch in range(args.start_epoch, args.max_epoch):
2018-04-30 18:19:44 +08:00
start_train_time = time.time()
2018-04-27 17:03:24 +08:00
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
2018-04-30 18:19:44 +08:00
train_time += round(time.time() - start_train_time)
2018-03-12 06:52:33 +08:00
2018-07-02 17:17:14 +08:00
scheduler.step()
2018-03-12 06:52:33 +08:00
2018-11-10 05:58:58 +08:00
if (epoch + 1) > args.start_eval and args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.max_epoch:
2019-01-31 06:41:47 +08:00
print('=> Test')
2018-07-06 18:03:38 +08:00
2018-11-09 05:41:32 +08:00
for name in args.target_names:
2019-01-31 06:41:47 +08:00
print('Evaluating {} ...'.format(name))
2018-11-06 05:18:56 +08:00
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
rank1 = test(model, queryloader, galleryloader, use_gpu)
2018-11-09 05:41:32 +08:00
ranklogger.write(name, epoch + 1, rank1)
2018-07-06 18:03:38 +08:00
2018-03-12 05:17:48 +08:00
save_checkpoint({
'state_dict': model.state_dict(),
2018-03-12 05:17:48 +08:00
'rank1': rank1,
'epoch': epoch,
2018-11-06 05:18:56 +08:00
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
2018-04-27 16:33:12 +08:00
2018-03-12 06:29:53 +08:00
elapsed = round(time.time() - start_time)
2018-03-12 05:17:48 +08:00
elapsed = str(datetime.timedelta(seconds=elapsed))
2018-04-30 18:19:44 +08:00
train_time = str(datetime.timedelta(seconds=train_time))
2019-01-31 06:41:47 +08:00
print('Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.'.format(elapsed, train_time))
2018-11-09 05:41:32 +08:00
ranklogger.show_summary()
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-11-09 07:04:50 +08:00
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False):
2018-03-12 05:17:48 +08:00
losses = AverageMeter()
2019-01-28 07:15:38 +08:00
accs = AverageMeter()
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
2018-03-12 05:17:48 +08:00
2018-11-15 06:33:00 +08:00
if fixbase or args.always_fixbase:
2018-11-09 07:04:50 +08:00
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
2018-07-06 18:03:38 +08:00
2018-05-23 02:27:09 +08:00
end = time.time()
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
2018-05-23 18:42:32 +08:00
data_time.update(time.time() - end)
2018-07-02 19:55:28 +08:00
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
2018-03-12 05:17:48 +08:00
outputs = model(imgs)
if isinstance(outputs, (tuple, list)):
2018-04-27 01:00:03 +08:00
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
2018-03-12 05:17:48 +08:00
optimizer.zero_grad()
loss.backward()
optimizer.step()
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
losses.update(loss.item(), pids.size(0))
2019-01-28 07:15:38 +08:00
accs.update(accuracy(outputs, pids)[0])
2018-03-12 05:17:48 +08:00
2018-08-01 21:59:43 +08:00
if (batch_idx + 1) % args.print_freq == 0:
2018-05-23 02:27:09 +08:00
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
2019-01-28 07:15:38 +08:00
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
epoch + 1, batch_idx + 1, len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accs
))
2018-07-02 23:39:52 +08:00
end = time.time()
2018-03-12 05:17:48 +08:00
2018-07-02 20:57:11 +08:00
2018-08-01 17:52:06 +08:00
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False):
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
2018-03-12 05:17:48 +08:00
model.eval()
2018-04-26 21:11:43 +08:00
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
2018-04-26 21:11:43 +08:00
if use_gpu: imgs = imgs.cuda()
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
2019-01-31 06:41:47 +08:00
print('Extracted features for query set, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
2018-04-26 21:11:43 +08:00
gf, g_pids, g_camids = [], [], []
2018-05-23 02:27:09 +08:00
end = time.time()
for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
2018-04-26 21:11:43 +08:00
if use_gpu: imgs = imgs.cuda()
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
2019-01-31 06:41:47 +08:00
print('Extracted features for gallery set, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))
2018-05-23 02:27:09 +08:00
2019-01-31 06:41:47 +08:00
print('=> BatchTime(s)/BatchSize(img): {:.3f}/{}'.format(batch_time.avg, args.test_batch_size))
2018-03-12 05:17:48 +08:00
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
2018-03-12 06:29:53 +08:00
distmat = distmat.numpy()
2018-03-12 05:17:48 +08:00
2019-01-31 06:41:47 +08:00
print('Computing CMC and mAP')
2018-04-23 19:57:50 +08:00
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03)
2018-03-12 05:17:48 +08:00
2019-01-31 06:41:47 +08:00
print('Results ----------')
print('mAP: {:.1%}'.format(mAP))
print('CMC curve')
2018-03-12 06:10:43 +08:00
for r in ranks:
2019-01-31 06:41:47 +08:00
print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1]))
print('------------------')
2018-03-12 05:17:48 +08:00
2018-08-01 17:52:06 +08:00
if return_distmat:
2018-08-01 19:04:58 +08:00
return distmat
2018-03-12 05:17:48 +08:00
return cmc[0]
2018-07-02 20:57:11 +08:00
2018-03-12 05:17:48 +08:00
if __name__ == '__main__':
2018-07-02 23:39:52 +08:00
main()