add xent+htri trainer
parent
3aa71a0bea
commit
2949021259
|
@ -1,4 +1,4 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function, absolute_import
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import print_function, absolute_import
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
|
|
@ -9,8 +9,9 @@ import torchvision
|
|||
__all__ = ['DenseNet121']
|
||||
|
||||
class DenseNet121(nn.Module):
|
||||
def __init__(self, num_classes, **kwargs):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(DenseNet121, self).__init__()
|
||||
self.loss = loss
|
||||
densenet121 = torchvision.models.densenet121(pretrained=True)
|
||||
self.base = densenet121.features
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
|
@ -18,8 +19,14 @@ class DenseNet121(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
x = x.view(x.size(0), -1)
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return x
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
|
@ -9,8 +9,9 @@ import torchvision
|
|||
__all__ = ['ResNet50', 'ResNet50M']
|
||||
|
||||
class ResNet50(nn.Module):
|
||||
def __init__(self, num_classes, **kwargs):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(ResNet50, self).__init__()
|
||||
self.loss = loss
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
|
@ -18,11 +19,17 @@ class ResNet50(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
x = x.view(x.size(0), -1)
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return x
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
||||
|
||||
class ResNet50M(nn.Module):
|
||||
"""ResNet50 + mid-level features.
|
||||
|
@ -31,8 +38,9 @@ class ResNet50M(nn.Module):
|
|||
Qian et al. The Devil is in the Middle: Exploiting Mid-level Representations for
|
||||
Cross-Domain Instance Matching. arXiv:1711.08106.
|
||||
"""
|
||||
def __init__(self, num_classes=0, **kwargs):
|
||||
def __init__(self, num_classes=0, loss={'xent'}, **kwargs):
|
||||
super(ResNet50M, self).__init__()
|
||||
self.loss = loss
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||
self.layers1 = nn.Sequential(self.base[0], self.base[1], self.base[2])
|
||||
|
@ -65,4 +73,17 @@ class ResNet50M(nn.Module):
|
|||
if not self.training:
|
||||
return combofeat
|
||||
prelogits = self.classifier(combofeat)
|
||||
return prelogits
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return prelogits
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return prelogits, combofeat
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from __future__ import absolute_import
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ parser.add_argument('--start-epoch', default=0, type=int,
|
|||
help="manual epoch number (useful on restarts)")
|
||||
parser.add_argument('--train-batch', default=32, type=int,
|
||||
help="train batch size")
|
||||
parser.add_argument('--test-batch', default=50, type=int, help="test batch size")
|
||||
parser.add_argument('--test-batch', default=32, type=int, help="test batch size")
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
|
||||
help="initial learning rate")
|
||||
parser.add_argument('--stepsize', default=20, type=int,
|
||||
|
@ -119,7 +119,7 @@ def main():
|
|||
)
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids)
|
||||
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent'})
|
||||
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
|
||||
|
||||
criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
|
||||
|
|
|
@ -0,0 +1,259 @@
|
|||
from __future__ import absolute_import
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.autograd import Variable
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
import data_manager
|
||||
from dataset_loader import ImageDataset
|
||||
import transforms as T
|
||||
import models
|
||||
from losses import CrossEntropyLabelSmooth, TripletLoss
|
||||
from utils import AverageMeter, Logger, save_checkpoint
|
||||
from eval_metrics import evaluate
|
||||
from samplers import RandomIdentitySampler
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss and hard triplet loss')
|
||||
# Datasets
|
||||
parser.add_argument('-d', '--dataset', type=str, default='market1501',
|
||||
choices=data_manager.get_names())
|
||||
parser.add_argument('-j', '--workers', default=4, type=int,
|
||||
help="number of data loading workers (default: 4)")
|
||||
parser.add_argument('--height', type=int, default=256,
|
||||
help="height of an image (default: 256)")
|
||||
parser.add_argument('--width', type=int, default=128,
|
||||
help="width of an image (default: 128)")
|
||||
# Optimization options
|
||||
parser.add_argument('--max-epoch', default=60, type=int,
|
||||
help="maximum epochs to run")
|
||||
parser.add_argument('--start-epoch', default=0, type=int,
|
||||
help="manual epoch number (useful on restarts)")
|
||||
parser.add_argument('--train-batch', default=32, type=int,
|
||||
help="train batch size")
|
||||
parser.add_argument('--test-batch', default=32, type=int, help="test batch size")
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
|
||||
help="initial learning rate")
|
||||
parser.add_argument('--stepsize', default=20, type=int,
|
||||
help="stepsize to decay learning rate (>0 means this is enabled)")
|
||||
parser.add_argument('--gamma', default=0.1, type=float,
|
||||
help="learning rate decay")
|
||||
parser.add_argument('--weight-decay', default=5e-04, type=float,
|
||||
help="weight decay (default: 5e-04)")
|
||||
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
|
||||
parser.add_argument('--num-instances', type=int, default=4,
|
||||
help="number of instances per identity")
|
||||
# Architecture
|
||||
parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.get_names())
|
||||
# Miscs
|
||||
parser.add_argument('--print-freq', type=int, default=10, help="print frequency")
|
||||
parser.add_argument('--seed', type=int, default=1, help="manual seed")
|
||||
parser.add_argument('--resume', type=str, default='', metavar='PATH')
|
||||
parser.add_argument('--evaluate', action='store_true', help="evaluation only")
|
||||
parser.add_argument('--eval-step', type=int, default=-1,
|
||||
help="run evaluation for every N epochs (set to -1 to test after training)")
|
||||
parser.add_argument('--save-dir', type=str, default='log')
|
||||
parser.add_argument('--use-cpu', action='store_true', help="use cpu")
|
||||
parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
torch.manual_seed(args.seed)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if args.use_cpu: use_gpu = False
|
||||
|
||||
if not args.evaluate:
|
||||
sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
|
||||
else:
|
||||
sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
|
||||
print("==========\nArgs:{}\n==========".format(args))
|
||||
|
||||
if use_gpu:
|
||||
print("Currently using GPU")
|
||||
cudnn.benchmark = True
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
else:
|
||||
print("Currently using CPU (GPU is highly recommended)")
|
||||
|
||||
print("Initializing dataset {}".format(args.dataset))
|
||||
dataset = data_manager.init_dataset(name=args.dataset)
|
||||
|
||||
transform_train = T.Compose([
|
||||
T.Random2DTranslation(args.height, args.width),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
transform_test = T.Compose([
|
||||
T.Resize((args.height, args.width)),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
pin_memory = True if use_gpu else False
|
||||
|
||||
trainloader = DataLoader(
|
||||
ImageDataset(dataset.train, transform=transform_train),
|
||||
sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances),
|
||||
batch_size=args.train_batch, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=True,
|
||||
)
|
||||
|
||||
queryloader = DataLoader(
|
||||
ImageDataset(dataset.query, transform=transform_test),
|
||||
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=False,
|
||||
)
|
||||
|
||||
galleryloader = DataLoader(
|
||||
ImageDataset(dataset.gallery, transform=transform_test),
|
||||
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
|
||||
pin_memory=pin_memory, drop_last=False,
|
||||
)
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
|
||||
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))
|
||||
|
||||
criterion_xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu)
|
||||
criterion_htri = TripletLoss(margin=args.margin)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
if args.stepsize > 0:
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
|
||||
start_epoch = args.start_epoch
|
||||
|
||||
if args.resume:
|
||||
print("Loading checkpoint from '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
start_epoch = checkpoint['epoch']
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
if args.evaluate:
|
||||
print("Evaluate only")
|
||||
test(model, queryloader, galleryloader, use_gpu)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
best_rank1 = -np.inf
|
||||
|
||||
for epoch in range(start_epoch, args.max_epoch):
|
||||
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
|
||||
|
||||
train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu)
|
||||
|
||||
if args.stepsize > 0: scheduler.step()
|
||||
|
||||
if args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch:
|
||||
print("==> Test")
|
||||
rank1 = test(model, queryloader, galleryloader, use_gpu)
|
||||
is_best = rank1 > best_rank1
|
||||
if is_best: best_rank1 = rank1
|
||||
|
||||
save_checkpoint({
|
||||
'state_dict': model.state_dict(),
|
||||
'rank1': rank1,
|
||||
'epoch': epoch,
|
||||
}, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
|
||||
|
||||
elapsed = round(time.time() - start_time)
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
|
||||
|
||||
def train(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu):
|
||||
model.train()
|
||||
losses = AverageMeter()
|
||||
|
||||
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
|
||||
if use_gpu:
|
||||
imgs, pids = imgs.cuda(), pids.cuda()
|
||||
imgs, pids = Variable(imgs), Variable(pids)
|
||||
outputs, features = model(imgs)
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
loss = xent_loss + htri_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.update(loss.data[0], pids.size(0))
|
||||
|
||||
if (batch_idx+1) % args.print_freq == 0:
|
||||
print("Batch {}/{}\t Loss {:.6f} ({:.6f})".format(batch_idx+1, len(trainloader), losses.val, losses.avg))
|
||||
|
||||
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
|
||||
model.eval()
|
||||
|
||||
qf, q_pids, q_camids = [], [], []
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
|
||||
if use_gpu:
|
||||
imgs = imgs.cuda()
|
||||
imgs = Variable(imgs)
|
||||
features = model(imgs)
|
||||
features = features.data.cpu()
|
||||
qf.append(features)
|
||||
q_pids.extend(pids)
|
||||
q_camids.extend(camids)
|
||||
qf = torch.cat(qf, 0)
|
||||
q_pids = np.asarray(q_pids)
|
||||
q_camids = np.asarray(q_camids)
|
||||
|
||||
print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
|
||||
|
||||
gf, g_pids, g_camids = [], [], []
|
||||
for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
|
||||
if use_gpu:
|
||||
imgs = imgs.cuda()
|
||||
imgs = Variable(imgs)
|
||||
features = model(imgs)
|
||||
features = features.data.cpu()
|
||||
gf.append(features)
|
||||
g_pids.extend(pids)
|
||||
g_camids.extend(camids)
|
||||
gf = torch.cat(gf, 0)
|
||||
g_pids = np.asarray(g_pids)
|
||||
g_camids = np.asarray(g_camids)
|
||||
|
||||
print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
|
||||
print("Computing distance matrix")
|
||||
|
||||
m, n = qf.size(0), gf.size(0)
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = distmat.numpy()
|
||||
|
||||
print("Computing CMC and mAP")
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
print("Results ----------")
|
||||
print("mAP: {:.1%}".format(mAP))
|
||||
print("CMC curve")
|
||||
for r in ranks:
|
||||
print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
|
||||
print("------------------")
|
||||
|
||||
return cmc[0]
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -40,17 +40,4 @@ class Random2DTranslation(object):
|
|||
return croped_img
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-impath', type=str)
|
||||
parser.add_argument('-nlevel', type=float, default=0.1)
|
||||
args = parser.parse_args()
|
||||
|
||||
RC = RandomOcclusion(nlevel=args.nlevel, p=1)
|
||||
im = Image.open(args.impath)
|
||||
transformed_im = RC(im)
|
||||
|
||||
basename = osp.basename(args.impath)
|
||||
save_name = osp.splitext(basename)[0] + '_nlevel_' + str(args.nlevel) + osp.splitext(basename)[1]
|
||||
transformed_im.save(save_name)"""
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue