add xent+htri trainer

pull/17/head
KaiyangZhou 2018-03-12 13:53:08 +00:00
parent 3aa71a0bea
commit 2949021259
8 changed files with 305 additions and 29 deletions

View File

@ -1,4 +1,4 @@
from __future__ import absolute_import
from __future__ import print_function, absolute_import
import os
from PIL import Image

View File

@ -1,4 +1,4 @@
from __future__ import absolute_import
from __future__ import print_function, absolute_import
import numpy as np
import copy

View File

@ -9,8 +9,9 @@ import torchvision
__all__ = ['DenseNet121']
class DenseNet121(nn.Module):
def __init__(self, num_classes, **kwargs):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(DenseNet121, self).__init__()
self.loss = loss
densenet121 = torchvision.models.densenet121(pretrained=True)
self.base = densenet121.features
self.classifier = nn.Linear(1024, num_classes)
@ -18,8 +19,14 @@ class DenseNet121(nn.Module):
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(x.size(0), -1)
f = x.view(x.size(0), -1)
if not self.training:
return x
x = self.classifier(x)
return x
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
else:
raise KeyError("Unknown loss: {}".format(self.loss))

View File

@ -9,8 +9,9 @@ import torchvision
__all__ = ['ResNet50', 'ResNet50M']
class ResNet50(nn.Module):
def __init__(self, num_classes, **kwargs):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(ResNet50, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained=True)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
@ -18,11 +19,17 @@ class ResNet50(nn.Module):
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(x.size(0), -1)
f = x.view(x.size(0), -1)
if not self.training:
return x
x = self.classifier(x)
return x
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
else:
raise KeyError("Unknown loss: {}".format(self.loss))
class ResNet50M(nn.Module):
"""ResNet50 + mid-level features.
@ -31,8 +38,9 @@ class ResNet50M(nn.Module):
Qian et al. The Devil is in the Middle: Exploiting Mid-level Representations for
Cross-Domain Instance Matching. arXiv:1711.08106.
"""
def __init__(self, num_classes=0, **kwargs):
def __init__(self, num_classes=0, loss={'xent'}, **kwargs):
super(ResNet50M, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained=True)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.layers1 = nn.Sequential(self.base[0], self.base[1], self.base[2])
@ -65,4 +73,17 @@ class ResNet50M(nn.Module):
if not self.training:
return combofeat
prelogits = self.classifier(combofeat)
return prelogits
if self.loss == {'xent'}:
return prelogits
elif self.loss == {'xent', 'htri'}:
return prelogits, combofeat
else:
raise KeyError("Unknown loss: {}".format(self.loss))

View File

@ -1,4 +1,6 @@
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
import torch

View File

@ -39,7 +39,7 @@ parser.add_argument('--start-epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
parser.add_argument('--train-batch', default=32, type=int,
help="train batch size")
parser.add_argument('--test-batch', default=50, type=int, help="test batch size")
parser.add_argument('--test-batch', default=32, type=int, help="test batch size")
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
help="initial learning rate")
parser.add_argument('--stepsize', default=20, type=int,
@ -119,7 +119,7 @@ def main():
)
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids)
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'})
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)

View File

