from __future__ import print_function from __future__ import division import os import sys import time import datetime import argparse import os.path as osp import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import lr_scheduler from torchreid.data_manager import ImageDataManager from torchreid.transforms import build_transforms from torchreid import models from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision from torchreid.utils.iotools import save_checkpoint, check_isfile from torchreid.utils.avgmeter import AverageMeter from torchreid.utils.logger import Logger from torchreid.utils.torchtools import count_num_param from torchreid.utils.reidtools import visualize_ranked_results from torchreid.eval_metrics import evaluate from torchreid.samplers import RandomIdentitySampler from torchreid.optimizers import init_optim parser = argparse.ArgumentParser(description='Train image model with cross entropy loss and hard triplet loss') # Datasets parser.add_argument('--root', type=str, default='data', help="root path to data directory") parser.add_argument('-s', '--source', type=str, required=True, nargs='+') parser.add_argument('-t', '--target', type=str, required=True, nargs='+') parser.add_argument('-j', '--workers', default=4, type=int, help="number of data loading workers (default: 4)") parser.add_argument('--height', type=int, default=256, help="height of an image (default: 256)") parser.add_argument('--width', type=int, default=128, help="width of an image (default: 128)") parser.add_argument('--split-id', type=int, default=0, help="split index (0-based)") # CUHK03-specific setting parser.add_argument('--cuhk03-labeled', action='store_true', help="use labeled images, if false, detected images are used (default: False)") parser.add_argument('--cuhk03-classic-split', action='store_true', help="use classic split by Li et al. CVPR'14 (default: False)") parser.add_argument('--use-metric-cuhk03', action='store_true', help="use cuhk03-metric (default: False)") # Optimization options parser.add_argument('--optim', type=str, default='adam', help="optimization algorithm (see optimizers.py)") parser.add_argument('--max-epoch', default=60, type=int, help="maximum epochs to run") parser.add_argument('--start-epoch', default=0, type=int, help="manual epoch number (useful on restarts)") parser.add_argument('--train-batch', default=32, type=int, help="train batch size") parser.add_argument('--test-batch', default=100, type=int, help="test batch size") parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, help="initial learning rate") parser.add_argument('--stepsize', default=[20, 40], nargs='+', type=int, help="stepsize to decay learning rate") parser.add_argument('--gamma', default=0.1, type=float, help="learning rate decay") parser.add_argument('--weight-decay', default=5e-04, type=float, help="weight decay (default: 5e-04)") parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") parser.add_argument('--num-instances', type=int, default=4, help="number of instances per identity") parser.add_argument('--htri-only', action='store_true', help="only use hard triplet loss (default: Fasle)") parser.add_argument('--lambda-xent', type=float, default=1, help="weight to balance cross entropy loss") parser.add_argument('--lambda-htri', type=float, default=1, help="weight to balance hard triplet loss") parser.add_argument('--label-smooth', action='store_true', help="use label smoothing regularizer in cross entropy loss") # Architecture parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names()) # Miscs parser.add_argument('--print-freq', type=int, default=10, help="print frequency") parser.add_argument('--seed', type=int, default=1, help="manual seed") parser.add_argument('--resume', type=str, default='', metavar='PATH') parser.add_argument('--load-weights', type=str, default='', help="load pretrained weights but ignores layers that don't match in size") parser.add_argument('--evaluate', action='store_true', help="evaluation only") parser.add_argument('--eval-step', type=int, default=-1, help="run evaluation for every N epochs (set to -1 to test after training)") parser.add_argument('--start-eval', type=int, default=0, help="start to evaluate after specific epoch") parser.add_argument('--save-dir', type=str, default='log') parser.add_argument('--use-cpu', action='store_true', help="use cpu") parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') parser.add_argument('--use-avai-gpus', action='store_true', help="use available gpus instead of specified devices (this is useful when using managed clusters)") parser.add_argument('--visualize-ranks', action='store_true', help="visualize ranked results, only available in evaluation mode (default: False)") # global variables args = parser.parse_args() def main(): global args torch.manual_seed(args.seed) if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False if not args.evaluate: sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) else: sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) print("==========\nArgs:{}\n==========".format(args)) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) cudnn.benchmark = True torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU (GPU is highly recommended)") transform_train = build_transforms(args.height, args.width, is_train=True) transform_test = build_transforms(args.height, args.width, is_train=False) pin_memory = True if use_gpu else False dm = ImageDataManager( args.source, args.target, args.root, args.split_id, transform_train, transform_test, args.train_batch, args.test_batch, args.workers, pin_memory, cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split ) trainloader = dm.trainloader testloader_dict = dm.testloader_dict print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) print("Model size: {:.3f} M".format(count_num_param(model))) criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) criterion_htri = TripletLoss(margin=args.margin) optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay) scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) if args.load_weights and check_isfile(args.load_weights): # load pretrained weights but ignore layers that don't match in size checkpoint = torch.load(args.load_weights) pretrain_dict = checkpoint['state_dict'] model_dict = model.state_dict() pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} model_dict.update(pretrain_dict) model.load_state_dict(model_dict) print("Loaded pretrained weights from '{}'".format(args.load_weights)) if args.resume and check_isfile(args.resume): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 print("Loaded checkpoint from '{}'".format(args.resume)) print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint['rank1'])) if use_gpu: model = nn.DataParallel(model).cuda() if args.evaluate: print("Evaluate only") for name in args.target: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True) if args.visualize_ranks: visualize_ranked_results( distmat, dataset, save_dir=osp.join(args.save_dir, 'ranked_results', name), topk=20 ) return start_time = time.time() train_time = 0 print("==> Start training") for epoch in range(args.start_epoch, args.max_epoch): start_train_time = time.time() train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu) train_time += round(time.time() - start_train_time) scheduler.step() if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: print("==> Test") for name in args.target: print("Evaluating {} ...".format(name)) queryloader = testloader_dict[name]['query'] galleryloader = testloader_dict[name]['gallery'] rank1 = test(model, queryloader, galleryloader, use_gpu) if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ 'state_dict': state_dict, 'rank1': rank1, 'epoch': epoch, }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu): losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() model.train() end = time.time() for batch_idx, (imgs, pids, _) in enumerate(trainloader): data_time.update(time.time() - end) if use_gpu: imgs, pids = imgs.cuda(), pids.cuda() outputs, features = model(imgs) if args.htri_only: if isinstance(features, (tuple, list)): loss = DeepSupervision(criterion_htri, features, pids) else: loss = criterion_htri(features, pids) else: if isinstance(outputs, (tuple, list)): xent_loss = DeepSupervision(criterion_xent, outputs, pids) else: xent_loss = criterion_xent(outputs, pids) if isinstance(features, (tuple, list)): htri_loss = DeepSupervision(criterion_htri, features, pids) else: htri_loss = criterion_htri(features, pids) loss = args.lambda_xent * xent_loss + args.lambda_htri * htri_loss optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) losses.update(loss.item(), pids.size(0)) if (batch_idx + 1) % args.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, data_time=data_time, loss=losses)) end = time.time() def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False): batch_time = AverageMeter() model.eval() with torch.no_grad(): qf, q_pids, q_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(queryloader): if use_gpu: imgs = imgs.cuda() end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.data.cpu() qf.append(features) q_pids.extend(pids) q_camids.extend(camids) qf = torch.cat(qf, 0) q_pids = np.asarray(q_pids) q_camids = np.asarray(q_camids) print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) gf, g_pids, g_camids = [], [], [] for batch_idx, (imgs, pids, camids) in enumerate(galleryloader): if use_gpu: imgs = imgs.cuda() end = time.time() features = model(imgs) batch_time.update(time.time() - end) features = features.data.cpu() gf.append(features) g_pids.extend(pids) g_camids.extend(camids) gf = torch.cat(gf, 0) g_pids = np.asarray(g_pids) g_camids = np.asarray(g_camids) print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch)) m, n = qf.size(0), gf.size(0) distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.numpy() print("Computing CMC and mAP") cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03) print("Results ----------") print("mAP: {:.1%}".format(mAP)) print("CMC curve") for r in ranks: print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) print("------------------") if return_distmat: return distmat return cmc[0] if __name__ == '__main__': main()