deep-person-reid/scripts/main.py

192 lines
5.7 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
import sys
2019-05-24 22:34:27 +08:00
import time
2019-12-01 10:35:44 +08:00
import os.path as osp
2019-08-26 17:34:31 +08:00
import argparse
2019-03-24 19:06:12 +08:00
import torch
import torch.nn as nn
2019-03-20 01:26:08 +08:00
2019-03-24 19:06:12 +08:00
import torchreid
from torchreid.utils import (
2019-12-01 10:35:44 +08:00
Logger, check_isfile, set_random_seed, collect_env_info,
resume_from_checkpoint, load_pretrained_weights, compute_model_complexity
)
from default_config import (
imagedata_kwargs, optimizer_kwargs, videodata_kwargs, engine_run_kwargs,
get_default_config, lr_scheduler_kwargs
2019-03-20 01:26:08 +08:00
)
2019-03-24 19:06:12 +08:00
2019-08-26 17:34:31 +08:00
def build_datamanager(cfg):
if cfg.data.type == 'image':
return torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
2019-03-24 19:06:12 +08:00
else:
2019-08-26 17:34:31 +08:00
return torchreid.data.VideoDataManager(**videodata_kwargs(cfg))
2019-03-24 19:06:12 +08:00
2019-08-26 17:34:31 +08:00
def build_engine(cfg, datamanager, model, optimizer, scheduler):
if cfg.data.type == 'image':
if cfg.loss.name == 'softmax':
2019-03-24 19:06:12 +08:00
engine = torchreid.engine.ImageSoftmaxEngine(
datamanager,
model,
2019-11-28 00:35:54 +08:00
optimizer=optimizer,
2019-03-24 19:06:12 +08:00
scheduler=scheduler,
2019-08-26 17:34:31 +08:00
use_gpu=cfg.use_gpu,
label_smooth=cfg.loss.softmax.label_smooth
2019-03-24 19:06:12 +08:00
)
2019-12-01 10:35:44 +08:00
2019-03-24 19:06:12 +08:00
else:
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
2019-11-28 00:35:54 +08:00
optimizer=optimizer,
2019-08-26 17:34:31 +08:00
margin=cfg.loss.triplet.margin,
weight_t=cfg.loss.triplet.weight_t,
weight_x=cfg.loss.triplet.weight_x,
2019-03-24 19:06:12 +08:00
scheduler=scheduler,
2019-08-26 17:34:31 +08:00
use_gpu=cfg.use_gpu,
label_smooth=cfg.loss.softmax.label_smooth
2019-03-24 19:06:12 +08:00
)
2019-12-01 10:35:44 +08:00
2019-03-24 19:06:12 +08:00
else:
2019-08-26 17:34:31 +08:00
if cfg.loss.name == 'softmax':
2019-03-24 19:06:12 +08:00
engine = torchreid.engine.VideoSoftmaxEngine(
datamanager,
model,
2019-11-28 00:35:54 +08:00
optimizer=optimizer,
2019-03-24 19:06:12 +08:00
scheduler=scheduler,
2019-08-26 17:34:31 +08:00
use_gpu=cfg.use_gpu,
label_smooth=cfg.loss.softmax.label_smooth,
pooling_method=cfg.video.pooling_method
2019-03-24 19:06:12 +08:00
)
2019-12-01 10:35:44 +08:00
2019-03-24 19:06:12 +08:00
else:
2019-05-06 18:04:25 +08:00
engine = torchreid.engine.VideoTripletEngine(
2019-03-24 19:06:12 +08:00
datamanager,
model,
2019-11-28 00:35:54 +08:00
optimizer=optimizer,
2019-08-26 17:34:31 +08:00
margin=cfg.loss.triplet.margin,
weight_t=cfg.loss.triplet.weight_t,
weight_x=cfg.loss.triplet.weight_x,
2019-03-24 19:06:12 +08:00
scheduler=scheduler,
2019-08-26 17:34:31 +08:00
use_gpu=cfg.use_gpu,
label_smooth=cfg.loss.softmax.label_smooth
2019-03-24 19:06:12 +08:00
)
return engine
2019-08-26 17:34:31 +08:00
def reset_config(cfg, args):
if args.root:
cfg.data.root = args.root
if args.sources:
cfg.data.sources = args.sources
if args.targets:
cfg.data.targets = args.targets
if args.transforms:
cfg.data.transforms = args.transforms
2019-03-24 19:06:12 +08:00
2019-08-26 17:34:31 +08:00
2020-09-14 17:34:18 +08:00
def check_cfg(cfg):
if cfg.loss.name == 'triplet' and cfg.loss.triplet.weight_x == 0:
assert cfg.train.fixbase_epoch == 0, \
'The output of classifier is not included in the computational graph'
2019-08-26 17:34:31 +08:00
def main():
2019-12-01 10:35:44 +08:00
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
'--config-file', type=str, default='', help='path to config file'
)
parser.add_argument(
'-s',
'--sources',
type=str,
nargs='+',
help='source datasets (delimited by space)'
)
parser.add_argument(
'-t',
'--targets',
type=str,
nargs='+',
help='target datasets (delimited by space)'
)
parser.add_argument(
'--transforms', type=str, nargs='+', help='data augmentation'
)
parser.add_argument(
'--root', type=str, default='', help='path to data root'
)
parser.add_argument(
'opts',
default=None,
nargs=argparse.REMAINDER,
help='Modify config options using the command-line'
)
2019-08-26 17:34:31 +08:00
args = parser.parse_args()
cfg = get_default_config()
cfg.use_gpu = torch.cuda.is_available()
if args.config_file:
cfg.merge_from_file(args.config_file)
reset_config(cfg, args)
cfg.merge_from_list(args.opts)
set_random_seed(cfg.train.seed)
2020-09-14 17:34:18 +08:00
check_cfg(cfg)
2019-08-26 17:34:31 +08:00
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
2019-05-24 22:34:27 +08:00
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
2019-08-26 17:34:31 +08:00
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
2019-12-01 10:35:44 +08:00
2019-08-26 17:34:31 +08:00
print('Show configuration\n{}\n'.format(cfg))
2019-05-24 22:34:27 +08:00
print('Collecting env info ...')
print('** System info **\n{}\n'.format(collect_env_info()))
2019-12-01 10:35:44 +08:00
2019-08-26 17:34:31 +08:00
if cfg.use_gpu:
2019-03-24 19:06:12 +08:00
torch.backends.cudnn.benchmark = True
2019-12-01 10:35:44 +08:00
2019-08-26 17:34:31 +08:00
datamanager = build_datamanager(cfg)
2019-12-01 10:35:44 +08:00
2019-08-26 17:34:31 +08:00
print('Building model: {}'.format(cfg.model.name))
2019-03-24 19:06:12 +08:00
model = torchreid.models.build_model(
2019-08-26 17:34:31 +08:00
name=cfg.model.name,
2019-03-24 19:06:12 +08:00
num_classes=datamanager.num_train_pids,
2019-08-26 17:34:31 +08:00
loss=cfg.loss.name,
pretrained=cfg.model.pretrained,
use_gpu=cfg.use_gpu
2019-03-24 19:06:12 +08:00
)
2019-12-01 10:35:44 +08:00
num_params, flops = compute_model_complexity(
model, (1, 3, cfg.data.height, cfg.data.width)
)
2019-05-23 05:09:09 +08:00
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
2019-03-24 19:06:12 +08:00
2019-08-26 17:34:31 +08:00
if cfg.model.load_weights and check_isfile(cfg.model.load_weights):
load_pretrained_weights(model, cfg.model.load_weights)
2019-12-01 10:35:44 +08:00
2019-08-26 17:34:31 +08:00
if cfg.use_gpu:
2019-03-24 19:06:12 +08:00
model = nn.DataParallel(model).cuda()
2019-08-26 17:34:31 +08:00
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
2019-12-01 10:35:44 +08:00
scheduler = torchreid.optim.build_lr_scheduler(
optimizer, **lr_scheduler_kwargs(cfg)
)
2019-03-24 19:06:12 +08:00
2019-08-26 17:34:31 +08:00
if cfg.model.resume and check_isfile(cfg.model.resume):
2019-12-01 10:35:44 +08:00
cfg.train.start_epoch = resume_from_checkpoint(
cfg.model.resume, model, optimizer=optimizer, scheduler=scheduler
2019-12-01 10:35:44 +08:00
)
2019-03-24 19:06:12 +08:00
2019-12-01 10:35:44 +08:00
print(
'Building {}-engine for {}-reid'.format(cfg.loss.name, cfg.data.type)
)
2019-08-26 17:34:31 +08:00
engine = build_engine(cfg, datamanager, model, optimizer, scheduler)
engine.run(**engine_run_kwargs(cfg))
2019-03-24 19:06:12 +08:00
if __name__ == '__main__':
main()