400 lines
12 KiB
Python
400 lines
12 KiB
Python
from __future__ import division, print_function
|
|
import sys
|
|
import copy
|
|
import time
|
|
import numpy as np
|
|
import os.path as osp
|
|
import datetime
|
|
import warnings
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import torchreid
|
|
from torchreid.utils import (
|
|
Logger, AverageMeter, check_isfile, open_all_layers, save_checkpoint,
|
|
set_random_seed, collect_env_info, open_specified_layers,
|
|
load_pretrained_weights, compute_model_complexity
|
|
)
|
|
from torchreid.data.transforms import (
|
|
Resize, Compose, ToTensor, Normalize, Random2DTranslation,
|
|
RandomHorizontalFlip
|
|
)
|
|
|
|
import models
|
|
import datasets
|
|
from default_parser import init_parser, optimizer_kwargs, lr_scheduler_kwargs
|
|
|
|
parser = init_parser()
|
|
args = parser.parse_args()
|
|
|
|
|
|
def init_dataset(use_gpu):
|
|
normalize = Normalize(
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
)
|
|
|
|
transform_tr = Compose(
|
|
[
|
|
Random2DTranslation(args.height, args.width, p=0.5),
|
|
RandomHorizontalFlip(),
|
|
ToTensor(), normalize
|
|
]
|
|
)
|
|
|
|
transform_te = Compose(
|
|
[Resize([args.height, args.width]),
|
|
ToTensor(), normalize]
|
|
)
|
|
|
|
trainset = datasets.init_dataset(
|
|
args.dataset,
|
|
root=args.root,
|
|
transform=transform_tr,
|
|
mode='train',
|
|
verbose=True
|
|
)
|
|
|
|
valset = datasets.init_dataset(
|
|
args.dataset,
|
|
root=args.root,
|
|
transform=transform_te,
|
|
mode='val',
|
|
verbose=False
|
|
)
|
|
|
|
testset = datasets.init_dataset(
|
|
args.dataset,
|
|
root=args.root,
|
|
transform=transform_te,
|
|
mode='test',
|
|
verbose=False
|
|
)
|
|
|
|
num_attrs = trainset.num_attrs
|
|
attr_dict = trainset.attr_dict
|
|
|
|
trainloader = torch.utils.data.DataLoader(
|
|
trainset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.workers,
|
|
pin_memory=use_gpu,
|
|
drop_last=True
|
|
)
|
|
|
|
valloader = torch.utils.data.DataLoader(
|
|
valset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.workers,
|
|
pin_memory=use_gpu,
|
|
drop_last=False
|
|
)
|
|
|
|
testloader = torch.utils.data.DataLoader(
|
|
testset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.workers,
|
|
pin_memory=use_gpu,
|
|
drop_last=False
|
|
)
|
|
|
|
return trainloader, valloader, testloader, num_attrs, attr_dict
|
|
|
|
|
|
def main():
|
|
global args
|
|
|
|
set_random_seed(args.seed)
|
|
use_gpu = torch.cuda.is_available() and not args.use_cpu
|
|
log_name = 'test.log' if args.evaluate else 'train.log'
|
|
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
|
|
|
print('** Arguments **')
|
|
arg_keys = list(args.__dict__.keys())
|
|
arg_keys.sort()
|
|
for key in arg_keys:
|
|
print('{}: {}'.format(key, args.__dict__[key]))
|
|
print('\n')
|
|
print('Collecting env info ...')
|
|
print('** System info **\n{}\n'.format(collect_env_info()))
|
|
|
|
if use_gpu:
|
|
torch.backends.cudnn.benchmark = True
|
|
else:
|
|
warnings.warn(
|
|
'Currently using CPU, however, GPU is highly recommended'
|
|
)
|
|
|
|
dataset_vars = init_dataset(use_gpu)
|
|
trainloader, valloader, testloader, num_attrs, attr_dict = dataset_vars
|
|
|
|
if args.weighted_bce:
|
|
print('Use weighted binary cross entropy')
|
|
print('Computing the weights ...')
|
|
bce_weights = torch.zeros(num_attrs, dtype=torch.float)
|
|
for _, attrs, _ in trainloader:
|
|
bce_weights += attrs.sum(0) # sum along the batch dim
|
|
bce_weights /= len(trainloader) * args.batch_size
|
|
print('Sample ratio for each attribute: {}'.format(bce_weights))
|
|
bce_weights = torch.exp(-1 * bce_weights)
|
|
print('BCE weights: {}'.format(bce_weights))
|
|
bce_weights = bce_weights.expand(args.batch_size, num_attrs)
|
|
criterion = nn.BCEWithLogitsLoss(weight=bce_weights)
|
|
|
|
else:
|
|
print('Use plain binary cross entropy')
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
|
print('Building model: {}'.format(args.arch))
|
|
model = models.build_model(
|
|
args.arch,
|
|
num_attrs,
|
|
pretrained=not args.no_pretrained,
|
|
use_gpu=use_gpu
|
|
)
|
|
num_params, flops = compute_model_complexity(
|
|
model, (1, 3, args.height, args.width)
|
|
)
|
|
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
|
|
|
if args.load_weights and check_isfile(args.load_weights):
|
|
load_pretrained_weights(model, args.load_weights)
|
|
|
|
if use_gpu:
|
|
model = nn.DataParallel(model).cuda()
|
|
criterion = criterion.cuda()
|
|
|
|
if args.evaluate:
|
|
test(model, testloader, attr_dict, use_gpu)
|
|
return
|
|
|
|
optimizer = torchreid.optim.build_optimizer(
|
|
model, **optimizer_kwargs(args)
|
|
)
|
|
scheduler = torchreid.optim.build_lr_scheduler(
|
|
optimizer, **lr_scheduler_kwargs(args)
|
|
)
|
|
|
|
start_epoch = args.start_epoch
|
|
best_result = -np.inf
|
|
if args.resume and check_isfile(args.resume):
|
|
checkpoint = torch.load(args.resume)
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
start_epoch = checkpoint['epoch']
|
|
best_result = checkpoint['label_mA']
|
|
print('Loaded checkpoint from "{}"'.format(args.resume))
|
|
print('- start epoch: {}'.format(start_epoch))
|
|
print('- label_mA: {}'.format(best_result))
|
|
|
|
time_start = time.time()
|
|
|
|
for epoch in range(start_epoch, args.max_epoch):
|
|
train(
|
|
epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu
|
|
)
|
|
test_outputs = test(model, testloader, attr_dict, use_gpu)
|
|
label_mA = test_outputs[0]
|
|
is_best = label_mA > best_result
|
|
if is_best:
|
|
best_result = label_mA
|
|
|
|
save_checkpoint(
|
|
{
|
|
'state_dict': model.state_dict(),
|
|
'epoch': epoch + 1,
|
|
'label_mA': label_mA,
|
|
'optimizer': optimizer.state_dict(),
|
|
},
|
|
args.save_dir,
|
|
is_best=is_best
|
|
)
|
|
|
|
elapsed = round(time.time() - time_start)
|
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
|
print('Elapsed {}'.format(elapsed))
|
|
|
|
|
|
def train(epoch, model, criterion, optimizer, scheduler, trainloader, use_gpu):
|
|
losses = AverageMeter()
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
model.train()
|
|
|
|
if (epoch + 1) <= args.fixbase_epoch and args.open_layers is not None:
|
|
print(
|
|
'* Only train {} (epoch: {}/{})'.format(
|
|
args.open_layers, epoch + 1, args.fixbase_epoch
|
|
)
|
|
)
|
|
open_specified_layers(model, args.open_layers)
|
|
else:
|
|
open_all_layers(model)
|
|
|
|
end = time.time()
|
|
for batch_idx, data in enumerate(trainloader):
|
|
data_time.update(time.time() - end)
|
|
|
|
imgs, attrs = data[0], data[1]
|
|
if use_gpu:
|
|
imgs = imgs.cuda()
|
|
attrs = attrs.cuda()
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(imgs)
|
|
loss = criterion(outputs, attrs)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
losses.update(loss.item(), imgs.size(0))
|
|
|
|
if (batch_idx+1) % args.print_freq == 0:
|
|
# estimate remaining time
|
|
num_batches = len(trainloader)
|
|
eta_seconds = batch_time.avg * (
|
|
num_batches - (batch_idx+1) + (args.max_epoch -
|
|
(epoch+1)) * num_batches
|
|
)
|
|
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
print(
|
|
'Epoch: [{0}/{1}][{2}/{3}]\t'
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
|
'Lr {lr:.6f}\t'
|
|
'Eta {eta}'.format(
|
|
epoch + 1,
|
|
args.max_epoch,
|
|
batch_idx + 1,
|
|
len(trainloader),
|
|
batch_time=batch_time,
|
|
data_time=data_time,
|
|
loss=losses,
|
|
lr=optimizer.param_groups[0]['lr'],
|
|
eta=eta_str
|
|
)
|
|
)
|
|
|
|
end = time.time()
|
|
|
|
scheduler.step()
|
|
|
|
|
|
@torch.no_grad()
|
|
def test(model, testloader, attr_dict, use_gpu):
|
|
batch_time = AverageMeter()
|
|
model.eval()
|
|
|
|
num_persons = 0
|
|
prob_thre = 0.5
|
|
ins_acc = 0
|
|
ins_prec = 0
|
|
ins_rec = 0
|
|
mA_history = {
|
|
'correct_pos': 0,
|
|
'real_pos': 0,
|
|
'correct_neg': 0,
|
|
'real_neg': 0
|
|
}
|
|
|
|
print('Testing ...')
|
|
|
|
for batch_idx, data in enumerate(testloader):
|
|
imgs, attrs, img_paths = data
|
|
if use_gpu:
|
|
imgs = imgs.cuda()
|
|
|
|
end = time.time()
|
|
orig_outputs = model(imgs)
|
|
batch_time.update(time.time() - end)
|
|
|
|
orig_outputs = orig_outputs.data.cpu().numpy()
|
|
attrs = attrs.data.numpy()
|
|
|
|
# transform raw outputs to attributes (binary codes)
|
|
outputs = copy.deepcopy(orig_outputs)
|
|
outputs[outputs < prob_thre] = 0
|
|
outputs[outputs >= prob_thre] = 1
|
|
|
|
# compute label-based metric
|
|
overlaps = outputs * attrs
|
|
mA_history['correct_pos'] += overlaps.sum(0)
|
|
mA_history['real_pos'] += attrs.sum(0)
|
|
inv_overlaps = (1-outputs) * (1-attrs)
|
|
mA_history['correct_neg'] += inv_overlaps.sum(0)
|
|
mA_history['real_neg'] += (1 - attrs).sum(0)
|
|
|
|
outputs = outputs.astype(bool)
|
|
attrs = attrs.astype(bool)
|
|
|
|
# compute instabce-based accuracy
|
|
intersect = (outputs & attrs).astype(float)
|
|
union = (outputs | attrs).astype(float)
|
|
ins_acc += (intersect.sum(1) / union.sum(1)).sum()
|
|
ins_prec += (intersect.sum(1) / outputs.astype(float).sum(1)).sum()
|
|
ins_rec += (intersect.sum(1) / attrs.astype(float).sum(1)).sum()
|
|
|
|
num_persons += imgs.size(0)
|
|
|
|
if (batch_idx+1) % args.print_freq == 0:
|
|
print(
|
|
'Processed batch {}/{}'.format(batch_idx + 1, len(testloader))
|
|
)
|
|
|
|
if args.save_prediction:
|
|
txtfile = open(osp.join(args.save_dir, 'prediction.txt'), 'a')
|
|
for idx in range(imgs.size(0)):
|
|
img_path = img_paths[idx]
|
|
probs = orig_outputs[idx, :]
|
|
labels = attrs[idx, :]
|
|
txtfile.write('{}\n'.format(img_path))
|
|
txtfile.write('*** Correct prediction ***\n')
|
|
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
|
|
if label:
|
|
attr_name = attr_dict[attr_idx]
|
|
info = '{}: {:.1%} '.format(attr_name, prob)
|
|
txtfile.write(info)
|
|
txtfile.write('\n*** Incorrect prediction ***\n')
|
|
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
|
|
if not label and prob > 0.5:
|
|
attr_name = attr_dict[attr_idx]
|
|
info = '{}: {:.1%} '.format(attr_name, prob)
|
|
txtfile.write(info)
|
|
txtfile.write('\n\n')
|
|
txtfile.close()
|
|
|
|
print(
|
|
'=> BatchTime(s)/BatchSize(img): {:.4f}/{}'.format(
|
|
batch_time.avg, args.batch_size
|
|
)
|
|
)
|
|
|
|
ins_acc /= num_persons
|
|
ins_prec /= num_persons
|
|
ins_rec /= num_persons
|
|
ins_f1 = (2*ins_prec*ins_rec) / (ins_prec+ins_rec)
|
|
|
|
term1 = mA_history['correct_pos'] / mA_history['real_pos']
|
|
term2 = mA_history['correct_neg'] / mA_history['real_neg']
|
|
label_mA_verbose = (term1+term2) * 0.5
|
|
label_mA = label_mA_verbose.mean()
|
|
|
|
print('* Results *')
|
|
print(' # test persons: {}'.format(num_persons))
|
|
print(' (instance-based) accuracy: {:.1%}'.format(ins_acc))
|
|
print(' (instance-based) precition: {:.1%}'.format(ins_prec))
|
|
print(' (instance-based) recall: {:.1%}'.format(ins_rec))
|
|
print(' (instance-based) f1-score: {:.1%}'.format(ins_f1))
|
|
print(' (label-based) mean accuracy: {:.1%}'.format(label_mA))
|
|
print(' mA for each attribute: {}'.format(label_mA_verbose))
|
|
|
|
return label_mA, ins_acc, ins_prec, ins_rec, ins_f1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|