deep-person-reid/torchreid/engine/image/triplet.py

114 lines
3.5 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-20 01:26:08 +08:00
from torchreid import metrics
2019-12-01 10:35:44 +08:00
from torchreid.losses import TripletLoss, CrossEntropyLoss
2019-03-20 01:26:08 +08:00
2019-12-01 11:31:32 +08:00
from ..engine import Engine
2019-03-20 01:26:08 +08:00
class ImageTripletEngine(Engine):
2019-05-24 22:36:10 +08:00
r"""Triplet-loss engine for image-reid.
2019-03-24 07:09:39 +08:00
Args:
datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``
or ``torchreid.data.VideoDataManager``.
model (nn.Module): model instance.
optimizer (Optimizer): an Optimizer.
margin (float, optional): margin for triplet loss. Default is 0.3.
weight_t (float, optional): weight for triplet loss. Default is 1.
weight_x (float, optional): weight for softmax loss. Default is 1.
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
2019-08-26 17:34:31 +08:00
use_gpu (bool, optional): use gpu. Default is True.
2019-03-24 07:09:39 +08:00
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
Examples::
import torchreid
datamanager = torchreid.data.ImageDataManager(
root='path/to/reid-data',
sources='market1501',
height=256,
width=128,
combineall=False,
batch_size=32,
num_instances=4,
train_sampler='RandomIdentitySampler' # this is important
)
model = torchreid.models.build_model(
name='resnet50',
num_classes=datamanager.num_train_pids,
loss='triplet'
)
model = model.cuda()
optimizer = torchreid.optim.build_optimizer(
model, optim='adam', lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
stepsize=20
)
engine = torchreid.engine.ImageTripletEngine(
datamanager, model, optimizer, margin=0.3,
weight_t=0.7, weight_x=1, scheduler=scheduler
)
engine.run(
max_epoch=60,
save_dir='log/resnet50-triplet-market1501',
print_freq=10
)
2019-12-01 10:35:44 +08:00
"""
def __init__(
self,
datamanager,
model,
optimizer,
margin=0.3,
weight_t=1,
weight_x=1,
scheduler=None,
use_gpu=True,
label_smooth=True
):
2020-04-16 19:46:15 +08:00
super(ImageTripletEngine, self).__init__(datamanager, use_gpu)
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
2019-03-20 01:26:08 +08:00
self.weight_t = weight_t
self.weight_x = weight_x
2019-12-01 10:35:44 +08:00
2019-03-20 01:26:08 +08:00
self.criterion_t = TripletLoss(margin=margin)
self.criterion_x = CrossEntropyLoss(
2019-03-21 20:53:21 +08:00
num_classes=self.datamanager.num_train_pids,
2019-03-20 01:26:08 +08:00
use_gpu=self.use_gpu,
label_smooth=label_smooth
)
2020-04-16 19:46:15 +08:00
def forward_backward(self, data):
2020-05-05 19:43:28 +08:00
imgs, pids = self.parse_data_for_train(data)
2019-03-20 01:26:08 +08:00
2020-04-16 19:46:15 +08:00
if self.use_gpu:
imgs = imgs.cuda()
pids = pids.cuda()
2019-08-23 05:41:21 +08:00
2020-04-16 19:46:15 +08:00
outputs, features = self.model(imgs)
2020-05-05 19:43:28 +08:00
loss_t = self.compute_loss(self.criterion_t, features, pids)
loss_x = self.compute_loss(self.criterion_x, outputs, pids)
2020-04-16 19:46:15 +08:00
loss = self.weight_t * loss_t + self.weight_x * loss_x
2020-05-05 19:48:58 +08:00
self.optimizer.zero_grad()
2020-04-16 19:46:15 +08:00
loss.backward()
self.optimizer.step()
2019-12-01 10:35:44 +08:00
2020-04-17 04:57:49 +08:00
loss_summary = {
2020-04-16 19:46:15 +08:00
'loss_t': loss_t.item(),
'loss_x': loss_x.item(),
'acc': metrics.accuracy(outputs, pids)[0].item()
}
2019-03-27 21:01:39 +08:00
2020-04-17 04:57:49 +08:00
return loss_summary