deep-person-reid/train_vidreid_xent.py

285 lines
10 KiB
Python

from __future__ import print_function
from __future__ import division
import os
import sys
import time
import datetime
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from args import argument_parser, video_dataset_kwargs, optimizer_kwargs
from torchreid.data_manager import VideoDataManager
from torchreid import models
from torchreid.losses import CrossEntropyLoss
from torchreid.utils.iotools import save_checkpoint, check_isfile
from torchreid.utils.avgmeter import AverageMeter
from torchreid.utils.loggers import Logger, RankLogger
from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy
from torchreid.utils.reidtools import visualize_ranked_results
from torchreid.utils.generaltools import set_random_seed
from torchreid.eval_metrics import evaluate
from torchreid.optimizers import init_optimizer
# global variables
parser = argument_parser()
args = parser.parse_args()
def main():
global args
set_random_seed(args.seed)
if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
log_name = 'log_test.txt' if args.evaluate else 'log_train.txt'
sys.stdout = Logger(osp.join(args.save_dir, log_name))
print('==========\nArgs:{}\n=========='.format(args))
if use_gpu:
print('Currently using GPU {}'.format(args.gpu_devices))
cudnn.benchmark = True
else:
print('Currently using CPU, however, GPU is highly recommended')
print('Initializing video data manager')
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
trainloader, testloader_dict = dm.return_dataloaders()
print('Initializing model: {}'.format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})
print('Model size: {:.3f} M'.format(count_num_param(model)))
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args))
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma)
if args.load_weights and check_isfile(args.load_weights):
# load pretrained weights but ignore layers that don't match in size
checkpoint = torch.load(args.load_weights)
pretrain_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print('Loaded pretrained weights from "{}"'.format(args.load_weights))
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch'] + 1
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, checkpoint['rank1']))
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print('Evaluate only')
for name in args.target_names:
print('Evaluating {} ...'.format(name))
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
distmat = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu, return_distmat=True)
if args.visualize_ranks:
visualize_ranked_results(
distmat, dm.return_testdataset_by_name(name),
save_dir=osp.join(args.save_dir, 'ranked_results', name),
topk=20
)
return
start_time = time.time()
ranklogger = RankLogger(args.source_names, args.target_names)
train_time = 0
print('=> Start training')
if args.fixbase_epoch > 0:
print('Train {} for {} epochs while keeping other layers frozen'.format(args.open_layers, args.fixbase_epoch))
initial_optim_state = optimizer.state_dict()
for epoch in range(args.fixbase_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=True)
train_time += round(time.time() - start_train_time)
print('Done. All layers are open to train for {} epochs'.format(args.max_epoch))
optimizer.load_state_dict(initial_optim_state)
for epoch in range(args.start_epoch, args.max_epoch):
start_train_time = time.time()
train(epoch, model, criterion, optimizer, trainloader, use_gpu)
train_time += round(time.time() - start_train_time)
scheduler.step()
if (epoch + 1) > args.start_eval and args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.max_epoch:
print('=> Test')
for name in args.target_names:
print('Evaluating {} ...'.format(name))
queryloader = testloader_dict[name]['query']
galleryloader = testloader_dict[name]['gallery']
rank1 = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu)
ranklogger.write(name, epoch + 1, rank1)
if use_gpu:
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
save_checkpoint({
'state_dict': state_dict,
'rank1': rank1,
'epoch': epoch,
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
train_time = str(datetime.timedelta(seconds=train_time))
print('Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.'.format(elapsed, train_time))
ranklogger.show_summary()
def train(epoch, model, criterion, optimizer, trainloader, use_gpu, fixbase=False):
losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
if fixbase or args.always_fixbase:
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
end = time.time()
for batch_idx, (imgs, pids, _, _) in enumerate(trainloader):
data_time.update(time.time() - end)
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs)
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
accs.update(accuracy(outputs, pids)[0])
if (batch_idx + 1) % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.4f} ({data_time.avg:.4f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
epoch + 1, batch_idx + 1, len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accs
))
end = time.time()
def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False):
batch_time = AverageMeter()
model.eval()
with torch.no_grad():
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
if use_gpu: imgs = imgs.cuda()
b, s, c, h, w = imgs.size()
imgs = imgs.view(b*s, c, h, w)
end = time.time()
features = model(imgs)
batch_time.update(time.time() - end)
features = features.view(b, s, -1)
if pool == 'avg':
features = torch.mean(features, 1)
else:
features, _ = torch.max(features, 1)
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print('Extracted features for query set, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
if use_gpu: imgs = imgs.cuda()
b, s, c, h, w = imgs.size()
imgs = imgs.view(b*s, c, h, w)
end = time.time()
features = model(imgs)
batch_time.update(time.time() - end)
features = features.view(b, s, -1)
if pool == 'avg':
features = torch.mean(features, 1)
else:
features, _ = torch.max(features, 1)
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print('Extracted features for gallery set, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))
print('=> BatchTime(s)/BatchSize(img): {:.3f}/{}'.format(batch_time.avg, args.test_batch_size * args.seq_len))
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
print('Computing CMC and mAP')
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
print('Results ----------')
print('mAP: {:.1%}'.format(mAP))
print('CMC curve')
for r in ranks:
print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1]))
print('------------------')
if return_distmat:
return distmat
return cmc[0]
if __name__ == '__main__':
main()