add scripts

pull/133/head
KaiyangZhou 2019-03-24 11:06:12 +00:00
parent d0fe37c558
commit a0e54ede32
4 changed files with 250 additions and 48 deletions

View File

@ -7,12 +7,9 @@ import argparse
def init_parser():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# ************************************************************
# Method
# ************************************************************
parser.add_argument('--application', type=str, default='image', choices=['image', 'video'],
help='image-reid or video-reid')
parser.add_argument('--method', type=str, default='softmax',
parser.add_argument('--app', type=str, default='image', choices=['image', 'video'],
help='application')
parser.add_argument('--loss', type=str, default='softmax', choices=['softmax', 'triplet'],
help='methodology')
# ************************************************************
@ -20,9 +17,9 @@ def init_parser():
# ************************************************************
parser.add_argument('--root', type=str, default='data',
help='root path to data directory')
parser.add_argument('-s', '--source-names', type=str, required=True, nargs='+',
parser.add_argument('-s', '--sources', type=str, required=True, nargs='+',
help='source datasets (delimited by space)')
parser.add_argument('-t', '--target-names', type=str, required=True, nargs='+',
parser.add_argument('-t', '--targets', type=str, required=False, nargs='+',
help='target datasets (delimited by space)')
parser.add_argument('-j', '--workers', type=int, default=4,
help='number of data loading workers (tips: 4 or 8 times number of gpus)')
@ -102,16 +99,11 @@ def init_parser():
help='maximum epochs to run')
parser.add_argument('--start-epoch', type=int, default=0,
help='manual epoch number (useful when restart)')
parser.add_argument('--train-batch-size', type=int, default=32,
help='training batch size')
parser.add_argument('--test-batch-size', type=int, default=100,
help='test batch size')
parser.add_argument('--batch-size', type=int, default=32,
help='batch size')
parser.add_argument('--always-fixbase', action='store_true',
help='always fix base network and only train specified layers')
parser.add_argument('--fixbase-epoch', type=int, default=0,
help='how many epochs to fix base network (only train randomly initialized classifier)')
help='number of epochs to fix base layers')
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
help='open specified layers for training while keeping others frozen')
@ -153,7 +145,8 @@ def init_parser():
# ************************************************************
# Architecture
# ************************************************************
parser.add_argument('-a', '--arch', type=str, default='resnet50')
parser.add_argument('-a', '--arch', type=str, default='resnet50',
help='model architecture')
parser.add_argument('--no-pretrained', action='store_true',
help='do not load pretrained weights')
@ -170,6 +163,8 @@ def init_parser():
help='start to evaluate after a specific epoch')
parser.add_argument('--dist-metric', type=str, default='euclidean',
help='distance metric')
parser.add_argument('--ranks', type=str, default=[1, 5, 10, 20], nargs='+',
help='cmc ranks')
# ************************************************************
# Miscs
@ -193,4 +188,95 @@ def init_parser():
parser.add_argument('--visrank-topk', type=int, default=20,
help='visualize topk ranks')
return parser
return parser
def imagedata_kwargs(parsed_args):
return {
'root': parsed_args.root,
'sources': parsed_args.sources,
'targets': parsed_args.targets,
'height': parsed_args.height,
'width': parsed_args.width,
'random_erase': parsed_args.random_erase,
'color_jitter': parsed_args.color_jitter,
'color_aug': parsed_args.color_aug,
'use_cpu': parsed_args.use_cpu,
'split_id': parsed_args.split_id,
'combineall': parsed_args.combineall,
'batch_size': parsed_args.batch_size,
'workers': parsed_args.workers,
'num_instances': parsed_args.num_instances,
'train_sampler': parsed_args.train_sampler,
# image
'cuhk03_labeled': parsed_args.cuhk03_labeled,
'cuhk03_classic_split': parsed_args.cuhk03_classic_split,
'market1501_500k': parsed_args.market1501_500k,
}
def videodata_kwargs(parsed_args):
return {
'root': parsed_args.root,
'sources': parsed_args.sources,
'targets': parsed_args.targets,
'height': parsed_args.height,
'width': parsed_args.width,
'random_erase': parsed_args.random_erase,
'color_jitter': parsed_args.color_jitter,
'color_aug': parsed_args.color_aug,
'use_cpu': parsed_args.use_cpu,
'split_id': parsed_args.split_id,
'combineall': parsed_args.combineall,
'batch_size': parsed_args.batch_size,
'workers': parsed_args.workers,
'num_instances': parsed_args.num_instances,
'train_sampler': parsed_args.train_sampler,
# video
'seq_len': parsed_args.seq_len,
'sample_method': parsed_args.sample_method
}
def optimizer_kwargs(parsed_args):
return {
'optim': parsed_args.optim,
'lr': parsed_args.lr,
'weight_decay': parsed_args.weight_decay,
'momentum': parsed_args.momentum,
'sgd_dampening': parsed_args.sgd_dampening,
'sgd_nesterov': parsed_args.sgd_nesterov,
'rmsprop_alpha': parsed_args.rmsprop_alpha,
'adam_beta1': parsed_args.adam_beta1,
'adam_beta2': parsed_args.adam_beta2,
'staged_lr': parsed_args.staged_lr,
'new_layers': parsed_args.new_layers,
'base_lr_mult': parsed_args.base_lr_mult
}
def lr_scheduler_kwargs(parsed_args):
return {
'lr_scheduler': parsed_args.lr_scheduler,
'stepsize': parsed_args.stepsize,
'gamma': parsed_args.gamma
}
def engine_run_kwargs(parsed_args):
return {
'save_dir': parsed_args.save_dir,
'max_epoch': parsed_args.max_epoch,
'start_epoch': parsed_args.start_epoch,
'fixbase_epoch': parsed_args.fixbase_epoch,
'open_layers': parsed_args.open_layers,
'start_eval': parsed_args.start_eval,
'eval_freq': parsed_args.eval_freq,
'test_only': parsed_args.evaluate,
'print_freq': parsed_args.print_freq,
'dist_metric': parsed_args.dist_metric,
'visrank': parsed_args.visrank,
'visrank_topk': parsed_args.visrank_topk,
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
'ranks': parsed_args.ranks
}

