from __future__ import print_function from __future__ import division import os import sys import time import datetime import os.path as osp import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from args import argument_parser, video_dataset_kwargs, optimizer_kwargs, lr_scheduler_kwargs from torchreid.data_manager import VideoDataManager from torchreid import models from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision from torchreid.utils.iotools import save_checkpoint, check_isfile from torchreid.utils.avgmeter import AverageMeter from torchreid.utils.loggers import Logger, RankLogger from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers, accuracy, load_pretrained_weights from torchreid.utils.reidtools import visualize_ranked_results from torchreid.utils.generaltools import set_random_seed from torchreid.eval_metrics import evaluate from torchreid.samplers import RandomIdentitySampler from torchreid.optimizers import init_optimizer from torchreid.lr_schedulers import init_lr_scheduler # global variables parser = argument_parser() args = parser.parse_args() def main(): global args set_random_seed(args.seed) if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False log_name = 'log_test.txt' if args.evaluate else 'log_train.txt' sys.stdout = Logger(osp.join(args.save_dir, log_name)) print('==========\nArgs:{}\n=========='.format(args)) if use_gpu: print('Currently using GPU {}'.format(args.gpu_devices)) cudnn.benchmark = True else: print('Currently using CPU, however, GPU is highly recommended') print('Initializing video data manager') dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args)) trainloader, testloader_dict = dm.return_dataloaders() print('Initializing model: {}'.format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) print('Model size: {:.3f} M'.format(count_num_param(model))) if args.load_weights and check_isfile(args.load_weights): load_pretrained_weights(model, args.load_weights) if args.resume and check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 best_rank1 = checkpoint['rank1'] print('Loaded checkpoint from "{}"'.format(args.resume)) print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch, best_rank1)) model = nn.DataParallel(model).cuda() if use_gpu else model criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) criterion_htri = TripletLoss(margin=args.margin) optimizer = init_optimizer(model, **optimizer_kwargs(args)) scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args)) if args.evaluate: print('Evaluate only') for name in args.target_names: print('Evaluating {} ...'.format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] distmat = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu, return_distmat=True) if args.visualize_ranks: visualize_ranked_results( distmat, dm.return_testdataset_by_name(name), save_dir=osp.join(args.save_dir, 'ranked_results', name), topk=20 ) return start_time = time.time() ranklogger = RankLogger(args.source_names, args.target_names) train_time = 0 print('=> Start training') if args.fixbase_epoch > 0: print('Train {} for {} epochs while keeping other layers frozen'.format(args.open_layers, args.fixbase_epoch)) initial_optim_state = optimizer.state_dict() for epoch in range(args.fixbase_epoch): start_train_time = time.time() train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=True) train_time += round(time.time() - start_train_time) print('Done. All layers are open to train for {} epochs'.format(args.max_epoch)) optimizer.load_state_dict(initial_optim_state) for epoch in range(args.start_epoch, args.max_epoch): start_train_time = time.time() train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu) train_time += round(time.time() - start_train_time) scheduler.step() if (epoch + 1) > args.start_eval and args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.max_epoch: print('=> Test') for name in args.target_names: print('Evaluating {} ...'.format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] rank1 = test(model, queryloader, galleryloader, args.pool_tracklet_features, use_gpu) ranklogger.write(name, epoch + 1, rank1) save_checkpoint({ 'state_dict': model.state_dict(), 'rank1': rank1, 'epoch': epoch, }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print('Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.'.format(elapsed, train_time)) ranklogger.show_summary() def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu, fixbase=False): xent_losses = AverageMeter() htri_losses = AverageMeter() accs = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() if fixbase or args.always_fixbase: open_specified_layers(model, args.open_layers) else: open_all_layers(model) end = time.time() for batch_idx, (imgs, pids, _, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs, features = model(imgs) if isinstance(outputs, (tuple, list)): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, (tuple, list)): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) xent_losses.update(xent_loss.item(), pids.size(0)) htri_losses.update(htri_loss.item(), pids.size(0)) accs.update(accuracy(outputs, pids)[0]) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Xent {xent.val:.4f} ({xent.avg:.4f})\t' 'Htri {htri.val:.4f} ({htri.avg:.4f})\t' 'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, xent=xent_losses, htri=htri_losses, acc=accs )) end = time.time() def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False): batch_time = AverageMeter() model.eval() with torch.no_grad(): qf, q_pids, q_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(queryloader): if use_gpu: imgs = imgs.cuda() b, s, c, h, w = imgs.size() imgs = imgs.view(b*s, c, h, w) end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.view(b, s, -1) if pool == 'avg': features = torch.mean(features, 1) else: features, _ = torch.max(features, 1) features = features.data.cpu() qf.append(features) q_pids.extend(pids) q_camids.extend(camids) qf = torch.cat(qf, 0) q_pids = np.asarray(q_pids) q_camids = np.asarray(q_camids) print('Extracted features for query set, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1))) gf, g_pids, g_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(galleryloader): if use_gpu: imgs = imgs.cuda() b, s, c, h, w = imgs.size() imgs = imgs.view(b*s, c, h, w) end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.view(b, s, -1) if pool == 'avg': features = torch.mean(features, 1) else: features, _ = torch.max(features, 1) features = features.data.cpu() gf.append(features) g_pids.extend(pids) g_camids.extend(camids) gf = torch.cat(gf, 0) g_pids = np.asarray(g_pids) g_camids = np.asarray(g_camids) print('Extracted features for gallery set, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1))) print('=> BatchTime(s)/BatchSize(img): {:.3f}/{}'.format(batch_time.avg, args.test_batch_size * args.seq_len)) m, n = qf.size(0), gf.size(0) distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.numpy() print('Computing CMC and mAP') cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) print('Results ----------') print('mAP: {:.1%}'.format(mAP)) print('CMC curve') for r in ranks: print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1])) print('------------------') if return_distmat: return distmat return cmc[0] if __name__ == '__main__': main()