add scripts
parent
d0fe37c558
commit
a0e54ede32
|
@ -7,12 +7,9 @@ import argparse
|
|||
def init_parser():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
# ************************************************************
|
||||
# Method
|
||||
# ************************************************************
|
||||
parser.add_argument('--application', type=str, default='image', choices=['image', 'video'],
|
||||
help='image-reid or video-reid')
|
||||
parser.add_argument('--method', type=str, default='softmax',
|
||||
parser.add_argument('--app', type=str, default='image', choices=['image', 'video'],
|
||||
help='application')
|
||||
parser.add_argument('--loss', type=str, default='softmax', choices=['softmax', 'triplet'],
|
||||
help='methodology')
|
||||
|
||||
# ************************************************************
|
||||
|
@ -20,9 +17,9 @@ def init_parser():
|
|||
# ************************************************************
|
||||
parser.add_argument('--root', type=str, default='data',
|
||||
help='root path to data directory')
|
||||
parser.add_argument('-s', '--source-names', type=str, required=True, nargs='+',
|
||||
parser.add_argument('-s', '--sources', type=str, required=True, nargs='+',
|
||||
help='source datasets (delimited by space)')
|
||||
parser.add_argument('-t', '--target-names', type=str, required=True, nargs='+',
|
||||
parser.add_argument('-t', '--targets', type=str, required=False, nargs='+',
|
||||
help='target datasets (delimited by space)')
|
||||
parser.add_argument('-j', '--workers', type=int, default=4,
|
||||
help='number of data loading workers (tips: 4 or 8 times number of gpus)')
|
||||
|
@ -102,16 +99,11 @@ def init_parser():
|
|||
help='maximum epochs to run')
|
||||
parser.add_argument('--start-epoch', type=int, default=0,
|
||||
help='manual epoch number (useful when restart)')
|
||||
|
||||
parser.add_argument('--train-batch-size', type=int, default=32,
|
||||
help='training batch size')
|
||||
parser.add_argument('--test-batch-size', type=int, default=100,
|
||||
help='test batch size')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='batch size')
|
||||
|
||||
parser.add_argument('--always-fixbase', action='store_true',
|
||||
help='always fix base network and only train specified layers')
|
||||
parser.add_argument('--fixbase-epoch', type=int, default=0,
|
||||
help='how many epochs to fix base network (only train randomly initialized classifier)')
|
||||
help='number of epochs to fix base layers')
|
||||
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
|
||||
help='open specified layers for training while keeping others frozen')
|
||||
|
||||
|
@ -153,7 +145,8 @@ def init_parser():
|
|||
# ************************************************************
|
||||
# Architecture
|
||||
# ************************************************************
|
||||
parser.add_argument('-a', '--arch', type=str, default='resnet50')
|
||||
parser.add_argument('-a', '--arch', type=str, default='resnet50',
|
||||
help='model architecture')
|
||||
parser.add_argument('--no-pretrained', action='store_true',
|
||||
help='do not load pretrained weights')
|
||||
|
||||
|
@ -170,6 +163,8 @@ def init_parser():
|
|||
help='start to evaluate after a specific epoch')
|
||||
parser.add_argument('--dist-metric', type=str, default='euclidean',
|
||||
help='distance metric')
|
||||
parser.add_argument('--ranks', type=str, default=[1, 5, 10, 20], nargs='+',
|
||||
help='cmc ranks')
|
||||
|
||||
# ************************************************************
|
||||
# Miscs
|
||||
|
@ -193,4 +188,95 @@ def init_parser():
|
|||
parser.add_argument('--visrank-topk', type=int, default=20,
|
||||
help='visualize topk ranks')
|
||||
|
||||
return parser
|
||||
return parser
|
||||
|
||||
|
||||
def imagedata_kwargs(parsed_args):
|
||||
return {
|
||||
'root': parsed_args.root,
|
||||
'sources': parsed_args.sources,
|
||||
'targets': parsed_args.targets,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'random_erase': parsed_args.random_erase,
|
||||
'color_jitter': parsed_args.color_jitter,
|
||||
'color_aug': parsed_args.color_aug,
|
||||
'use_cpu': parsed_args.use_cpu,
|
||||
'split_id': parsed_args.split_id,
|
||||
'combineall': parsed_args.combineall,
|
||||
'batch_size': parsed_args.batch_size,
|
||||
'workers': parsed_args.workers,
|
||||
'num_instances': parsed_args.num_instances,
|
||||
'train_sampler': parsed_args.train_sampler,
|
||||
# image
|
||||
'cuhk03_labeled': parsed_args.cuhk03_labeled,
|
||||
'cuhk03_classic_split': parsed_args.cuhk03_classic_split,
|
||||
'market1501_500k': parsed_args.market1501_500k,
|
||||
}
|
||||
|
||||
|
||||
def videodata_kwargs(parsed_args):
|
||||
return {
|
||||
'root': parsed_args.root,
|
||||
'sources': parsed_args.sources,
|
||||
'targets': parsed_args.targets,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'random_erase': parsed_args.random_erase,
|
||||
'color_jitter': parsed_args.color_jitter,
|
||||
'color_aug': parsed_args.color_aug,
|
||||
'use_cpu': parsed_args.use_cpu,
|
||||
'split_id': parsed_args.split_id,
|
||||
'combineall': parsed_args.combineall,
|
||||
'batch_size': parsed_args.batch_size,
|
||||
'workers': parsed_args.workers,
|
||||
'num_instances': parsed_args.num_instances,
|
||||
'train_sampler': parsed_args.train_sampler,
|
||||
# video
|
||||
'seq_len': parsed_args.seq_len,
|
||||
'sample_method': parsed_args.sample_method
|
||||
}
|
||||
|
||||
|
||||
def optimizer_kwargs(parsed_args):
|
||||
return {
|
||||
'optim': parsed_args.optim,
|
||||
'lr': parsed_args.lr,
|
||||
'weight_decay': parsed_args.weight_decay,
|
||||
'momentum': parsed_args.momentum,
|
||||
'sgd_dampening': parsed_args.sgd_dampening,
|
||||
'sgd_nesterov': parsed_args.sgd_nesterov,
|
||||
'rmsprop_alpha': parsed_args.rmsprop_alpha,
|
||||
'adam_beta1': parsed_args.adam_beta1,
|
||||
'adam_beta2': parsed_args.adam_beta2,
|
||||
'staged_lr': parsed_args.staged_lr,
|
||||
'new_layers': parsed_args.new_layers,
|
||||
'base_lr_mult': parsed_args.base_lr_mult
|
||||
}
|
||||
|
||||
|
||||
def lr_scheduler_kwargs(parsed_args):
|
||||
return {
|
||||
'lr_scheduler': parsed_args.lr_scheduler,
|
||||
'stepsize': parsed_args.stepsize,
|
||||
'gamma': parsed_args.gamma
|
||||
}
|
||||
|
||||
|
||||
def engine_run_kwargs(parsed_args):
|
||||
return {
|
||||
'save_dir': parsed_args.save_dir,
|
||||
'max_epoch': parsed_args.max_epoch,
|
||||
'start_epoch': parsed_args.start_epoch,
|
||||
'fixbase_epoch': parsed_args.fixbase_epoch,
|
||||
'open_layers': parsed_args.open_layers,
|
||||
'start_eval': parsed_args.start_eval,
|
||||
'eval_freq': parsed_args.eval_freq,
|
||||
'test_only': parsed_args.evaluate,
|
||||
'print_freq': parsed_args.print_freq,
|
||||
'dist_metric': parsed_args.dist_metric,
|
||||
'visrank': parsed_args.visrank,
|
||||
'visrank_topk': parsed_args.visrank_topk,
|
||||
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
|
||||
'ranks': parsed_args.ranks
|
||||
}
|
146
scripts/main.py
146
scripts/main.py
|
@ -1,19 +1,133 @@
|
|||
import torchreid
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
sources='market1501',
|
||||
height=128,
|
||||
width=64,
|
||||
combineall=False,
|
||||
batch_size=16
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from default_parser import (
|
||||
init_parser, imagedata_kwargs, videodata_kwargs,
|
||||
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
|
||||
)
|
||||
model = torchreid.models.build_model(
|
||||
name='squeezenet1_0',
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss='softmax'
|
||||
import torchreid
|
||||
from torchreid.utils import (
|
||||
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
|
||||
load_pretrained_weights
|
||||
)
|
||||
optimizer = torchreid.optim.build_optimizer(model)
|
||||
scheduler = torchreid.optim.build_lr_scheduler(optimizer, lr_scheduler='single_step', stepsize=20)
|
||||
engine = torchreid.engine.ImageSoftmaxEngine(datamanager, model, optimizer, scheduler=scheduler)
|
||||
#engine.run(max_epoch=1, print_freq=1, fixbase_epoch=0, open_layers='classifier', test_only=True)
|
||||
|
||||
|
||||
parser = init_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def build_datamanager(args):
|
||||
if args.app == 'image':
|
||||
return torchreid.data.ImageDataManager(**imagedata_kwargs(args))
|
||||
else:
|
||||
return torchreid.data.VideoDataManager(**videodata_kwargs(args))
|
||||
|
||||
|
||||
def build_engine(args, datamanager, model, optimizer, scheduler):
|
||||
if args.app == 'image':
|
||||
if args.loss == 'softmax':
|
||||
engine = torchreid.engine.ImageSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.ImageTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
|
||||
else:
|
||||
if args.loss == 'softmax':
|
||||
engine = torchreid.engine.VideoSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth,
|
||||
pooling_method=args.pooling_method
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.ImageTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def main():
|
||||
global args
|
||||
|
||||
set_random_seed(args.seed)
|
||||
if not args.use_avai_gpus:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||
use_gpu = (torch.cuda.is_available() and not args.use_cpu)
|
||||
log_name = 'test.log' if args.evaluate else 'train.log'
|
||||
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
||||
print('==========\nArgs:{}\n=========='.format(args))
|
||||
if use_gpu:
|
||||
print('Currently using GPU {}'.format(args.gpu_devices))
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
warnings.warn('Currently using CPU, however, GPU is highly recommended')
|
||||
|
||||
datamanager = build_datamanager(args)
|
||||
model = torchreid.models.build_model(
|
||||
name=args.arch,
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss=args.loss.lower(),
|
||||
pretrained=(not args.no_pretrained),
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
|
||||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
optimizer = torchreid.optim.build_optimizer(
|
||||
model,
|
||||
**optimizer_kwargs(args)
|
||||
)
|
||||
|
||||
scheduler = torchreid.optim.build_lr_scheduler(
|
||||
optimizer,
|
||||
**lr_scheduler_kwargs(args)
|
||||
)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
print('Building {}-engine for {}-reid'.format(args.loss, args.app))
|
||||
engine = build_engine(args, datamanager, model, optimizer, scheduler)
|
||||
|
||||
engine.run(**engine_run_kwargs(args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
|
||||
|
||||
def readme():
|
||||
with open('README.md') as f:
|
||||
with open('README.rst') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
|
|
|
@ -51,7 +51,8 @@ class Engine(object):
|
|||
max_epoch (int): maximum epoch.
|
||||
start_epoch (int, optional): starting epoch. Default is 0.
|
||||
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
|
||||
while keeping base layers frozen. Default is 0.
|
||||
while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is not counted
|
||||
in ``max_epoch``.
|
||||
open_layers (str or list, optional): layers (attribute names) open for training.
|
||||
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
|
||||
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
|
||||
|
@ -114,18 +115,19 @@ class Engine(object):
|
|||
)
|
||||
self._save_checkpoint(epoch, rank1, save_dir)
|
||||
|
||||
print('=> Final test')
|
||||
rank1 = self.test(
|
||||
epoch,
|
||||
testloader,
|
||||
dist_metric=dist_metric,
|
||||
visrank=visrank,
|
||||
visrank_topk=visrank_topk,
|
||||
save_dir=save_dir,
|
||||
use_metric_cuhk03=use_metric_cuhk03,
|
||||
ranks=ranks
|
||||
)
|
||||
self._save_checkpoint(epoch, rank1, save_dir)
|
||||
if max_epoch > 0:
|
||||
print('=> Final test')
|
||||
rank1 = self.test(
|
||||
epoch,
|
||||
testloader,
|
||||
dist_metric=dist_metric,
|
||||
visrank=visrank,
|
||||
visrank_topk=visrank_topk,
|
||||
save_dir=save_dir,
|
||||
use_metric_cuhk03=use_metric_cuhk03,
|
||||
ranks=ranks
|
||||
)
|
||||
self._save_checkpoint(epoch, rank1, save_dir)
|
||||
|
||||
elapsed = round(time.time() - time_start)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
|
Loading…
Reference in New Issue