from __future__ import print_function from __future__ import division import os import sys import time import datetime import os.path as osp import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.optim import lr_scheduler from args import argument_parser from torchreid.data_manager import VideoDataManager from torchreid.dataset_loader import ImageDataset, VideoDataset from torchreid.transforms import build_transforms from torchreid import models from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision from torchreid.utils.iotools import save_checkpoint, check_isfile from torchreid.utils.avgmeter import AverageMeter from torchreid.utils.logger import Logger from torchreid.utils.torchtools import count_num_param from torchreid.utils.reidtools import visualize_ranked_results from torchreid.eval_metrics import evaluate from torchreid.samplers import RandomIdentitySampler from torchreid.optimizers import init_optim # global variables parser = argument_parser() args = parser.parse_args() def main(): global args torch.manual_seed(args.seed) if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False if not args.evaluate: sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) else: sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) print("==========\nArgs:{}\n==========".format(args)) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) cudnn.benchmark = True torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU (GPU is highly recommended)") transform_train = build_transforms(args.height, args.width, is_train=True) transform_test = build_transforms(args.height, args.width, is_train=False) pin_memory = True if use_gpu else False dm = VideoDataManager( args.source, args.target, args.root, args.split_id, transform_train, transform_test, args.train_batch, args.test_batch, args.workers, pin_memory, args.seq_len, args.sample ) trainloader = dm.trainloader testloader_dict = dm.testloader_dict print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) print("Model size: {:.3f} M".format(count_num_param(model))) criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) criterion_htri = TripletLoss(margin=args.margin) optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) if args.load_weights and check_isfile(args.load_weights): # load pretrained weights but ignore layers that don't match in size checkpoint = torch.load(args.load_weights) pretrain_dict = checkpoint['state_dict'] model_dict = model.state_dict() pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume and check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 best_rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, best_rank1)) if use_gpu: model = nn.DataParallel(model).cuda() if args.evaluate: print("Evaluate only") for name in args.target: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] distmat = test(model, queryloader, galleryloader, args.pool, use_gpu, return_distmat=True) if args.visualize_ranks: visualize_ranked_results( distmat, dataset, save_dir=osp.join(args.save_dir, 'ranked_results', name), topk=20 ) return start_time = time.time() train_time = 0 print("==> Start training") for epoch in range(args.start_epoch, args.max_epoch): start_train_time = time.time() train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu) train_time += round(time.time() - start_train_time) scheduler.step() if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: print("==> Test") for name in args.target: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] rank1 = test(model, queryloader, galleryloader, args.pool, use_gpu) if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ 'state_dict': state_dict, 'rank1': rank1, 'epoch': epoch, }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs, features = model(imgs) if args.htri_only: if isinstance(features, (tuple, list)): loss = DeepSupervision(criterion_htri, features, pids) else: loss = criterion_htri(features, pids) else: if isinstance(outputs, (tuple, list)): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, (tuple, list)): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time() def test(model, queryloader, galleryloader, pool, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False): batch_time = AverageMeter() model.eval() with torch.no_grad(): qf, q_pids, q_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(queryloader): if use_gpu: imgs = imgs.cuda() b, s, c, h, w = imgs.size() imgs = imgs.view(b*s, c, h, w) end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.view(b, s, -1) if pool == 'avg': features = torch.mean(features, 1) else: features, _ = torch.max(features, 1) features = features.data.cpu() qf.append(features) q_pids.extend(pids) q_camids.extend(camids) qf = torch.cat(qf, 0) q_pids = np.asarray(q_pids) q_camids = np.asarray(q_camids) print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) gf, g_pids, g_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(galleryloader): if use_gpu: imgs = imgs.cuda() b, s, c, h, w = imgs.size() imgs = imgs.view(b*s, c, h, w) end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.view(b, s, -1) if pool == 'avg': features = torch.mean(features, 1) else: features, _ = torch.max(features, 1) features = features.data.cpu() gf.append(features) g_pids.extend(pids) g_camids.extend(camids) gf = torch.cat(gf, 0) g_pids = np.asarray(g_pids) g_camids = np.asarray(g_camids) print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch*args.seq_len)) m, n = qf.size(0), gf.size(0) distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.numpy() print("Computing CMC and mAP") cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) print("Results ----------") print("mAP: {:.1%}".format(mAP)) print("CMC curve") for r in ranks: print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) print("------------------") if return_distmat: return distmat return cmc[0] if __name__ == '__main__': main()