support multi-dataset training
parent
81d3cc3995
commit
a96dd6b916
|
@ -65,4 +65,5 @@ class ImageDataManager(object):
|
|||
print(" # train images : {}".format(len(self.train)))
|
||||
print(" # train cameras : {}".format(self.num_train_cams))
|
||||
print(" test names : {}".format(self.test_names))
|
||||
print(" *****************************************")
|
||||
print(" *****************************************")
|
||||
print("\n")
|
|
@ -12,11 +12,9 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from torchreid import data_manager
|
||||
from torchreid.dataset_loader import ImageDataset
|
||||
from torchreid.data_manager import ImageDataManager
|
||||
from torchreid.transforms import build_transforms
|
||||
from torchreid import models
|
||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||
|
@ -34,8 +32,8 @@ parser = argparse.ArgumentParser(description='Train image model with cross entro
|
|||
# Datasets
|
||||
parser.add_argument('--root', type=str, default='data',
|
||||
help="root path to data directory")
|
||||
parser.add_argument('-d', '--dataset', type=str, default='market1501',
|
||||
choices=data_manager.get_names())
|
||||
parser.add_argument('-s', '--source', type=str, required=True, nargs='+')
|
||||
parser.add_argument('-t', '--target', type=str, required=True, nargs='+')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int,
|
||||
help="number of data loading workers (default: 4)")
|
||||
parser.add_argument('--height', type=int, default=256,
|
||||
|
@ -110,11 +108,10 @@ parser.add_argument('--visualize-ranks', action='store_true',
|
|||
|
||||
# global variables
|
||||
args = parser.parse_args()
|
||||
best_rank1 = -np.inf
|
||||
|
||||
|
||||
def main():
|
||||
global args, best_rank1
|
||||
global args
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||
|
@ -134,41 +131,25 @@ def main():
|
|||
else:
|
||||
print("Currently using CPU (GPU is highly recommended)")
|
||||
|
||||
print("Initializing dataset {}".format(args.dataset))
|
||||
dataset = data_manager.init_imgreid_dataset(
|
||||
root=args.root, name=args.dataset, split_id=args.split_id,
|
||||
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
||||
)
|
||||
|
||||
transform_train = build_transforms(args.height, args.width, is_train=True)
|
||||
transform_test = build_transforms(args.height, args.width, is_train=False)
|
||||
|
||||
pin_memory = True if use_gpu else False
|
||||
|
||||
trainloader = DataLoader(
|
||||
ImageDataset(dataset.train, transform=transform_train),
|
||||
sampler=RandomIdentitySampler(dataset.train, args.train_batch, args.num_instances),
|
||||
batch_size=args.train_batch, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=True,
|
||||
)
|
||||
|
||||
queryloader = DataLoader(
|
||||
ImageDataset(dataset.query, transform=transform_test),
|
||||
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=False,
|
||||
)
|
||||
|
||||
galleryloader = DataLoader(
|
||||
ImageDataset(dataset.gallery, transform=transform_test),
|
||||
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=False,
|
||||
dm = ImageDataManager(
|
||||
args.source, args.target, args.root, args.split_id, transform_train, transform_test,
|
||||
args.train_batch, args.test_batch, args.workers, pin_memory,
|
||||
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split
|
||||
)
|
||||
|
||||
trainloader = dm.trainloader
|
||||
testloader_dict = dm.testloader_dict
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
|
||||
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
|
||||
print("Model size: {:.3f} M".format(count_num_param(model)))
|
||||
|
||||
criterion = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
criterion = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth)
|
||||
criterion_htri = TripletLoss(margin=args.margin)
|
||||
|
||||
optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)
|
||||
|
@ -188,27 +169,31 @@ def main():
|
|||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
best_rank1 = checkpoint['rank1']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, best_rank1))
|
||||
print("- start_epoch: {}\n- rank1: {}".format(args.start_epoch, checkpoint['rank1']))
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
if args.evaluate:
|
||||
print("Evaluate only")
|
||||
distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
|
||||
if args.visualize_ranks:
|
||||
visualize_ranked_results(
|
||||
distmat, dataset,
|
||||
save_dir=osp.join(args.save_dir, 'ranked_results'),
|
||||
topk=20,
|
||||
)
|
||||
|
||||
for name in args.target:
|
||||
print("Evaluating {} ...".format(name))
|
||||
queryloader = testloader_dict[name]['query']
|
||||
galleryloader = testloader_dict[name]['gallery']
|
||||
distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True)
|
||||
|
||||
if args.visualize_ranks:
|
||||
visualize_ranked_results(
|
||||
distmat, dataset,
|
||||
save_dir=osp.join(args.save_dir, 'ranked_results', name),
|
||||
topk=20
|
||||
)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
train_time = 0
|
||||
best_epoch = args.start_epoch
|
||||
print("==> Start training")
|
||||
|
||||
for epoch in range(args.start_epoch, args.max_epoch):
|
||||
|
@ -220,12 +205,12 @@ def main():
|
|||
|
||||
if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
||||
print("==> Test")
|
||||
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||
is_best = rank1 > best_rank1
|
||||
|
||||
if is_best:
|
||||
best_rank1 = rank1
|
||||
best_epoch = epoch + 1
|
||||
for name in args.target:
|
||||
print("Evaluating {} ...".format(name))
|
||||
queryloader = testloader_dict[name]['query']
|
||||
galleryloader = testloader_dict[name]['gallery']
|
||||
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||
|
||||
if use_gpu:
|
||||
state_dict = model.module.state_dict()
|
||||
|
@ -236,9 +221,7 @@ def main():
|
|||
'state_dict': state_dict,
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
|
||||
print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))
|
||||
}, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
Loading…
Reference in New Issue