View File

@ -1,19 +1,133 @@
import torchreid
import sys
import os
import os.path as osp
import warnings
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
height=128,
width=64,
combineall=False,
batch_size=16
import torch
import torch.nn as nn
from default_parser import (
init_parser, imagedata_kwargs, videodata_kwargs,
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
)
model = torchreid.models.build_model(
name='squeezenet1_0',
num_classes=datamanager.num_train_pids,
loss='softmax'
import torchreid
from torchreid.utils import (
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
load_pretrained_weights
)
optimizer = torchreid.optim.build_optimizer(model)
scheduler = torchreid.optim.build_lr_scheduler(optimizer, lr_scheduler='single_step', stepsize=20)
engine = torchreid.engine.ImageSoftmaxEngine(datamanager, model, optimizer, scheduler=scheduler)
#engine.run(max_epoch=1, print_freq=1, fixbase_epoch=0, open_layers='classifier', test_only=True)
parser = init_parser()
args = parser.parse_args()
def build_datamanager(args):
if args.app == 'image':
return torchreid.data.ImageDataManager(**imagedata_kwargs(args))
else:
return torchreid.data.VideoDataManager(**videodata_kwargs(args))
def build_engine(args, datamanager, model, optimizer, scheduler):
if args.app == 'image':
if args.loss == 'softmax':
engine = torchreid.engine.ImageSoftmaxEngine(
datamanager,
model,
optimizer,
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth
)
else:
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
optimizer,
margin=args.margin,
weight_t=args.weight_t,
weight_x=args.weight_x,
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth
)
else:
if args.loss == 'softmax':
engine = torchreid.engine.VideoSoftmaxEngine(
datamanager,
model,
optimizer,
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth,
pooling_method=args.pooling_method
)
else:
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
optimizer,
margin=args.margin,
weight_t=args.weight_t,
weight_x=args.weight_x,
scheduler=scheduler,
use_cpu=args.use_cpu,
label_smooth=args.label_smooth
)
return engine
def main():
global args
set_random_seed(args.seed)
if not args.use_avai_gpus:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = (torch.cuda.is_available() and not args.use_cpu)
log_name = 'test.log' if args.evaluate else 'train.log'
sys.stdout = Logger(osp.join(args.save_dir, log_name))
print('==========\nArgs:{}\n=========='.format(args))
if use_gpu:
print('Currently using GPU {}'.format(args.gpu_devices))
torch.backends.cudnn.benchmark = True
else:
warnings.warn('Currently using CPU, however, GPU is highly recommended')
datamanager = build_datamanager(args)
model = torchreid.models.build_model(
name=args.arch,
num_classes=datamanager.num_train_pids,
loss=args.loss.lower(),
pretrained=(not args.no_pretrained),
use_gpu=use_gpu
)
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if use_gpu:
model = nn.DataParallel(model).cuda()
optimizer = torchreid.optim.build_optimizer(
model,
**optimizer_kwargs(args)
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
**lr_scheduler_kwargs(args)
)
if args.resume and check_isfile(args.resume):
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
print('Building {}-engine for {}-reid'.format(args.loss, args.app))
engine = build_engine(args, datamanager, model, optimizer, scheduler)
engine.run(**engine_run_kwargs(args))
if __name__ == '__main__':
main()

View File

@ -5,7 +5,7 @@ import numpy as np
def readme():
with open('README.md') as f:
with open('README.rst') as f:
content = f.read()
return content

View File

@ -51,7 +51,8 @@ class Engine(object):
max_epoch (int): maximum epoch.
start_epoch (int, optional): starting epoch. Default is 0.
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
while keeping base layers frozen. Default is 0.
while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is not counted
in ``max_epoch``.
open_layers (str or list, optional): layers (attribute names) open for training.
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
@ -114,18 +115,19 @@ class Engine(object):
)
self._save_checkpoint(epoch, rank1, save_dir)
print('=> Final test')
rank1 = self.test(
epoch,
testloader,
dist_metric=dist_metric,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks
)
self._save_checkpoint(epoch, rank1, save_dir)
if max_epoch > 0:
print('=> Final test')
rank1 = self.test(
epoch,
testloader,
dist_metric=dist_metric,
visrank=visrank,
visrank_topk=visrank_topk,
save_dir=save_dir,
use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks
)
self._save_checkpoint(epoch, rank1, save_dir)
elapsed = round(time.time() - time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))