deep-person-reid/train_vidreid_xent.py

278 lines
10 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import print_function
from __future__ import division
2018-03-12 22:09:56 +08:00
import os
import sys
import time
import datetime
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
2018-11-08 01:09:23 +08:00
from args import argument_parser, video_dataset_kwargs, optimizer_kwargs
2018-11-07 23:36:49 +08:00
from torchreid.data_manager import VideoDataManager
2018-08-15 16:48:17 +08:00
from torchreid import models
from torchreid.losses import CrossEntropyLoss
2018-08-15 16:48:17 +08:00
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
2018-11-09 05:41:32 +08:00
from torchreid.utils.loggers import Logger, RankLogger
2018-11-09 08:02:46 +08:00
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers
2018-08-15 16:48:17 +08:00
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.eval_metrics import evaluate
2018-11-08 01:09:23 +08:00
from torchreid.optimizers import init_optimizer
2018-03-12 22:09:56 +08:00
2018-07-02 20:57:11 +08:00
# global variables
2018-11-07 23:36:49 +08:00
parser = argument_parser()
2018-03-12 22:09:56 +08:00
args = parser.parse_args()
2018-07-02 20:57:11 +08:00
2018-03-12 22:09:56 +08:00
def main():
2018-11-07 23:36:49 +08:00
global args
2018-03-12 22:09:56 +08:00
torch.manual_seed(args.seed)
2018-09-20 04:15:17 +08:00
if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
2018-03-12 22:09:56 +08:00
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
2018-11-08 01:09:23 +08:00
log_name = 'log_test.txt' if args.evaluate else 'log_train.txt'
sys.stdout = Logger(osp.join(args.save_dir, log_name))
2018-03-12 22:09:56 +08:00
print("==========\nArgs:{}\n==========".format(args))
if use_gpu:
2018-03-14 20:35:14 +08:00
print("Currently using GPU {}".format(args.gpu_devices))
2018-03-12 22:09:56 +08:00
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
2018-11-08 01:09:23 +08:00
print("Currently using CPU, however, GPU is highly recommended")
2018-03-12 22:09:56 +08:00
2018-11-08 05:46:39 +08:00
print("Initializing video data manager")
2018-11-08 01:09:23 +08:00
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
trainloader, testloader_dict = dm.return_dataloaders()
2018-03-12 22:09:56 +08:00
print("Initializing model: {}".format(args.arch))
2018-11-07 23:40:02 +08:00
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})
2018-07-06 18:03:38 +08:00
print("Model size: {:.3f} M".format(count_num_param(model)))
2018-03-12 22:09:56 +08:00
2018-11-07 23:40:02 +08:00
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
2018-11-08 01:09:23 +08:00
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
2018-07-02 17:17:14 +08:00
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
2018-03-12 22:09:56 +08:00
if args.load_weights and check_isfile(args.load_weights):
2018-07-10 18:56:30 +08:00
# load pretrained weights but ignore layers that don't match in size
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Loaded pretrained weights from '{}'".format(args.load_weights))
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
print("Loaded checkpoint from '{}'".format(args.resume))
2018-11-07 23:36:49 +08:00
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint['rank1']))
2018-03-12 22:09:56 +08:00
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print("Evaluate only")
2018-11-07 23:36:49 +08:00
2018-11-09 05:41:32 +08:00
for name in args.target_names:
2018-11-07 23:36:49 +08:00
print("Evaluating {} ...".format(name))
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
2018-11-09 06:02:22 +08:00
distmat = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu, return_distmat=True)
2018-11-07 23:36:49 +08:00
if args.visualize_ranks:
visualize_ranked_results(
2018-11-08 05:46:39 +08:00
distmat, dm.return_testdataset_by_name(name),
2018-11-07 23:36:49 +08:00
save_dir=osp.join(args.save_dir, 'ranked_results', name),
topk=20
)
2018-03-12 22:09:56 +08:00
return
start_time = time.time()
2018-11-09 05:41:32 +08:00
ranklogger = RankLogger(args.source_names, args.target_names)
2018-04-30 18:19:44 +08:00
train_time = 0
2018-04-27 17:03:24 +08:00
print("==> Start training")
2018-03-12 22:09:56 +08:00
2018-07-06 18:03:38 +08:00
if args.fixbase_epoch > 0:
2018-11-09 08:02:46 +08:00
print("Train {} for {} epochs while keeping other layers frozen".format(args.open_layers, args.fixbase_epoch))
initial_optim_state = optimizer.state_dict()
2018-07-06 18:03:38 +08:00
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
2018-11-09 08:02:46 +08:00
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
2018-07-06 18:03:38 +08:00
train_time += round(time.time() - start_train_time)
2018-11-09 08:02:46 +08:00
print("Done. All layers are open to train for {} epochs".format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
2018-07-06 18:03:38 +08:00
2018-08-01 21:59:43 +08:00
for epoch in range(args.start_epoch, args.max_epoch):
2018-04-30 18:19:44 +08:00
start_train_time = time.time()
2018-04-27 17:03:24 +08:00
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
2018-04-30 18:19:44 +08:00
train_time += round(time.time() - start_train_time)
2018-03-12 22:09:56 +08:00
2018-07-02 17:17:14 +08:00
scheduler.step()
2018-03-12 22:09:56 +08:00
2018-08-01 21:59:43 +08:00
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
2018-03-12 22:09:56 +08:00
print("==> Test")
2018-07-06 18:03:38 +08:00
2018-11-09 05:41:32 +08:00
for name in args.target_names:
2018-11-07 23:36:49 +08:00
print("Evaluating {} ...".format(name))
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
2018-11-09 06:02:22 +08:00
rank1 = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu)
2018-11-09 05:41:32 +08:00
ranklogger.write(name, epoch + 1, rank1)
2018-03-12 22:09:56 +08:00
2018-03-28 07:41:49 +08:00
if use_gpu:
2018-03-29 23:45:39 +08:00
state_dict = model.module.state_dict()
2018-03-28 07:41:49 +08:00
else:
state_dict = model.state_dict()
2018-07-06 18:03:38 +08:00
2018-03-12 22:09:56 +08:00
save_checkpoint({
2018-03-28 07:41:49 +08:00
'state_dict': state_dict,
2018-03-12 22:09:56 +08:00
'rank1': rank1,
'epoch': epoch,
2018-11-07 23:36:49 +08:00
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
2018-04-27 16:33:12 +08:00
2018-03-12 22:09:56 +08:00
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
2018-04-30 18:19:44 +08:00
train_time = str(datetime.timedelta(seconds=train_time))
print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
2018-11-09 05:41:32 +08:00
ranklogger.show_summary()
2018-03-12 22:09:56 +08:00
2018-07-02 20:57:11 +08:00
2018-11-09 08:02:46 +08:00
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False):
2018-03-12 22:09:56 +08:00
losses = AverageMeter()
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
2018-03-12 22:09:56 +08:00
2018-11-09 08:02:46 +08:00
if fixbase:
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
2018-07-06 18:03:38 +08:00
2018-05-23 02:27:09 +08:00
end = time.time()
2018-03-12 22:09:56 +08:00
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
2018-05-23 18:42:32 +08:00
data_time.update(time.time() - end)
2018-07-02 20:57:11 +08:00
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
2018-03-12 22:09:56 +08:00
outputs = model(imgs)
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
2018-03-12 22:09:56 +08:00
optimizer.zero_grad()
loss.backward()
optimizer.step()
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
losses.update(loss.item(), pids.size(0))
2018-03-12 22:09:56 +08:00
2018-08-01 21:59:43 +08:00
if (batch_idx + 1) % args.print_freq == 0:
2018-05-23 02:27:09 +08:00
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
2018-05-23 02:27:09 +08:00
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
2018-08-01 21:59:43 +08:00
epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time,
2018-05-23 02:27:09 +08:00
data_time=data_time, loss=losses))
2018-07-02 23:39:52 +08:00
end = time.time()
2018-03-12 22:09:56 +08:00
2018-07-02 20:57:11 +08:00
2018-08-01 19:04:58 +08:00
def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False):
2018-05-23 02:27:09 +08:00
batch_time = AverageMeter()
2018-03-12 22:09:56 +08:00
model.eval()
2018-04-26 21:11:43 +08:00
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
if use_gpu: imgs = imgs.cuda()
b, s, c, h, w = imgs.size()
imgs = imgs.view(b*s, c, h, w)
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.view(b, s, -1)
if pool == 'avg':
features = torch.mean(features, 1)
else:
features, _ = torch.max(features, 1)
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
if use_gpu: imgs = imgs.cuda()
b, s, c, h, w = imgs.size()
imgs = imgs.view(b*s, c, h, w)
2018-05-23 02:27:09 +08:00
end = time.time()
2018-04-26 21:11:43 +08:00
features = model(imgs)
2018-05-23 02:27:09 +08:00
batch_time.update(time.time() - end)
2018-04-26 21:11:43 +08:00
features = features.view(b, s, -1)
if pool == 'avg':
features = torch.mean(features, 1)
else:
features, _ = torch.max(features, 1)
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
2018-11-08 01:13:39 +08:00
print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch_size * args.seq_len))
2018-03-12 22:09:56 +08:00
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
print("Computing CMC and mAP")
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
print("------------------")
2018-08-01 19:04:58 +08:00
if return_distmat:
return distmat
2018-03-12 22:09:56 +08:00
return cmc[0]
2018-07-02 20:57:11 +08:00
2018-03-12 22:09:56 +08:00
if __name__ == '__main__':
2018-07-02 23:39:52 +08:00
main()