2019-03-24 19:06:12 +08:00
|
|
|
import sys
|
|
|
|
import os
|
|
|
|
import os.path as osp
|
2019-05-24 22:34:27 +08:00
|
|
|
import time
|
2019-08-26 17:34:31 +08:00
|
|
|
import argparse
|
2019-03-24 19:06:12 +08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
from default_config import (
|
|
|
|
get_default_config, imagedata_kwargs, videodata_kwargs,
|
2019-03-24 19:06:12 +08:00
|
|
|
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
|
2019-03-20 01:26:08 +08:00
|
|
|
)
|
2019-03-24 19:06:12 +08:00
|
|
|
import torchreid
|
|
|
|
from torchreid.utils import (
|
|
|
|
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
|
2019-05-24 22:34:27 +08:00
|
|
|
load_pretrained_weights, compute_model_complexity, collect_env_info
|
2019-03-20 01:26:08 +08:00
|
|
|
)
|
2019-03-24 19:06:12 +08:00
|
|
|
|
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
def build_datamanager(cfg):
|
|
|
|
if cfg.data.type == 'image':
|
|
|
|
return torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
|
2019-03-24 19:06:12 +08:00
|
|
|
else:
|
2019-08-26 17:34:31 +08:00
|
|
|
return torchreid.data.VideoDataManager(**videodata_kwargs(cfg))
|
2019-03-24 19:06:12 +08:00
|
|
|
|
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
def build_engine(cfg, datamanager, model, optimizer, scheduler):
|
|
|
|
if cfg.data.type == 'image':
|
|
|
|
if cfg.loss.name == 'softmax':
|
2019-03-24 19:06:12 +08:00
|
|
|
engine = torchreid.engine.ImageSoftmaxEngine(
|
|
|
|
datamanager,
|
|
|
|
model,
|
2019-11-28 00:35:54 +08:00
|
|
|
optimizer=optimizer,
|
2019-03-24 19:06:12 +08:00
|
|
|
scheduler=scheduler,
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu=cfg.use_gpu,
|
|
|
|
label_smooth=cfg.loss.softmax.label_smooth
|
2019-03-24 19:06:12 +08:00
|
|
|
)
|
2019-11-28 00:35:54 +08:00
|
|
|
|
2019-03-24 19:06:12 +08:00
|
|
|
else:
|
|
|
|
engine = torchreid.engine.ImageTripletEngine(
|
|
|
|
datamanager,
|
|
|
|
model,
|
2019-11-28 00:35:54 +08:00
|
|
|
optimizer=optimizer,
|
2019-08-26 17:34:31 +08:00
|
|
|
margin=cfg.loss.triplet.margin,
|
|
|
|
weight_t=cfg.loss.triplet.weight_t,
|
|
|
|
weight_x=cfg.loss.triplet.weight_x,
|
2019-03-24 19:06:12 +08:00
|
|
|
scheduler=scheduler,
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu=cfg.use_gpu,
|
|
|
|
label_smooth=cfg.loss.softmax.label_smooth
|
2019-03-24 19:06:12 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
2019-08-26 17:34:31 +08:00
|
|
|
if cfg.loss.name == 'softmax':
|
2019-03-24 19:06:12 +08:00
|
|
|
engine = torchreid.engine.VideoSoftmaxEngine(
|
|
|
|
datamanager,
|
|
|
|
model,
|
2019-11-28 00:35:54 +08:00
|
|
|
optimizer=optimizer,
|
2019-03-24 19:06:12 +08:00
|
|
|
scheduler=scheduler,
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu=cfg.use_gpu,
|
|
|
|
label_smooth=cfg.loss.softmax.label_smooth,
|
|
|
|
pooling_method=cfg.video.pooling_method
|
2019-03-24 19:06:12 +08:00
|
|
|
)
|
2019-11-28 00:35:54 +08:00
|
|
|
|
2019-03-24 19:06:12 +08:00
|
|
|
else:
|
2019-05-06 18:04:25 +08:00
|
|
|
engine = torchreid.engine.VideoTripletEngine(
|
2019-03-24 19:06:12 +08:00
|
|
|
datamanager,
|
|
|
|
model,
|
2019-11-28 00:35:54 +08:00
|
|
|
optimizer=optimizer,
|
2019-08-26 17:34:31 +08:00
|
|
|
margin=cfg.loss.triplet.margin,
|
|
|
|
weight_t=cfg.loss.triplet.weight_t,
|
|
|
|
weight_x=cfg.loss.triplet.weight_x,
|
2019-03-24 19:06:12 +08:00
|
|
|
scheduler=scheduler,
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu=cfg.use_gpu,
|
|
|
|
label_smooth=cfg.loss.softmax.label_smooth
|
2019-03-24 19:06:12 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
return engine
|
|
|
|
|
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
def reset_config(cfg, args):
|
|
|
|
if args.root:
|
|
|
|
cfg.data.root = args.root
|
|
|
|
if args.sources:
|
|
|
|
cfg.data.sources = args.sources
|
|
|
|
if args.targets:
|
|
|
|
cfg.data.targets = args.targets
|
|
|
|
if args.transforms:
|
|
|
|
cfg.data.transforms = args.transforms
|
2019-03-24 19:06:12 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('--config-file', type=str, default='', help='path to config file')
|
|
|
|
parser.add_argument('-s', '--sources', type=str, nargs='+', help='source datasets (delimited by space)')
|
|
|
|
parser.add_argument('-t', '--targets', type=str, nargs='+', help='target datasets (delimited by space)')
|
|
|
|
parser.add_argument('--transforms', type=str, nargs='+', help='data augmentation')
|
|
|
|
parser.add_argument('--root', type=str, default='', help='path to data root')
|
|
|
|
parser.add_argument('--gpu-devices', type=str, default='',)
|
|
|
|
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='Modify config options using the command-line')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
cfg = get_default_config()
|
|
|
|
cfg.use_gpu = torch.cuda.is_available()
|
|
|
|
if args.config_file:
|
|
|
|
cfg.merge_from_file(args.config_file)
|
|
|
|
reset_config(cfg, args)
|
|
|
|
cfg.merge_from_list(args.opts)
|
|
|
|
set_random_seed(cfg.train.seed)
|
|
|
|
|
|
|
|
if cfg.use_gpu and args.gpu_devices:
|
|
|
|
# if gpu_devices is not specified, all available gpus will be used
|
2019-03-24 19:06:12 +08:00
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
2019-08-26 17:34:31 +08:00
|
|
|
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
|
2019-05-24 22:34:27 +08:00
|
|
|
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
2019-08-26 17:34:31 +08:00
|
|
|
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
|
|
|
|
|
|
|
|
print('Show configuration\n{}\n'.format(cfg))
|
2019-05-24 22:34:27 +08:00
|
|
|
print('Collecting env info ...')
|
|
|
|
print('** System info **\n{}\n'.format(collect_env_info()))
|
2019-08-26 17:34:31 +08:00
|
|
|
|
|
|
|
if cfg.use_gpu:
|
2019-03-24 19:06:12 +08:00
|
|
|
torch.backends.cudnn.benchmark = True
|
2019-04-23 18:35:22 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
datamanager = build_datamanager(cfg)
|
|
|
|
|
|
|
|
print('Building model: {}'.format(cfg.model.name))
|
2019-03-24 19:06:12 +08:00
|
|
|
model = torchreid.models.build_model(
|
2019-08-26 17:34:31 +08:00
|
|
|
name=cfg.model.name,
|
2019-03-24 19:06:12 +08:00
|
|
|
num_classes=datamanager.num_train_pids,
|
2019-08-26 17:34:31 +08:00
|
|
|
loss=cfg.loss.name,
|
|
|
|
pretrained=cfg.model.pretrained,
|
|
|
|
use_gpu=cfg.use_gpu
|
2019-03-24 19:06:12 +08:00
|
|
|
)
|
2019-08-26 17:34:31 +08:00
|
|
|
num_params, flops = compute_model_complexity(model, (1, 3, cfg.data.height, cfg.data.width))
|
2019-05-23 05:09:09 +08:00
|
|
|
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
2019-03-24 19:06:12 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
if cfg.model.load_weights and check_isfile(cfg.model.load_weights):
|
|
|
|
load_pretrained_weights(model, cfg.model.load_weights)
|
2019-03-24 19:06:12 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
if cfg.use_gpu:
|
2019-03-24 19:06:12 +08:00
|
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
|
|
|
|
scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg))
|
2019-03-24 19:06:12 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
if cfg.model.resume and check_isfile(cfg.model.resume):
|
2019-09-12 17:34:37 +08:00
|
|
|
cfg.train.start_epoch = resume_from_checkpoint(cfg.model.resume, model, optimizer=optimizer)
|
2019-03-24 19:06:12 +08:00
|
|
|
|
2019-08-26 17:34:31 +08:00
|
|
|
print('Building {}-engine for {}-reid'.format(cfg.loss.name, cfg.data.type))
|
|
|
|
engine = build_engine(cfg, datamanager, model, optimizer, scheduler)
|
|
|
|
engine.run(**engine_run_kwargs(cfg))
|
2019-03-24 19:06:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2019-09-12 17:34:37 +08:00
|
|
|
main()
|