add scripts
parent
d0fe37c558
commit
a0e54ede32
|
@ -7,12 +7,9 @@ import argparse
|
||||||
def init_parser():
|
def init_parser():
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
# ************************************************************
|
parser.add_argument('--app', type=str, default='image', choices=['image', 'video'],
|
||||||
# Method
|
help='application')
|
||||||
# ************************************************************
|
parser.add_argument('--loss', type=str, default='softmax', choices=['softmax', 'triplet'],
|
||||||
parser.add_argument('--application', type=str, default='image', choices=['image', 'video'],
|
|
||||||
help='image-reid or video-reid')
|
|
||||||
parser.add_argument('--method', type=str, default='softmax',
|
|
||||||
help='methodology')
|
help='methodology')
|
||||||
|
|
||||||
# ************************************************************
|
# ************************************************************
|
||||||
|
@ -20,9 +17,9 @@ def init_parser():
|
||||||
# ************************************************************
|
# ************************************************************
|
||||||
parser.add_argument('--root', type=str, default='data',
|
parser.add_argument('--root', type=str, default='data',
|
||||||
help='root path to data directory')
|
help='root path to data directory')
|
||||||
parser.add_argument('-s', '--source-names', type=str, required=True, nargs='+',
|
parser.add_argument('-s', '--sources', type=str, required=True, nargs='+',
|
||||||
help='source datasets (delimited by space)')
|
help='source datasets (delimited by space)')
|
||||||
parser.add_argument('-t', '--target-names', type=str, required=True, nargs='+',
|
parser.add_argument('-t', '--targets', type=str, required=False, nargs='+',
|
||||||
help='target datasets (delimited by space)')
|
help='target datasets (delimited by space)')
|
||||||
parser.add_argument('-j', '--workers', type=int, default=4,
|
parser.add_argument('-j', '--workers', type=int, default=4,
|
||||||
help='number of data loading workers (tips: 4 or 8 times number of gpus)')
|
help='number of data loading workers (tips: 4 or 8 times number of gpus)')
|
||||||
|
@ -102,16 +99,11 @@ def init_parser():
|
||||||
help='maximum epochs to run')
|
help='maximum epochs to run')
|
||||||
parser.add_argument('--start-epoch', type=int, default=0,
|
parser.add_argument('--start-epoch', type=int, default=0,
|
||||||
help='manual epoch number (useful when restart)')
|
help='manual epoch number (useful when restart)')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=32,
|
||||||
parser.add_argument('--train-batch-size', type=int, default=32,
|
help='batch size')
|
||||||
help='training batch size')
|
|
||||||
parser.add_argument('--test-batch-size', type=int, default=100,
|
|
||||||
help='test batch size')
|
|
||||||
|
|
||||||
parser.add_argument('--always-fixbase', action='store_true',
|
|
||||||
help='always fix base network and only train specified layers')
|
|
||||||
parser.add_argument('--fixbase-epoch', type=int, default=0,
|
parser.add_argument('--fixbase-epoch', type=int, default=0,
|
||||||
help='how many epochs to fix base network (only train randomly initialized classifier)')
|
help='number of epochs to fix base layers')
|
||||||
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
|
parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'],
|
||||||
help='open specified layers for training while keeping others frozen')
|
help='open specified layers for training while keeping others frozen')
|
||||||
|
|
||||||
|
@ -153,7 +145,8 @@ def init_parser():
|
||||||
# ************************************************************
|
# ************************************************************
|
||||||
# Architecture
|
# Architecture
|
||||||
# ************************************************************
|
# ************************************************************
|
||||||
parser.add_argument('-a', '--arch', type=str, default='resnet50')
|
parser.add_argument('-a', '--arch', type=str, default='resnet50',
|
||||||
|
help='model architecture')
|
||||||
parser.add_argument('--no-pretrained', action='store_true',
|
parser.add_argument('--no-pretrained', action='store_true',
|
||||||
help='do not load pretrained weights')
|
help='do not load pretrained weights')
|
||||||
|
|
||||||
|
@ -170,6 +163,8 @@ def init_parser():
|
||||||
help='start to evaluate after a specific epoch')
|
help='start to evaluate after a specific epoch')
|
||||||
parser.add_argument('--dist-metric', type=str, default='euclidean',
|
parser.add_argument('--dist-metric', type=str, default='euclidean',
|
||||||
help='distance metric')
|
help='distance metric')
|
||||||
|
parser.add_argument('--ranks', type=str, default=[1, 5, 10, 20], nargs='+',
|
||||||
|
help='cmc ranks')
|
||||||
|
|
||||||
# ************************************************************
|
# ************************************************************
|
||||||
# Miscs
|
# Miscs
|
||||||
|
@ -193,4 +188,95 @@ def init_parser():
|
||||||
parser.add_argument('--visrank-topk', type=int, default=20,
|
parser.add_argument('--visrank-topk', type=int, default=20,
|
||||||
help='visualize topk ranks')
|
help='visualize topk ranks')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def imagedata_kwargs(parsed_args):
|
||||||
|
return {
|
||||||
|
'root': parsed_args.root,
|
||||||
|
'sources': parsed_args.sources,
|
||||||
|
'targets': parsed_args.targets,
|
||||||
|
'height': parsed_args.height,
|
||||||
|
'width': parsed_args.width,
|
||||||
|
'random_erase': parsed_args.random_erase,
|
||||||
|
'color_jitter': parsed_args.color_jitter,
|
||||||
|
'color_aug': parsed_args.color_aug,
|
||||||
|
'use_cpu': parsed_args.use_cpu,
|
||||||
|
'split_id': parsed_args.split_id,
|
||||||
|
'combineall': parsed_args.combineall,
|
||||||
|
'batch_size': parsed_args.batch_size,
|
||||||
|
'workers': parsed_args.workers,
|
||||||
|
'num_instances': parsed_args.num_instances,
|
||||||
|
'train_sampler': parsed_args.train_sampler,
|
||||||
|
# image
|
||||||
|
'cuhk03_labeled': parsed_args.cuhk03_labeled,
|
||||||
|
'cuhk03_classic_split': parsed_args.cuhk03_classic_split,
|
||||||
|
'market1501_500k': parsed_args.market1501_500k,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def videodata_kwargs(parsed_args):
|
||||||
|
return {
|
||||||
|
'root': parsed_args.root,
|
||||||
|
'sources': parsed_args.sources,
|
||||||
|
'targets': parsed_args.targets,
|
||||||
|
'height': parsed_args.height,
|
||||||
|
'width': parsed_args.width,
|
||||||
|
'random_erase': parsed_args.random_erase,
|
||||||
|
'color_jitter': parsed_args.color_jitter,
|
||||||
|
'color_aug': parsed_args.color_aug,
|
||||||
|
'use_cpu': parsed_args.use_cpu,
|
||||||
|
'split_id': parsed_args.split_id,
|
||||||
|
'combineall': parsed_args.combineall,
|
||||||
|
'batch_size': parsed_args.batch_size,
|
||||||
|
'workers': parsed_args.workers,
|
||||||
|
'num_instances': parsed_args.num_instances,
|
||||||
|
'train_sampler': parsed_args.train_sampler,
|
||||||
|
# video
|
||||||
|
'seq_len': parsed_args.seq_len,
|
||||||
|
'sample_method': parsed_args.sample_method
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer_kwargs(parsed_args):
|
||||||
|
return {
|
||||||
|
'optim': parsed_args.optim,
|
||||||
|
'lr': parsed_args.lr,
|
||||||
|
'weight_decay': parsed_args.weight_decay,
|
||||||
|
'momentum': parsed_args.momentum,
|
||||||
|
'sgd_dampening': parsed_args.sgd_dampening,
|
||||||
|
'sgd_nesterov': parsed_args.sgd_nesterov,
|
||||||
|
'rmsprop_alpha': parsed_args.rmsprop_alpha,
|
||||||
|
'adam_beta1': parsed_args.adam_beta1,
|
||||||
|
'adam_beta2': parsed_args.adam_beta2,
|
||||||
|
'staged_lr': parsed_args.staged_lr,
|
||||||
|
'new_layers': parsed_args.new_layers,
|
||||||
|
'base_lr_mult': parsed_args.base_lr_mult
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def lr_scheduler_kwargs(parsed_args):
|
||||||
|
return {
|
||||||
|
'lr_scheduler': parsed_args.lr_scheduler,
|
||||||
|
'stepsize': parsed_args.stepsize,
|
||||||
|
'gamma': parsed_args.gamma
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def engine_run_kwargs(parsed_args):
|
||||||
|
return {
|
||||||
|
'save_dir': parsed_args.save_dir,
|
||||||
|
'max_epoch': parsed_args.max_epoch,
|
||||||
|
'start_epoch': parsed_args.start_epoch,
|
||||||
|
'fixbase_epoch': parsed_args.fixbase_epoch,
|
||||||
|
'open_layers': parsed_args.open_layers,
|
||||||
|
'start_eval': parsed_args.start_eval,
|
||||||
|
'eval_freq': parsed_args.eval_freq,
|
||||||
|
'test_only': parsed_args.evaluate,
|
||||||
|
'print_freq': parsed_args.print_freq,
|
||||||
|
'dist_metric': parsed_args.dist_metric,
|
||||||
|
'visrank': parsed_args.visrank,
|
||||||
|
'visrank_topk': parsed_args.visrank_topk,
|
||||||
|
'use_metric_cuhk03': parsed_args.use_metric_cuhk03,
|
||||||
|
'ranks': parsed_args.ranks
|
||||||
|
}
|
146
scripts/main.py
146
scripts/main.py
|
@ -1,19 +1,133 @@
|
||||||
import torchreid
|
import sys
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
|
||||||
datamanager = torchreid.data.ImageDataManager(
|
import torch
|
||||||
root='reid-data',
|
import torch.nn as nn
|
||||||
sources='market1501',
|
|
||||||
height=128,
|
from default_parser import (
|
||||||
width=64,
|
init_parser, imagedata_kwargs, videodata_kwargs,
|
||||||
combineall=False,
|
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
|
||||||
batch_size=16
|
|
||||||
)
|
)
|
||||||
model = torchreid.models.build_model(
|
import torchreid
|
||||||
name='squeezenet1_0',
|
from torchreid.utils import (
|
||||||
num_classes=datamanager.num_train_pids,
|
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
|
||||||
loss='softmax'
|
load_pretrained_weights
|
||||||
)
|
)
|
||||||
optimizer = torchreid.optim.build_optimizer(model)
|
|
||||||
scheduler = torchreid.optim.build_lr_scheduler(optimizer, lr_scheduler='single_step', stepsize=20)
|
|
||||||
engine = torchreid.engine.ImageSoftmaxEngine(datamanager, model, optimizer, scheduler=scheduler)
|
parser = init_parser()
|
||||||
#engine.run(max_epoch=1, print_freq=1, fixbase_epoch=0, open_layers='classifier', test_only=True)
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_datamanager(args):
|
||||||
|
if args.app == 'image':
|
||||||
|
return torchreid.data.ImageDataManager(**imagedata_kwargs(args))
|
||||||
|
else:
|
||||||
|
return torchreid.data.VideoDataManager(**videodata_kwargs(args))
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine(args, datamanager, model, optimizer, scheduler):
|
||||||
|
if args.app == 'image':
|
||||||
|
if args.loss == 'softmax':
|
||||||
|
engine = torchreid.engine.ImageSoftmaxEngine(
|
||||||
|
datamanager,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
use_cpu=args.use_cpu,
|
||||||
|
label_smooth=args.label_smooth
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
engine = torchreid.engine.ImageTripletEngine(
|
||||||
|
datamanager,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
margin=args.margin,
|
||||||
|
weight_t=args.weight_t,
|
||||||
|
weight_x=args.weight_x,
|
||||||
|
scheduler=scheduler,
|
||||||
|
use_cpu=args.use_cpu,
|
||||||
|
label_smooth=args.label_smooth
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if args.loss == 'softmax':
|
||||||
|
engine = torchreid.engine.VideoSoftmaxEngine(
|
||||||
|
datamanager,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
use_cpu=args.use_cpu,
|
||||||
|
label_smooth=args.label_smooth,
|
||||||
|
pooling_method=args.pooling_method
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
engine = torchreid.engine.ImageTripletEngine(
|
||||||
|
datamanager,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
margin=args.margin,
|
||||||
|
weight_t=args.weight_t,
|
||||||
|
weight_x=args.weight_x,
|
||||||
|
scheduler=scheduler,
|
||||||
|
use_cpu=args.use_cpu,
|
||||||
|
label_smooth=args.label_smooth
|
||||||
|
)
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global args
|
||||||
|
|
||||||
|
set_random_seed(args.seed)
|
||||||
|
if not args.use_avai_gpus:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||||
|
use_gpu = (torch.cuda.is_available() and not args.use_cpu)
|
||||||
|
log_name = 'test.log' if args.evaluate else 'train.log'
|
||||||
|
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
||||||
|
print('==========\nArgs:{}\n=========='.format(args))
|
||||||
|
if use_gpu:
|
||||||
|
print('Currently using GPU {}'.format(args.gpu_devices))
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
else:
|
||||||
|
warnings.warn('Currently using CPU, however, GPU is highly recommended')
|
||||||
|
|
||||||
|
datamanager = build_datamanager(args)
|
||||||
|
model = torchreid.models.build_model(
|
||||||
|
name=args.arch,
|
||||||
|
num_classes=datamanager.num_train_pids,
|
||||||
|
loss=args.loss.lower(),
|
||||||
|
pretrained=(not args.no_pretrained),
|
||||||
|
use_gpu=use_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.load_weights and check_isfile(args.load_weights):
|
||||||
|
load_pretrained_weights(model, args.load_weights)
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
model = nn.DataParallel(model).cuda()
|
||||||
|
|
||||||
|
optimizer = torchreid.optim.build_optimizer(
|
||||||
|
model,
|
||||||
|
**optimizer_kwargs(args)
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = torchreid.optim.build_lr_scheduler(
|
||||||
|
optimizer,
|
||||||
|
**lr_scheduler_kwargs(args)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.resume and check_isfile(args.resume):
|
||||||
|
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||||
|
|
||||||
|
print('Building {}-engine for {}-reid'.format(args.loss, args.app))
|
||||||
|
engine = build_engine(args, datamanager, model, optimizer, scheduler)
|
||||||
|
|
||||||
|
engine.run(**engine_run_kwargs(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def readme():
|
def readme():
|
||||||
with open('README.md') as f:
|
with open('README.rst') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,8 @@ class Engine(object):
|
||||||
max_epoch (int): maximum epoch.
|
max_epoch (int): maximum epoch.
|
||||||
start_epoch (int, optional): starting epoch. Default is 0.
|
start_epoch (int, optional): starting epoch. Default is 0.
|
||||||
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
|
fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)
|
||||||
while keeping base layers frozen. Default is 0.
|
while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is not counted
|
||||||
|
in ``max_epoch``.
|
||||||
open_layers (str or list, optional): layers (attribute names) open for training.
|
open_layers (str or list, optional): layers (attribute names) open for training.
|
||||||
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
|
start_eval (int, optional): from which epoch to start evaluation. Default is 0.
|
||||||
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
|
eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation
|
||||||
|
@ -114,18 +115,19 @@ class Engine(object):
|
||||||
)
|
)
|
||||||
self._save_checkpoint(epoch, rank1, save_dir)
|
self._save_checkpoint(epoch, rank1, save_dir)
|
||||||
|
|
||||||
print('=> Final test')
|
if max_epoch > 0:
|
||||||
rank1 = self.test(
|
print('=> Final test')
|
||||||
epoch,
|
rank1 = self.test(
|
||||||
testloader,
|
epoch,
|
||||||
dist_metric=dist_metric,
|
testloader,
|
||||||
visrank=visrank,
|
dist_metric=dist_metric,
|
||||||
visrank_topk=visrank_topk,
|
visrank=visrank,
|
||||||
save_dir=save_dir,
|
visrank_topk=visrank_topk,
|
||||||
use_metric_cuhk03=use_metric_cuhk03,
|
save_dir=save_dir,
|
||||||
ranks=ranks
|
use_metric_cuhk03=use_metric_cuhk03,
|
||||||
)
|
ranks=ranks
|
||||||
self._save_checkpoint(epoch, rank1, save_dir)
|
)
|
||||||
|
self._save_checkpoint(epoch, rank1, save_dir)
|
||||||
|
|
||||||
elapsed = round(time.time() - time_start)
|
elapsed = round(time.time() - time_start)
|
||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
|
|
Loading…
Reference in New Issue