@ -0,0 +1,259 @@
from __future__ import absolute_import
import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import lr_scheduler
import data_manager
from dataset_loader import ImageDataset
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, TripletLoss
from utils import AverageMeter, Logger, save_checkpoint
from eval_metrics import evaluate
from samplers import RandomIdentitySampler
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss and hard triplet loss')
# Datasets
parser.add_argument('-d', '--dataset', type=str, default='market1501',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,
help="number of data loading workers (default: 4)")
parser.add_argument('--height', type=int, default=256,
help="height of an image (default: 256)")
parser.add_argument('--width', type=int, default=128,
help="width of an image (default: 128)")
# Optimization options
parser.add_argument('--max-epoch', default=60, type=int,
help="maximum epochs to run")
parser.add_argument('--start-epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
parser.add_argument('--train-batch', default=32, type=int,
help="train batch size")
parser.add_argument('--test-batch', default=32, type=int, help="test batch size")
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
help="initial learning rate")
parser.add_argument('--stepsize', default=20, type=int,
help="stepsize to decay learning rate (>0 means this is enabled)")
parser.add_argument('--gamma', default=0.1, type=float,
help="learning rate decay")
parser.add_argument('--weight-decay', default=5e-04, type=float,
help="weight decay (default: 5e-04)")
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--num-instances', type=int, default=4,
help="number of instances per identity")
# Architecture
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
# Miscs
parser.add_argument('--print-freq', type=int, default=10, help="print frequency")
parser.add_argument('--seed', type=int, default=1, help="manual seed")
parser.add_argument('--resume', type=str, default='', metavar='PATH')
parser.add_argument('--evaluate', action='store_true', help="evaluation only")
parser.add_argument('--eval-step', type=int, default=-1,
help="run evaluation for every N epochs (set to -1 to test after training)")
parser.add_argument('--save-dir', type=str, default='log')
parser.add_argument('--use-cpu', action='store_true', help="use cpu")
parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
def main():
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
if not args.evaluate:
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
else:
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
print("==========\nArgs:{}\n==========".format(args))
if use_gpu:
print("Currently using GPU")
cudnn.benchmark = True
torch.cuda.manual_seed_all(args.seed)
else:
print("Currently using CPU (GPU is highly recommended)")
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_dataset(name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pin_memory = True if use_gpu else False
trainloader = DataLoader(
ImageDataset(dataset.train, transform=transform_train),
sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances),
batch_size=args.train_batch, num_workers=args.workers,
pin_memory=pin_memory, drop_last=True,
)
queryloader = DataLoader(
ImageDataset(dataset.query, transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
galleryloader = DataLoader(
ImageDataset(dataset.gallery, transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
criterion_htri = TripletLoss(margin=args.margin)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.stepsize > 0:
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
start_epoch = args.start_epoch
if args.resume:
print("Loading checkpoint from '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
if use_gpu:
model = nn.DataParallel(model).cuda()
if args.evaluate:
print("Evaluate only")
test(model, queryloader, galleryloader, use_gpu)
return
start_time = time.time()
best_rank1 = -np.inf
for epoch in range(start_epoch, args.max_epoch):
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
if args.stepsize > 0: scheduler.step()
if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch:
print("==> Test")
rank1 = test(model, queryloader, galleryloader, use_gpu)
is_best = rank1 > best_rank1
if is_best: best_rank1 = rank1
save_checkpoint({
'state_dict': model.state_dict(),
'rank1': rank1,
'epoch': epoch,
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
model.train()
losses = AverageMeter()
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
if use_gpu:
imgs, pids = imgs.cuda(), pids.cuda()
imgs, pids = Variable(imgs), Variable(pids)
outputs, features = model(imgs)
xent_loss = criterion_xent(outputs, pids)
htri_loss = criterion_htri(features, pids)
loss = xent_loss + htri_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data[0], pids.size(0))
if (batch_idx+1) % args.print_freq == 0:
print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx+1, len(trainloader), losses.val, losses.avg))
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
model.eval()
qf, q_pids, q_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
if use_gpu:
imgs = imgs.cuda()
imgs = Variable(imgs)
features = model(imgs)
features = features.data.cpu()
qf.append(features)
q_pids.extend(pids)
q_camids.extend(camids)
qf = torch.cat(qf, 0)
q_pids = np.asarray(q_pids)
q_camids = np.asarray(q_camids)
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
gf, g_pids, g_camids = [], [], []
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
if use_gpu:
imgs = imgs.cuda()
imgs = Variable(imgs)
features = model(imgs)
features = features.data.cpu()
gf.append(features)
g_pids.extend(pids)
g_camids.extend(camids)
gf = torch.cat(gf, 0)
g_pids = np.asarray(g_pids)
g_camids = np.asarray(g_camids)
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
print("Computing distance matrix")
m, n = qf.size(0), gf.size(0)
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
print("Computing CMC and mAP")
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
print("------------------")
return cmc[0]
if __name__ == '__main__':
main()

View File

@ -40,17 +40,4 @@ class Random2DTranslation(object):
return croped_img
if __name__ == '__main__':
"""import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-impath', type=str)
parser.add_argument('-nlevel', type=float, default=0.1)
args = parser.parse_args()
RC = RandomOcclusion(nlevel=args.nlevel, p=1)
im = Image.open(args.impath)
transformed_im = RC(im)
basename = osp.basename(args.impath)
save_name = osp.splitext(basename)[0] + '_nlevel_' + str(args.nlevel) + osp.splitext(basename)[1]
transformed_im.save(save_name)"""
pass