mirror of https://github.com/JDAI-CV/fast-reid.git
411 lines
14 KiB
Python
411 lines
14 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: xyliao1993@qq.com
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import os
|
|
import sys
|
|
from os import path as osp
|
|
from pprint import pprint
|
|
|
|
import numpy as np
|
|
import torch
|
|
from datasets.samplers import RandomIdentitySampler
|
|
from tensorboardX import SummaryWriter
|
|
from torch import nn
|
|
from torch.backends import cudnn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from config import args
|
|
from datasets import data_manager
|
|
from datasets.data_loader import ImageData
|
|
from models import ResNetBuilder
|
|
from trainers import ResNetClsTrainer, ResNetTriTrainer, ResNetClsTriTrainer, ResNetEvaluator
|
|
from utils.loss import TripletLoss
|
|
from utils.serialization import Logger
|
|
from utils.serialization import save_checkpoint
|
|
from utils.transforms import TrainTransform, TestTransform
|
|
|
|
|
|
def train_classification(**kwargs):
|
|
args._parse(kwargs)
|
|
|
|
# set random seed and cudnn benchmark
|
|
torch.manual_seed(args.seed)
|
|
|
|
use_gpu = torch.cuda.is_available()
|
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
|
|
|
print('=========user config==========')
|
|
pprint(args._state_dict())
|
|
print('============end===============')
|
|
|
|
if use_gpu:
|
|
print('currently using GPU {}'.format(args.gpu))
|
|
cudnn.benchmark = True
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
|
else:
|
|
print('currently using cpu')
|
|
|
|
print('initializing dataset {}'.format(args.dataset))
|
|
dataset = data_manager.init_dataset(name=args.dataset)
|
|
|
|
pin_memory = True if use_gpu else False
|
|
|
|
tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
|
|
|
|
trainloader = DataLoader(
|
|
ImageData(dataset.train, TrainTransform(args.height, args.width)),
|
|
batch_size=args.train_batch, shuffle=True, num_workers=args.workers,
|
|
pin_memory=pin_memory, drop_last=True
|
|
)
|
|
|
|
queryloader = DataLoader(
|
|
ImageData(dataset.query, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
galleryloader = DataLoader(
|
|
ImageData(dataset.gallery, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
print('initializing model ...')
|
|
model = ResNetBuilder(num_classes=dataset.num_train_pids)
|
|
print('model size: {:.5f}M'.format(sum(p.numel()
|
|
for p in model.parameters()) / 1e6))
|
|
|
|
cls_criterion = nn.CrossEntropyLoss()
|
|
|
|
def xent_criterion(cls_scores, targets):
|
|
cls_loss = cls_criterion(cls_scores, targets)
|
|
return cls_loss
|
|
|
|
# get optimizer
|
|
optimizer = torch.optim.SGD(
|
|
model.optim_policy(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum
|
|
)
|
|
|
|
def adjust_lr(optimizer, ep, decay_ep, gamma):
|
|
decay = gamma ** float(ep // decay_ep)
|
|
for g in optimizer.param_groups:
|
|
g['lr'] = args.lr * decay * g.get('lr_multi', 1)
|
|
|
|
start_epoch = args.start_epoch
|
|
if use_gpu:
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
# get trainer and evaluator
|
|
reid_trainer = ResNetClsTrainer(model, xent_criterion, tb_writer)
|
|
reid_evaluator = ResNetEvaluator(model)
|
|
|
|
# start training
|
|
best_rank1 = -np.inf
|
|
best_epoch = 0
|
|
for epoch in range(start_epoch, args.max_epoch):
|
|
if args.step_size > 0:
|
|
adjust_lr(optimizer, epoch + 1, args.step_size, args.gamma)
|
|
reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
|
|
|
|
# skip if not save model
|
|
if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
|
rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
|
|
is_best = rank1 > best_rank1
|
|
if is_best:
|
|
best_rank1 = rank1
|
|
best_epoch = epoch + 1
|
|
|
|
if use_gpu:
|
|
state_dict = model.module.state_dict()
|
|
else:
|
|
state_dict = model.state_dict()
|
|
save_checkpoint({
|
|
'state_dict': state_dict,
|
|
'epoch': epoch + 1,
|
|
}, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
|
|
|
|
print(
|
|
'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
|
|
|
|
|
|
def train_triplet(**kwargs):
|
|
args._parse(kwargs)
|
|
|
|
# set random seed and cudnn benchmark
|
|
torch.manual_seed(args.seed)
|
|
|
|
use_gpu = torch.cuda.is_available()
|
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
|
|
|
print('=========user config==========')
|
|
pprint(args._state_dict())
|
|
print('============end===============')
|
|
|
|
if use_gpu:
|
|
print('currently using GPU {}'.format(args.gpu))
|
|
cudnn.benchmark = True
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
|
else:
|
|
print('currently using cpu')
|
|
|
|
print('initializing dataset {}'.format(args.dataset))
|
|
dataset = data_manager.init_dataset(name=args.dataset)
|
|
|
|
pin_memory = True if use_gpu else False
|
|
|
|
tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
|
|
|
|
trainloader = DataLoader(
|
|
ImageData(dataset.train, TrainTransform(args.height, args.width)),
|
|
sampler=RandomIdentitySampler(dataset.train, args.num_instances),
|
|
batch_size=args.train_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory, drop_last=True
|
|
)
|
|
|
|
queryloader = DataLoader(
|
|
ImageData(dataset.query, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
galleryloader = DataLoader(
|
|
ImageData(dataset.gallery, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
print('initializing model ...')
|
|
model = ResNetBuilder()
|
|
print('model size: {:.5f}M'.format(sum(p.numel()
|
|
for p in model.parameters()) / 1e6))
|
|
|
|
tri_criterion = TripletLoss(margin=args.margin)
|
|
|
|
def tri_hard(feat, targets):
|
|
tri_loss, _, _ = tri_criterion(feat, targets)
|
|
return tri_loss
|
|
|
|
# get optimizer
|
|
optimizer = torch.optim.Adam(
|
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
|
)
|
|
|
|
def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma):
|
|
if ep < start_decay_ep:
|
|
return
|
|
lr_decay = gamma ** (float(ep - start_decay_ep) /
|
|
(total_ep - start_decay_ep))
|
|
for g in optimizer.param_groups:
|
|
g['lr'] = base_lr * lr_decay
|
|
|
|
start_epoch = args.start_epoch
|
|
if use_gpu:
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
# get trainer and evaluator
|
|
reid_trainer = ResNetTriTrainer(model, tri_hard, tb_writer)
|
|
reid_evaluator = ResNetEvaluator(model)
|
|
|
|
# start training
|
|
best_rank1 = -np.inf
|
|
best_epoch = 0
|
|
for epoch in range(start_epoch, args.max_epoch):
|
|
if args.step_size > 0:
|
|
adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma)
|
|
reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
|
|
|
|
# skip if not save model
|
|
if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
|
rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
|
|
is_best = rank1 > best_rank1
|
|
if is_best:
|
|
best_rank1 = rank1
|
|
best_epoch = epoch + 1
|
|
|
|
if use_gpu:
|
|
state_dict = model.module.state_dict()
|
|
else:
|
|
state_dict = model.state_dict()
|
|
save_checkpoint({
|
|
'state_dict': state_dict,
|
|
'epoch': epoch + 1,
|
|
}, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
|
|
|
|
print(
|
|
'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
|
|
|
|
|
|
def train_cls_triplet(**kwargs):
|
|
args._parse(kwargs)
|
|
|
|
# set random seed and cudnn benchmark
|
|
torch.manual_seed(args.seed)
|
|
|
|
use_gpu = torch.cuda.is_available()
|
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
|
|
|
print('=========user config==========')
|
|
pprint(args._state_dict())
|
|
print('============end===============')
|
|
|
|
if use_gpu:
|
|
print('currently using GPU {}'.format(args.gpu))
|
|
cudnn.benchmark = True
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
|
else:
|
|
print('currently using cpu')
|
|
|
|
print('initializing dataset {}'.format(args.dataset))
|
|
dataset = data_manager.init_dataset(name=args.dataset)
|
|
|
|
pin_memory = True if use_gpu else False
|
|
|
|
tb_writer = SummaryWriter(osp.join(args.save_dir, 'tb_log'))
|
|
|
|
trainloader = DataLoader(
|
|
ImageData(dataset.train, TrainTransform(args.height, args.width)),
|
|
sampler=RandomIdentitySampler(dataset.train, args.num_instances),
|
|
batch_size=args.train_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory, drop_last=True
|
|
)
|
|
|
|
queryloader = DataLoader(
|
|
ImageData(dataset.query, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
galleryloader = DataLoader(
|
|
ImageData(dataset.gallery, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
print('initializing model ...')
|
|
model = ResNetBuilder(num_classes=dataset.num_train_pids)
|
|
print('model size: {:.5f}M'.format(sum(p.numel()
|
|
for p in model.parameters()) / 1e6))
|
|
|
|
cls_criterion = nn.CrossEntropyLoss()
|
|
tri_criterion = TripletLoss(margin=args.margin)
|
|
|
|
def xent_tri_criterion(cls_scores, global_feat, targets):
|
|
cls_loss = cls_criterion(cls_scores, targets)
|
|
tri_loss, dist_ap, dist_an = tri_criterion(global_feat, targets)
|
|
loss = cls_loss + tri_loss
|
|
return loss
|
|
|
|
# get optimizer
|
|
optimizer = torch.optim.Adam(
|
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
|
)
|
|
|
|
def adjust_lr_exp(optimizer, base_lr, ep, total_ep, start_decay_ep, gamma):
|
|
if ep < start_decay_ep:
|
|
return
|
|
lr_decay = gamma ** (float(ep - start_decay_ep) /
|
|
(total_ep - start_decay_ep))
|
|
for g in optimizer.param_groups:
|
|
g['lr'] = base_lr * lr_decay
|
|
|
|
start_epoch = args.start_epoch
|
|
if use_gpu:
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
# get trainer and evaluator
|
|
reid_trainer = ResNetClsTriTrainer(model, xent_tri_criterion, tb_writer)
|
|
reid_evaluator = ResNetEvaluator(model)
|
|
|
|
# start training
|
|
best_rank1 = -np.inf
|
|
best_epoch = 0
|
|
for epoch in range(start_epoch, args.max_epoch):
|
|
if args.step_size > 0:
|
|
adjust_lr_exp(optimizer, args.lr, epoch + 1, args.max_epoch, args.step_size, args.gamma)
|
|
reid_trainer.train(epoch, trainloader, optimizer, args.print_freq)
|
|
|
|
# skip if not save model
|
|
if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (epoch + 1) == args.max_epoch:
|
|
rank1 = reid_evaluator.evaluate(queryloader, galleryloader)
|
|
is_best = rank1 > best_rank1
|
|
if is_best:
|
|
best_rank1 = rank1
|
|
best_epoch = epoch + 1
|
|
|
|
if use_gpu:
|
|
state_dict = model.module.state_dict()
|
|
else:
|
|
state_dict = model.state_dict()
|
|
save_checkpoint({
|
|
'state_dict': state_dict,
|
|
'epoch': epoch + 1,
|
|
}, is_best=is_best, save_dir=args.save_dir, filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar')
|
|
|
|
print(
|
|
'Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch))
|
|
|
|
|
|
def test(**kwargs):
|
|
args._parse(kwargs)
|
|
|
|
# set random seed and cudnn benchmark
|
|
torch.manual_seed(args.seed)
|
|
|
|
use_gpu = torch.cuda.is_available()
|
|
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
|
|
|
if use_gpu:
|
|
print('currently using GPU {}'.format(args.gpu))
|
|
cudnn.benchmark = True
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
|
else:
|
|
print('currently using cpu')
|
|
|
|
print('initializing dataset {}'.format(args.dataset))
|
|
dataset = data_manager.init_dataset(name=args.dataset)
|
|
|
|
pin_memory = True if use_gpu else False
|
|
|
|
queryloader = DataLoader(
|
|
ImageData(dataset.query, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
galleryloader = DataLoader(
|
|
ImageData(dataset.gallery, TestTransform(args.height, args.width)),
|
|
batch_size=args.test_batch, num_workers=args.workers,
|
|
pin_memory=pin_memory
|
|
)
|
|
|
|
print('loading model ...')
|
|
model = ResNetBuilder(num_classes=dataset.num_train_pids)
|
|
# ckpt = torch.load(args.load_model)
|
|
# model.load_state_dict(ckpt['state_dict'])
|
|
print('model size: {:.5f}M'.format(sum(p.numel()
|
|
for p in model.parameters()) / 1e6))
|
|
|
|
if use_gpu:
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
reid_evaluator = ResNetEvaluator(model)
|
|
reid_evaluator.evaluate(queryloader, galleryloader)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import fire
|
|
|
|
fire.Fire()
|