diff --git a/torchreid/data_manager.py b/torchreid/data_manager.py index b406748..ffad5f7 100644 --- a/torchreid/data_manager.py +++ b/torchreid/data_manager.py @@ -65,4 +65,5 @@ class ImageDataManager(object): print(" # train images : {}".format(len(self.train))) print(" # train cameras : {}".format(self.num_train_cams)) print(" test names : {}".format(self.test_names)) - print(" *****************************************") \ No newline at end of file + print(" *****************************************") + print("\n") \ No newline at end of file diff --git a/train_imgreid_xent_htri.py b/train_imgreid_xent_htri.py index c606b4c..74537b2 100755 --- a/train_imgreid_xent_htri.py +++ b/train_imgreid_xent_htri.py @@ -12,11 +12,9 @@ import numpy as np import torch import torch.nn as nn import torch.backends.cudnn as cudnn -from torch.utils.data import DataLoader from torch.optim import lr_scheduler -from torchreid import data_manager -from torchreid.dataset_loader import ImageDataset +from torchreid.data_manager import ImageDataManager from torchreid.transforms import build_transforms from torchreid import models from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision @@ -34,8 +32,8 @@ parser = argparse.ArgumentParser(description='Train image model with cross entro # Datasets parser.add_argument('--root', type=str, default='data', help="root path to data directory") -parser.add_argument('-d', '--dataset', type=str, default='market1501', - choices=data_manager.get_names()) +parser.add_argument('-s', '--source', type=str, required=True, nargs='+') +parser.add_argument('-t', '--target', type=str, required=True, nargs='+') parser.add_argument('-j', '--workers', default=4, type=int, help="number of data loading workers (default: 4)") parser.add_argument('--height', type=int, default=256, @@ -110,11 +108,10 @@ parser.add_argument('--visualize-ranks', action='store_true', # global variables args = parser.parse_args() -best_rank1 = -np.inf def main(): - global args, best_rank1 + global args torch.manual_seed(args.seed) if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices @@ -134,41 +131,25 @@ def main(): else: print("Currently using CPU (GPU is highly recommended)") - print("Initializing dataset {}".format(args.dataset)) - dataset = data_manager.init_imgreid_dataset( - root=args.root, name=args.dataset, split_id=args.split_id, - cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split, - ) - transform_train = build_transforms(args.height, args.width, is_train=True) transform_test = build_transforms(args.height, args.width, is_train=False) pin_memory = True if use_gpu else False - trainloader = DataLoader( - ImageDataset(dataset.train, transform=transform_train), - sampler=RandomIdentitySampler(dataset.train, args.train_batch, args.num_instances), - batch_size=args.train_batch, num_workers=args.workers, - pin_memory=pin_memory, drop_last=True, - ) - - queryloader = DataLoader( - ImageDataset(dataset.query, transform=transform_test), - batch_size=args.test_batch, shuffle=False, num_workers=args.workers, - pin_memory=pin_memory, drop_last=False, - ) - - galleryloader = DataLoader( - ImageDataset(dataset.gallery, transform=transform_test), - batch_size=args.test_batch, shuffle=False, num_workers=args.workers, - pin_memory=pin_memory, drop_last=False, + dm = ImageDataManager( + args.source, args.target, args.root, args.split_id, transform_train, transform_test, + args.train_batch, args.test_batch, args.workers, pin_memory, + cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split ) + + trainloader = dm.trainloader + testloader_dict = dm.testloader_dict print("Initializing model: {}".format(args.arch)) - model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'}) + model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) print("Model size: {:.3f} M".format(count_num_param(model))) - criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) + criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) criterion_htri = TripletLoss(margin=args.margin) optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay) @@ -188,27 +169,31 @@ def main(): checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch'] + 1 - best_rank1 = checkpoint['rank1'] print("Loaded checkpoint from '{}'".format(args.resume)) - print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, best_rank1)) + print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint['rank1'])) if use_gpu: model = nn.DataParallel(model).cuda() if args.evaluate: print("Evaluate only") - distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True) - if args.visualize_ranks: - visualize_ranked_results( - distmat, dataset, - save_dir=osp.join(args.save_dir, 'ranked_results'), - topk=20, - ) + + for name in args.target: + print("Evaluating {} ...".format(name)) + queryloader = testloader_dict[name]['query'] + galleryloader = testloader_dict[name]['gallery'] + distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True) + + if args.visualize_ranks: + visualize_ranked_results( + distmat, dataset, + save_dir=osp.join(args.save_dir, 'ranked_results', name), + topk=20 + ) return start_time = time.time() train_time = 0 - best_epoch = args.start_epoch print("==> Start training") for epoch in range(args.start_epoch, args.max_epoch): @@ -220,12 +205,12 @@ def main(): if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch: print("==> Test") - rank1 = test(model, queryloader, galleryloader, use_gpu) - is_best = rank1 > best_rank1 - if is_best: - best_rank1 = rank1 - best_epoch = epoch + 1 + for name in args.target: + print("Evaluating {} ...".format(name)) + queryloader = testloader_dict[name]['query'] + galleryloader = testloader_dict[name]['gallery'] + rank1 = test(model, queryloader, galleryloader, use_gpu) if use_gpu: state_dict = model.module.state_dict() @@ -236,9 +221,7 @@ def main(): 'state_dict': state_dict, 'rank1': rank1, 'epoch': epoch, - }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) - - print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) + }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed))