deep-person-reid/projects/attribute_recognition/main.py

400 lines
12 KiB
Python

from __future__ import division, print_function
import sys
import copy
import time
import numpy as np
import os.path as osp
import datetime
import warnings
import torch
import torch.nn as nn
import torchreid
from torchreid.utils import (
Logger, AverageMeter, check_isfile, open_all_layers, save_checkpoint,
set_random_seed, collect_env_info, open_specified_layers,
load_pretrained_weights, compute_model_complexity
)
from torchreid.data.transforms import (
Resize, Compose, ToTensor, Normalize, Random2DTranslation,
RandomHorizontalFlip
)
import models
import datasets
from default_parser import init_parser, optimizer_kwargs, lr_scheduler_kwargs
parser = init_parser()
args = parser.parse_args()
def init_dataset(use_gpu):
normalize = Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform_tr = Compose(
[
Random2DTranslation(args.height, args.width, p=0.5),
RandomHorizontalFlip(),
ToTensor(), normalize
]
)
transform_te = Compose(
[Resize([args.height, args.width]),
ToTensor(), normalize]
)
trainset = datasets.init_dataset(
args.dataset,
root=args.root,
transform=transform_tr,
mode='train',
verbose=True
)
valset = datasets.init_dataset(
args.dataset,
root=args.root,
transform=transform_te,
mode='val',
verbose=False
)
testset = datasets.init_dataset(
args.dataset,
root=args.root,
transform=transform_te,
mode='test',
verbose=False
)
num_attrs = trainset.num_attrs
attr_dict = trainset.attr_dict
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=use_gpu,
drop_last=True
)
valloader = torch.utils.data.DataLoader(
valset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=use_gpu,
drop_last=False
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=use_gpu,
drop_last=False
)
return trainloader, valloader, testloader, num_attrs, attr_dict
def main():
global args
set_random_seed(args.seed)
use_gpu = torch.cuda.is_available() and not args.use_cpu
log_name = 'test.log' if args.evaluate else 'train.log'
sys.stdout = Logger(osp.join(args.save_dir, log_name))
print('** Arguments **')
arg_keys = list(args.__dict__.keys())
arg_keys.sort()
for key in arg_keys:
print('{}: {}'.format(key, args.__dict__[key]))
print('\n')
print('Collecting env info ...')
print('** System info **\n{}\n'.format(collect_env_info()))
if use_gpu:
torch.backends.cudnn.benchmark = True
else:
warnings.warn(
'Currently using CPU, however, GPU is highly recommended'
)
dataset_vars = init_dataset(use_gpu)
trainloader, valloader, testloader, num_attrs, attr_dict = dataset_vars
if args.weighted_bce:
print('Use weighted binary cross entropy')
print('Computing the weights ...')
bce_weights = torch.zeros(num_attrs, dtype=torch.float)
for _, attrs, _ in trainloader:
bce_weights += attrs.sum(0) # sum along the batch dim
bce_weights /= len(trainloader) * args.batch_size
print('Sample ratio for each attribute: {}'.format(bce_weights))
bce_weights = torch.exp(-1 * bce_weights)
print('BCE weights: {}'.format(bce_weights))
bce_weights = bce_weights.expand(args.batch_size, num_attrs)
criterion = nn.BCEWithLogitsLoss(weight=bce_weights)
else:
print('Use plain binary cross entropy')
criterion = nn.BCEWithLogitsLoss()
print('Building model: {}'.format(args.arch))
model = models.build_model(
args.arch,
num_attrs,
pretrained=not args.no_pretrained,
use_gpu=use_gpu
)
num_params, flops = compute_model_complexity(
model, (1, 3, args.height, args.width)
)
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
if args.load_weights and check_isfile(args.load_weights):
load_pretrained_weights(model, args.load_weights)
if use_gpu:
model = nn.DataParallel(model).cuda()
criterion = criterion.cuda()
if args.evaluate:
test(model, testloader, attr_dict, use_gpu)
return
optimizer = torchreid.optim.build_optimizer(
model, **optimizer_kwargs(args)
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer, **lr_scheduler_kwargs(args)
)
start_epoch = args.start_epoch
best_result = -np.inf
if args.resume and check_isfile(args.resume):
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
best_result = checkpoint['label_mA']
print('Loaded checkpoint from "{}"'.format(args.resume))
print('- start epoch: {}'.format(start_epoch))
print('- label_mA: {}'.format(best_result))
time_start = time.time()
for epoch in range(start_epoch, args.max_epoch):
train(
epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu
)
test_outputs = test(model, testloader, attr_dict, use_gpu)
label_mA = test_outputs[0]
is_best = label_mA > best_result
if is_best:
best_result = label_mA
save_checkpoint(
{
'state_dict': model.state_dict(),
'epoch': epoch + 1,
'label_mA': label_mA,
'optimizer': optimizer.state_dict(),
},
args.save_dir,
is_best=is_best
)
elapsed = round(time.time() - time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print('Elapsed {}'.format(elapsed))
def train(epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu):
losses = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
model.train()
if (epoch + 1) <= args.fixbase_epoch and args.open_layers is not None:
print(
'* Only train {} (epoch: {}/{})'.format(
args.open_layers, epoch + 1, args.fixbase_epoch
)
)
open_specified_layers(model, args.open_layers)
else:
open_all_layers(model)
end = time.time()
for batch_idx, data in enumerate(trainloader):
data_time.update(time.time() - end)
imgs, attrs = data[0], data[1]
if use_gpu:
imgs = imgs.cuda()
attrs = attrs.cuda()
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, attrs)
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), imgs.size(0))
if (batch_idx+1) % args.print_freq == 0:
# estimate remaining time
num_batches = len(trainloader)
eta_seconds = batch_time.avg * (
num_batches - (batch_idx+1) + (args.max_epoch -
(epoch+1)) * num_batches
)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
'Epoch: [{0}/{1}][{2}/{3}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Lr {lr:.6f}\t'
'Eta {eta}'.format(
epoch + 1,
args.max_epoch,
batch_idx + 1,
len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
lr=optimizer.param_groups[0]['lr'],
eta=eta_str
)
)
end = time.time()
scheduler.step()
@torch.no_grad()
def test(model, testloader, attr_dict, use_gpu):
batch_time = AverageMeter()
model.eval()
num_persons = 0
prob_thre = 0.5
ins_acc = 0
ins_prec = 0
ins_rec = 0
mA_history = {
'correct_pos': 0,
'real_pos': 0,
'correct_neg': 0,
'real_neg': 0
}
print('Testing ...')
for batch_idx, data in enumerate(testloader):
imgs, attrs, img_paths = data
if use_gpu:
imgs = imgs.cuda()
end = time.time()
orig_outputs = model(imgs)
batch_time.update(time.time() - end)
orig_outputs = orig_outputs.data.cpu().numpy()
attrs = attrs.data.numpy()
# transform raw outputs to attributes (binary codes)
outputs = copy.deepcopy(orig_outputs)
outputs[outputs < prob_thre] = 0
outputs[outputs >= prob_thre] = 1
# compute label-based metric
overlaps = outputs * attrs
mA_history['correct_pos'] += overlaps.sum(0)
mA_history['real_pos'] += attrs.sum(0)
inv_overlaps = (1-outputs) * (1-attrs)
mA_history['correct_neg'] += inv_overlaps.sum(0)
mA_history['real_neg'] += (1 - attrs).sum(0)
outputs = outputs.astype(bool)
attrs = attrs.astype(bool)
# compute instabce-based accuracy
intersect = (outputs & attrs).astype(float)
union = (outputs | attrs).astype(float)
ins_acc += (intersect.sum(1) / union.sum(1)).sum()
ins_prec += (intersect.sum(1) / outputs.astype(float).sum(1)).sum()
ins_rec += (intersect.sum(1) / attrs.astype(float).sum(1)).sum()
num_persons += imgs.size(0)
if (batch_idx+1) % args.print_freq == 0:
print(
'Processed batch {}/{}'.format(batch_idx + 1, len(testloader))
)
if args.save_prediction:
txtfile = open(osp.join(args.save_dir, 'prediction.txt'), 'a')
for idx in range(imgs.size(0)):
img_path = img_paths[idx]
probs = orig_outputs[idx, :]
labels = attrs[idx, :]
txtfile.write('{}\n'.format(img_path))
txtfile.write('*** Correct prediction ***\n')
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
if label:
attr_name = attr_dict[attr_idx]
info = '{}: {:.1%} '.format(attr_name, prob)
txtfile.write(info)
txtfile.write('\n*** Incorrect prediction ***\n')
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
if not label and prob > 0.5:
attr_name = attr_dict[attr_idx]
info = '{}: {:.1%} '.format(attr_name, prob)
txtfile.write(info)
txtfile.write('\n\n')
txtfile.close()
print(
'=> BatchTime(s)/BatchSize(img): {:.4f}/{}'.format(
batch_time.avg, args.batch_size
)
)
ins_acc /= num_persons
ins_prec /= num_persons
ins_rec /= num_persons
ins_f1 = (2*ins_prec*ins_rec) / (ins_prec+ins_rec)
term1 = mA_history['correct_pos'] / mA_history['real_pos']
term2 = mA_history['correct_neg'] / mA_history['real_neg']
label_mA_verbose = (term1+term2) * 0.5
label_mA = label_mA_verbose.mean()
print('* Results *')
print(' # test persons: {}'.format(num_persons))
print(' (instance-based) accuracy: {:.1%}'.format(ins_acc))
print(' (instance-based) precition: {:.1%}'.format(ins_prec))
print(' (instance-based) recall: {:.1%}'.format(ins_rec))
print(' (instance-based) f1-score: {:.1%}'.format(ins_f1))
print(' (label-based) mean accuracy: {:.1%}'.format(label_mA))
print(' mA for each attribute: {}'.format(label_mA_verbose))
return label_mA, ins_acc, ins_prec, ins_rec, ins_f1
if __name__ == '__main__':
main()