2019-03-20 01:26:08 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import division
|
|
|
|
|
|
|
|
import time
|
|
|
|
import datetime
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torchreid
|
|
|
|
from torchreid.engine import engine
|
|
|
|
from torchreid.losses import CrossEntropyLoss, TripletLoss
|
|
|
|
from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers
|
|
|
|
from torchreid import metrics
|
|
|
|
|
|
|
|
|
|
|
|
class ImageTripletEngine(engine.Engine):
|
2019-05-24 22:36:10 +08:00
|
|
|
r"""Triplet-loss engine for image-reid.
|
2019-03-24 07:09:39 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
|
|
|
|
or ``torchreid.data.VideoDataManager``.
|
|
|
|
model (nn.Module): model instance.
|
|
|
|
optimizer (Optimizer): an Optimizer.
|
|
|
|
margin (float, optional): margin for triplet loss. Default is 0.3.
|
|
|
|
weight_t (float, optional): weight for triplet loss. Default is 1.
|
|
|
|
weight_x (float, optional): weight for softmax loss. Default is 1.
|
|
|
|
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
2019-08-26 17:34:31 +08:00
|
|
|
use_gpu (bool, optional): use gpu. Default is True.
|
2019-03-24 07:09:39 +08:00
|
|
|
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
|
|
|
|
|
|
|
|
Examples::
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torchreid
|
|
|
|
datamanager = torchreid.data.ImageDataManager(
|
|
|
|
root='path/to/reid-data',
|
|
|
|
sources='market1501',
|
|
|
|
height=256,
|
|
|
|
width=128,
|
|
|
|
combineall=False,
|
|
|
|
batch_size=32,
|
|
|
|
num_instances=4,
|
|
|
|
train_sampler='RandomIdentitySampler' # this is important
|
|
|
|
)
|
|
|
|
model = torchreid.models.build_model(
|
|
|
|
name='resnet50',
|
|
|
|
num_classes=datamanager.num_train_pids,
|
|
|
|
loss='triplet'
|
|
|
|
)
|
|
|
|
model = model.cuda()
|
|
|
|
optimizer = torchreid.optim.build_optimizer(
|
|
|
|
model, optim='adam', lr=0.0003
|
|
|
|
)
|
|
|
|
scheduler = torchreid.optim.build_lr_scheduler(
|
|
|
|
optimizer,
|
|
|
|
lr_scheduler='single_step',
|
|
|
|
stepsize=20
|
|
|
|
)
|
|
|
|
engine = torchreid.engine.ImageTripletEngine(
|
|
|
|
datamanager, model, optimizer, margin=0.3,
|
|
|
|
weight_t=0.7, weight_x=1, scheduler=scheduler
|
|
|
|
)
|
|
|
|
engine.run(
|
|
|
|
max_epoch=60,
|
|
|
|
save_dir='log/resnet50-triplet-market1501',
|
|
|
|
print_freq=10
|
|
|
|
)
|
|
|
|
"""
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-03-21 20:53:21 +08:00
|
|
|
def __init__(self, datamanager, model, optimizer, margin=0.3,
|
2019-08-26 17:34:31 +08:00
|
|
|
weight_t=1, weight_x=1, scheduler=None, use_gpu=True,
|
2019-03-20 01:26:08 +08:00
|
|
|
label_smooth=True):
|
2019-08-26 17:34:31 +08:00
|
|
|
super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
2019-03-20 01:26:08 +08:00
|
|
|
|
|
|
|
self.weight_t = weight_t
|
|
|
|
self.weight_x = weight_x
|
|
|
|
|
|
|
|
self.criterion_t = TripletLoss(margin=margin)
|
|
|
|
self.criterion_x = CrossEntropyLoss(
|
2019-03-21 20:53:21 +08:00
|
|
|
num_classes=self.datamanager.num_train_pids,
|
2019-03-20 01:26:08 +08:00
|
|
|
use_gpu=self.use_gpu,
|
|
|
|
label_smooth=label_smooth
|
|
|
|
)
|
|
|
|
|
2019-05-24 22:36:10 +08:00
|
|
|
def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
|
2019-03-20 01:26:08 +08:00
|
|
|
losses_t = AverageMeter()
|
|
|
|
losses_x = AverageMeter()
|
|
|
|
accs = AverageMeter()
|
|
|
|
batch_time = AverageMeter()
|
|
|
|
data_time = AverageMeter()
|
|
|
|
|
|
|
|
self.model.train()
|
2019-05-24 22:36:10 +08:00
|
|
|
if (epoch+1)<=fixbase_epoch and open_layers is not None:
|
|
|
|
print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
|
2019-03-20 01:26:08 +08:00
|
|
|
open_specified_layers(self.model, open_layers)
|
|
|
|
else:
|
|
|
|
open_all_layers(self.model)
|
|
|
|
|
2019-08-22 06:02:05 +08:00
|
|
|
num_batches = len(trainloader)
|
2019-03-20 01:26:08 +08:00
|
|
|
end = time.time()
|
|
|
|
for batch_idx, data in enumerate(trainloader):
|
|
|
|
data_time.update(time.time() - end)
|
|
|
|
|
|
|
|
imgs, pids = self._parse_data_for_train(data)
|
|
|
|
if self.use_gpu:
|
|
|
|
imgs = imgs.cuda()
|
|
|
|
pids = pids.cuda()
|
|
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
outputs, features = self.model(imgs)
|
|
|
|
loss_t = self._compute_loss(self.criterion_t, features, pids)
|
|
|
|
loss_x = self._compute_loss(self.criterion_x, outputs, pids)
|
|
|
|
loss = self.weight_t * loss_t + self.weight_x * loss_x
|
|
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
|
|
|
|
losses_t.update(loss_t.item(), pids.size(0))
|
|
|
|
losses_x.update(loss_x.item(), pids.size(0))
|
|
|
|
accs.update(metrics.accuracy(outputs, pids)[0].item())
|
|
|
|
|
2019-04-19 18:52:02 +08:00
|
|
|
if (batch_idx+1) % print_freq == 0:
|
2019-05-24 22:36:10 +08:00
|
|
|
# estimate remaining time
|
|
|
|
eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches)
|
|
|
|
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
|
|
print('Epoch: [{0}/{1}][{2}/{3}]\t'
|
2019-03-20 01:26:08 +08:00
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
|
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
2019-05-24 22:36:10 +08:00
|
|
|
'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
|
|
|
'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
|
|
|
|
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
|
|
|
|
'Lr {lr:.6f}\t'
|
2019-08-22 06:02:05 +08:00
|
|
|
'eta {eta}'.format(
|
|
|
|
epoch+1, max_epoch, batch_idx+1, num_batches,
|
2019-03-20 01:26:08 +08:00
|
|
|
batch_time=batch_time,
|
|
|
|
data_time=data_time,
|
|
|
|
loss_t=losses_t,
|
|
|
|
loss_x=losses_x,
|
2019-05-24 22:36:10 +08:00
|
|
|
acc=accs,
|
|
|
|
lr=self.optimizer.param_groups[0]['lr'],
|
|
|
|
eta=eta_str
|
|
|
|
)
|
|
|
|
)
|
2019-08-23 05:41:21 +08:00
|
|
|
|
|
|
|
if self.writer is not None:
|
|
|
|
n_iter = epoch * num_batches + batch_idx
|
|
|
|
self.writer.add_scalar('Train/Time', batch_time.avg, n_iter)
|
|
|
|
self.writer.add_scalar('Train/Data', data_time.avg, n_iter)
|
|
|
|
self.writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
|
|
|
|
self.writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
|
|
|
|
self.writer.add_scalar('Train/Acc', accs.avg, n_iter)
|
|
|
|
self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter)
|
2019-03-20 01:26:08 +08:00
|
|
|
|
2019-03-27 21:01:39 +08:00
|
|
|
end = time.time()
|
|
|
|
|
2019-05-24 22:36:10 +08:00
|
|
|
if self.scheduler is not None:
|
2019-08-23 05:41:21 +08:00
|
|
|
self.scheduler.step()
|