mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
make engine more model-agnostic
This commit is contained in:
parent
322ec2b2de
commit
36e22e8ce1
@ -27,8 +27,7 @@ class ImageSoftmaxNASEngine(Engine):
|
|||||||
lmda_decay_rate=0.5,
|
lmda_decay_rate=0.5,
|
||||||
fixed_lmda=False
|
fixed_lmda=False
|
||||||
):
|
):
|
||||||
super(ImageSoftmaxNASEngine, self
|
super(ImageSoftmaxNASEngine, self).__init__(datamanager, use_gpu)
|
||||||
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
|
||||||
self.mc_iter = mc_iter
|
self.mc_iter = mc_iter
|
||||||
self.init_lmda = init_lmda
|
self.init_lmda = init_lmda
|
||||||
self.min_lmda = min_lmda
|
self.min_lmda = min_lmda
|
||||||
@ -36,108 +35,44 @@ class ImageSoftmaxNASEngine(Engine):
|
|||||||
self.lmda_decay_rate = lmda_decay_rate
|
self.lmda_decay_rate = lmda_decay_rate
|
||||||
self.fixed_lmda = fixed_lmda
|
self.fixed_lmda = fixed_lmda
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.register_model('model', model, optimizer, scheduler)
|
||||||
|
|
||||||
self.criterion = CrossEntropyLoss(
|
self.criterion = CrossEntropyLoss(
|
||||||
num_classes=self.datamanager.num_train_pids,
|
num_classes=self.datamanager.num_train_pids,
|
||||||
use_gpu=self.use_gpu,
|
use_gpu=self.use_gpu,
|
||||||
label_smooth=label_smooth
|
label_smooth=label_smooth
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(
|
def forward_backward(self, data):
|
||||||
self,
|
imgs, pids = self._parse_data_for_train(data)
|
||||||
epoch,
|
|
||||||
max_epoch,
|
|
||||||
writer,
|
|
||||||
fixbase_epoch=0,
|
|
||||||
open_layers=None,
|
|
||||||
print_freq=10
|
|
||||||
):
|
|
||||||
losses = AverageMeter()
|
|
||||||
accs = AverageMeter()
|
|
||||||
batch_time = AverageMeter()
|
|
||||||
data_time = AverageMeter()
|
|
||||||
|
|
||||||
self.model.train()
|
if self.use_gpu:
|
||||||
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
|
imgs = imgs.cuda()
|
||||||
print(
|
pids = pids.cuda()
|
||||||
'* Only train {} (epoch: {}/{})'.format(
|
|
||||||
open_layers, epoch + 1, fixbase_epoch
|
# softmax temporature
|
||||||
)
|
if self.fixed_lmda or self.lmda_decay_step == -1:
|
||||||
)
|
lmda = self.init_lmda
|
||||||
open_specified_layers(self.model, open_layers)
|
|
||||||
else:
|
else:
|
||||||
open_all_layers(self.model)
|
lmda = self.init_lmda * self.lmda_decay_rate**(
|
||||||
|
epoch // self.lmda_decay_step
|
||||||
|
)
|
||||||
|
if lmda < self.min_lmda:
|
||||||
|
lmda = self.min_lmda
|
||||||
|
|
||||||
num_batches = len(self.train_loader)
|
for k in range(self.mc_iter):
|
||||||
end = time.time()
|
outputs = self.model(imgs, lmda=lmda)
|
||||||
for batch_idx, data in enumerate(self.train_loader):
|
loss = self._compute_loss(self.criterion, outputs, pids)
|
||||||
data_time.update(time.time() - end)
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
imgs, pids = self._parse_data_for_train(data)
|
loss_dict = {
|
||||||
if self.use_gpu:
|
'loss': loss.item(),
|
||||||
imgs = imgs.cuda()
|
'acc': metrics.accuracy(outputs, pids)[0].item()
|
||||||
pids = pids.cuda()
|
}
|
||||||
|
|
||||||
# softmax temporature
|
return loss_dict
|
||||||
if self.fixed_lmda or self.lmda_decay_step == -1:
|
|
||||||
lmda = self.init_lmda
|
|
||||||
else:
|
|
||||||
lmda = self.init_lmda * self.lmda_decay_rate**(
|
|
||||||
epoch // self.lmda_decay_step
|
|
||||||
)
|
|
||||||
if lmda < self.min_lmda:
|
|
||||||
lmda = self.min_lmda
|
|
||||||
|
|
||||||
for k in range(self.mc_iter):
|
|
||||||
outputs = self.model(imgs, lmda=lmda)
|
|
||||||
loss = self._compute_loss(self.criterion, outputs, pids)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
batch_time.update(time.time() - end)
|
|
||||||
|
|
||||||
losses.update(loss.item(), pids.size(0))
|
|
||||||
accs.update(metrics.accuracy(outputs, pids)[0].item())
|
|
||||||
|
|
||||||
if (batch_idx+1) % print_freq == 0:
|
|
||||||
# estimate remaining time
|
|
||||||
eta_seconds = batch_time.avg * (
|
|
||||||
num_batches - (batch_idx+1) + (max_epoch -
|
|
||||||
(epoch+1)) * num_batches
|
|
||||||
)
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
print(
|
|
||||||
'Epoch: [{0}/{1}][{2}/{3}]\t'
|
|
||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
||||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
||||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
|
||||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
|
|
||||||
'Lr {lr:.6f}\t'
|
|
||||||
'eta {eta}'.format(
|
|
||||||
epoch + 1,
|
|
||||||
max_epoch,
|
|
||||||
batch_idx + 1,
|
|
||||||
num_batches,
|
|
||||||
batch_time=batch_time,
|
|
||||||
data_time=data_time,
|
|
||||||
loss=losses,
|
|
||||||
acc=accs,
|
|
||||||
lr=self.optimizer.param_groups[0]['lr'],
|
|
||||||
eta=eta_str
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
n_iter = epoch*num_batches + batch_idx
|
|
||||||
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Data', data_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Loss', losses.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Acc', accs.avg, n_iter)
|
|
||||||
writer.add_scalar(
|
|
||||||
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
|
|
||||||
)
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
if self.scheduler is not None:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
@ -3,13 +3,15 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import datetime
|
import datetime
|
||||||
|
from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.utils import (
|
from torchreid.utils import (
|
||||||
AverageMeter, re_ranking, save_checkpoint, visualize_ranked_results
|
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
|
||||||
|
open_specified_layers, visualize_ranked_results
|
||||||
)
|
)
|
||||||
from torchreid.losses import DeepSupervision
|
from torchreid.losses import DeepSupervision
|
||||||
|
|
||||||
@ -26,22 +28,89 @@ class Engine(object):
|
|||||||
use_gpu (bool, optional): use gpu. Default is True.
|
use_gpu (bool, optional): use gpu. Default is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, datamanager, use_gpu=True):
|
||||||
self,
|
|
||||||
datamanager,
|
|
||||||
model,
|
|
||||||
optimizer=None,
|
|
||||||
scheduler=None,
|
|
||||||
use_gpu=True
|
|
||||||
):
|
|
||||||
self.datamanager = datamanager
|
self.datamanager = datamanager
|
||||||
self.model = model
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.scheduler = scheduler
|
|
||||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
|
||||||
self.writer = None
|
|
||||||
self.train_loader = self.datamanager.train_loader
|
self.train_loader = self.datamanager.train_loader
|
||||||
self.test_loader = self.datamanager.test_loader
|
self.test_loader = self.datamanager.test_loader
|
||||||
|
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.optimizer = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
self._models = OrderedDict()
|
||||||
|
self._optims = OrderedDict()
|
||||||
|
self._scheds = OrderedDict()
|
||||||
|
|
||||||
|
def register_model(self, name='model', model=None, optim=None, sched=None):
|
||||||
|
if self.__dict__.get('_models') is None:
|
||||||
|
raise AttributeError(
|
||||||
|
'Cannot assign model before super().__init__() call'
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__dict__.get('_optims') is None:
|
||||||
|
raise AttributeError(
|
||||||
|
'Cannot assign optim before super().__init__() call'
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__dict__.get('_scheds') is None:
|
||||||
|
raise AttributeError(
|
||||||
|
'Cannot assign sched before super().__init__() call'
|
||||||
|
)
|
||||||
|
|
||||||
|
self._models[name] = model
|
||||||
|
self._optims[name] = optim
|
||||||
|
self._scheds[name] = sched
|
||||||
|
|
||||||
|
def get_model_names(self, names=None):
|
||||||
|
names_real = list(self._models.keys())
|
||||||
|
if names is not None:
|
||||||
|
if not isinstance(names, list):
|
||||||
|
names = [names]
|
||||||
|
for name in names:
|
||||||
|
assert name in names_real
|
||||||
|
return names
|
||||||
|
else:
|
||||||
|
return names_real
|
||||||
|
|
||||||
|
def save_model(self, epoch, rank1, save_dir, is_best=False):
|
||||||
|
names = self.get_model_names()
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
save_checkpoint(
|
||||||
|
{
|
||||||
|
'state_dict': self._models[name].state_dict(),
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'rank1': rank1,
|
||||||
|
'optimizer': self._optims[name].state_dict(),
|
||||||
|
'scheduler': self._scheds[name].state_dict()
|
||||||
|
},
|
||||||
|
osp.join(save_dir, name),
|
||||||
|
is_best=is_best
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_model_mode(self, mode='train', names=None):
|
||||||
|
assert mode in ['train', 'eval', 'test']
|
||||||
|
names = self.get_model_names(names)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
if mode == 'train':
|
||||||
|
self._models[name].train()
|
||||||
|
else:
|
||||||
|
self._models[name].eval()
|
||||||
|
|
||||||
|
def get_current_lr(self, names=None):
|
||||||
|
names = self.get_model_names(names)
|
||||||
|
name = names[0]
|
||||||
|
return self._optims[name].param_groups[0]['lr']
|
||||||
|
|
||||||
|
def update_lr(self, names=None):
|
||||||
|
names = self.get_model_names(names)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
if self._scheds[name] is not None:
|
||||||
|
self._scheds[name].step()
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@ -142,7 +211,7 @@ class Engine(object):
|
|||||||
use_metric_cuhk03=use_metric_cuhk03,
|
use_metric_cuhk03=use_metric_cuhk03,
|
||||||
ranks=ranks
|
ranks=ranks
|
||||||
)
|
)
|
||||||
self._save_checkpoint(epoch, rank1, save_dir)
|
self.save_model(epoch, rank1, save_dir)
|
||||||
|
|
||||||
if max_epoch > 0:
|
if max_epoch > 0:
|
||||||
print('=> Final test')
|
print('=> Final test')
|
||||||
@ -156,7 +225,7 @@ class Engine(object):
|
|||||||
use_metric_cuhk03=use_metric_cuhk03,
|
use_metric_cuhk03=use_metric_cuhk03,
|
||||||
ranks=ranks
|
ranks=ranks
|
||||||
)
|
)
|
||||||
self._save_checkpoint(epoch, rank1, save_dir)
|
self.save_model(epoch, rank1, save_dir)
|
||||||
|
|
||||||
elapsed = round(time.time() - time_start)
|
elapsed = round(time.time() - time_start)
|
||||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||||
@ -164,20 +233,68 @@ class Engine(object):
|
|||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
||||||
def train(self):
|
def train(
|
||||||
r"""Performs training on source datasets for one epoch.
|
self,
|
||||||
|
epoch,
|
||||||
|
max_epoch,
|
||||||
|
writer,
|
||||||
|
print_freq=10,
|
||||||
|
fixbase_epoch=0,
|
||||||
|
open_layers=None
|
||||||
|
):
|
||||||
|
losses = MetricMeter()
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
data_time = AverageMeter()
|
||||||
|
|
||||||
This will be called every epoch in ``run()``, e.g.
|
self.set_model_mode('train')
|
||||||
|
|
||||||
.. code-block:: python
|
self._two_stepped_transfer_learning(epoch, fixbase_epoch, open_layers)
|
||||||
|
|
||||||
for epoch in range(start_epoch, max_epoch):
|
|
||||||
self.train(some_arguments)
|
|
||||||
|
|
||||||
.. note::
|
num_batches = len(self.train_loader)
|
||||||
|
end = time.time()
|
||||||
This must be implemented in subclasses.
|
for batch_idx, data in enumerate(self.train_loader):
|
||||||
"""
|
data_time.update(time.time() - end)
|
||||||
|
loss_dict = self.forward_backward(data)
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
losses.update(loss_dict)
|
||||||
|
|
||||||
|
if (batch_idx+1) % print_freq == 0:
|
||||||
|
nb_this_epoch = num_batches - (batch_idx+1)
|
||||||
|
nb_future_epochs = (max_epoch - (epoch+1)) * num_batches
|
||||||
|
eta_seconds = batch_time.avg * (nb_this_epoch+nb_future_epochs)
|
||||||
|
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
print(
|
||||||
|
'epoch: [{0}/{1}][{2}/{3}]\t'
|
||||||
|
'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||||
|
'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||||
|
'eta {eta}\t'
|
||||||
|
'{losses}\t'
|
||||||
|
'lr {lr:.6f}'.format(
|
||||||
|
epoch + 1,
|
||||||
|
max_epoch,
|
||||||
|
batch_idx + 1,
|
||||||
|
num_batches,
|
||||||
|
batch_time=batch_time,
|
||||||
|
data_time=data_time,
|
||||||
|
eta=eta_str,
|
||||||
|
losses=losses,
|
||||||
|
lr=self.get_current_lr()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
n_iter = epoch*num_batches + batch_idx
|
||||||
|
writer.add_scalar('Train/time', batch_time.avg, n_iter)
|
||||||
|
writer.add_scalar('Train/data', data_time.avg, n_iter)
|
||||||
|
for name, meter in losses.meters.items():
|
||||||
|
writer.add_scalar('Train/' + name, meter.avg, n_iter)
|
||||||
|
writer.add_scalar('Train/lr', self.get_current_lr(), n_iter)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
def forward_backward(self, data):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def test(
|
def test(
|
||||||
@ -205,6 +322,7 @@ class Engine(object):
|
|||||||
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
|
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
|
||||||
but not a must. Please refer to the source code for more details.
|
but not a must. Please refer to the source code for more details.
|
||||||
"""
|
"""
|
||||||
|
self.set_model_mode('eval')
|
||||||
targets = list(self.test_loader.keys())
|
targets = list(self.test_loader.keys())
|
||||||
|
|
||||||
for name in targets:
|
for name in targets:
|
||||||
@ -330,7 +448,6 @@ class Engine(object):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _extract_features(self, input):
|
def _extract_features(self, input):
|
||||||
self.model.eval()
|
|
||||||
return self.model(input)
|
return self.model(input)
|
||||||
|
|
||||||
def _parse_data_for_train(self, data):
|
def _parse_data_for_train(self, data):
|
||||||
@ -344,15 +461,26 @@ class Engine(object):
|
|||||||
camids = data[2]
|
camids = data[2]
|
||||||
return imgs, pids, camids
|
return imgs, pids, camids
|
||||||
|
|
||||||
def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
|
def _two_stepped_transfer_learning(
|
||||||
save_checkpoint(
|
self, epoch, fixbase_epoch, open_layers, model=None
|
||||||
{
|
):
|
||||||
'state_dict': self.model.state_dict(),
|
"""Two stepped transfer learning.
|
||||||
'epoch': epoch + 1,
|
|
||||||
'rank1': rank1,
|
The idea is to freeze base layers for a certain number of epochs
|
||||||
'optimizer': self.optimizer.state_dict(),
|
and then open all layers for training.
|
||||||
'scheduler': self.scheduler.state_dict(),
|
|
||||||
},
|
Reference: https://arxiv.org/abs/1611.05244
|
||||||
save_dir,
|
"""
|
||||||
is_best=is_best
|
model = self.model if model is None else model
|
||||||
)
|
if model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
|
||||||
|
print(
|
||||||
|
'* Only train {} (epoch: {}/{})'.format(
|
||||||
|
open_layers, epoch + 1, fixbase_epoch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
open_specified_layers(model, open_layers)
|
||||||
|
else:
|
||||||
|
open_all_layers(model)
|
||||||
|
@ -3,9 +3,6 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.utils import (
|
|
||||||
AverageMeter, open_all_layers, open_specified_layers
|
|
||||||
)
|
|
||||||
from torchreid.losses import CrossEntropyLoss
|
from torchreid.losses import CrossEntropyLoss
|
||||||
|
|
||||||
from ..engine import Engine
|
from ..engine import Engine
|
||||||
@ -67,8 +64,12 @@ class ImageSoftmaxEngine(Engine):
|
|||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
label_smooth=True
|
label_smooth=True
|
||||||
):
|
):
|
||||||
super(ImageSoftmaxEngine, self
|
super(ImageSoftmaxEngine, self).__init__(datamanager, use_gpu)
|
||||||
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.register_model('model', model, optimizer, scheduler)
|
||||||
|
|
||||||
self.criterion = CrossEntropyLoss(
|
self.criterion = CrossEntropyLoss(
|
||||||
num_classes=self.datamanager.num_train_pids,
|
num_classes=self.datamanager.num_train_pids,
|
||||||
@ -76,91 +77,22 @@ class ImageSoftmaxEngine(Engine):
|
|||||||
label_smooth=label_smooth
|
label_smooth=label_smooth
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(
|
def forward_backward(self, data):
|
||||||
self,
|
imgs, pids = self._parse_data_for_train(data)
|
||||||
epoch,
|
|
||||||
max_epoch,
|
|
||||||
writer,
|
|
||||||
print_freq=10,
|
|
||||||
fixbase_epoch=0,
|
|
||||||
open_layers=None
|
|
||||||
):
|
|
||||||
losses = AverageMeter()
|
|
||||||
accs = AverageMeter()
|
|
||||||
batch_time = AverageMeter()
|
|
||||||
data_time = AverageMeter()
|
|
||||||
|
|
||||||
self.model.train()
|
if self.use_gpu:
|
||||||
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
|
imgs = imgs.cuda()
|
||||||
print(
|
pids = pids.cuda()
|
||||||
'* Only train {} (epoch: {}/{})'.format(
|
|
||||||
open_layers, epoch + 1, fixbase_epoch
|
|
||||||
)
|
|
||||||
)
|
|
||||||
open_specified_layers(self.model, open_layers)
|
|
||||||
else:
|
|
||||||
open_all_layers(self.model)
|
|
||||||
|
|
||||||
num_batches = len(self.train_loader)
|
self.optimizer.zero_grad()
|
||||||
end = time.time()
|
outputs = self.model(imgs)
|
||||||
for batch_idx, data in enumerate(self.train_loader):
|
loss = self._compute_loss(self.criterion, outputs, pids)
|
||||||
data_time.update(time.time() - end)
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
imgs, pids = self._parse_data_for_train(data)
|
loss_dict = {
|
||||||
if self.use_gpu:
|
'loss': loss.item(),
|
||||||
imgs = imgs.cuda()
|
'acc': metrics.accuracy(outputs, pids)[0].item()
|
||||||
pids = pids.cuda()
|
}
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
return loss_dict
|
||||||
outputs = self.model(imgs)
|
|
||||||
loss = self._compute_loss(self.criterion, outputs, pids)
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
batch_time.update(time.time() - end)
|
|
||||||
|
|
||||||
losses.update(loss.item(), pids.size(0))
|
|
||||||
accs.update(metrics.accuracy(outputs, pids)[0].item())
|
|
||||||
|
|
||||||
if (batch_idx+1) % print_freq == 0:
|
|
||||||
# estimate remaining time
|
|
||||||
eta_seconds = batch_time.avg * (
|
|
||||||
num_batches - (batch_idx+1) + (max_epoch -
|
|
||||||
(epoch+1)) * num_batches
|
|
||||||
)
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
print(
|
|
||||||
'Epoch: [{0}/{1}][{2}/{3}]\t'
|
|
||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
||||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
||||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
|
||||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
|
|
||||||
'Lr {lr:.6f}\t'
|
|
||||||
'eta {eta}'.format(
|
|
||||||
epoch + 1,
|
|
||||||
max_epoch,
|
|
||||||
batch_idx + 1,
|
|
||||||
num_batches,
|
|
||||||
batch_time=batch_time,
|
|
||||||
data_time=data_time,
|
|
||||||
loss=losses,
|
|
||||||
acc=accs,
|
|
||||||
lr=self.optimizer.param_groups[0]['lr'],
|
|
||||||
eta=eta_str
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
n_iter = epoch*num_batches + batch_idx
|
|
||||||
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Data', data_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Loss', losses.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Acc', accs.avg, n_iter)
|
|
||||||
writer.add_scalar(
|
|
||||||
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
|
|
||||||
)
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
if self.scheduler is not None:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
@ -3,9 +3,6 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.utils import (
|
|
||||||
AverageMeter, open_all_layers, open_specified_layers
|
|
||||||
)
|
|
||||||
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
||||||
|
|
||||||
from ..engine import Engine
|
from ..engine import Engine
|
||||||
@ -76,8 +73,12 @@ class ImageTripletEngine(Engine):
|
|||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
label_smooth=True
|
label_smooth=True
|
||||||
):
|
):
|
||||||
super(ImageTripletEngine, self
|
super(ImageTripletEngine, self).__init__(datamanager, use_gpu)
|
||||||
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.register_model('model', model, optimizer, scheduler)
|
||||||
|
|
||||||
self.weight_t = weight_t
|
self.weight_t = weight_t
|
||||||
self.weight_x = weight_x
|
self.weight_x = weight_x
|
||||||
@ -89,98 +90,25 @@ class ImageTripletEngine(Engine):
|
|||||||
label_smooth=label_smooth
|
label_smooth=label_smooth
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(
|
def forward_backward(self, data):
|
||||||
self,
|
imgs, pids = self._parse_data_for_train(data)
|
||||||
epoch,
|
|
||||||
max_epoch,
|
|
||||||
writer,
|
|
||||||
print_freq=10,
|
|
||||||
fixbase_epoch=0,
|
|
||||||
open_layers=None
|
|
||||||
):
|
|
||||||
losses_t = AverageMeter()
|
|
||||||
losses_x = AverageMeter()
|
|
||||||
accs = AverageMeter()
|
|
||||||
batch_time = AverageMeter()
|
|
||||||
data_time = AverageMeter()
|
|
||||||
|
|
||||||
self.model.train()
|
if self.use_gpu:
|
||||||
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
|
imgs = imgs.cuda()
|
||||||
print(
|
pids = pids.cuda()
|
||||||
'* Only train {} (epoch: {}/{})'.format(
|
|
||||||
open_layers, epoch + 1, fixbase_epoch
|
|
||||||
)
|
|
||||||
)
|
|
||||||
open_specified_layers(self.model, open_layers)
|
|
||||||
else:
|
|
||||||
open_all_layers(self.model)
|
|
||||||
|
|
||||||
num_batches = len(self.train_loader)
|
self.optimizer.zero_grad()
|
||||||
end = time.time()
|
outputs, features = self.model(imgs)
|
||||||
for batch_idx, data in enumerate(self.train_loader):
|
loss_t = self._compute_loss(self.criterion_t, features, pids)
|
||||||
data_time.update(time.time() - end)
|
loss_x = self._compute_loss(self.criterion_x, outputs, pids)
|
||||||
|
loss = self.weight_t * loss_t + self.weight_x * loss_x
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
imgs, pids = self._parse_data_for_train(data)
|
loss_dict = {
|
||||||
if self.use_gpu:
|
'loss_t': loss_t.item(),
|
||||||
imgs = imgs.cuda()
|
'loss_x': loss_x.item(),
|
||||||
pids = pids.cuda()
|
'acc': metrics.accuracy(outputs, pids)[0].item()
|
||||||
|
}
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
return loss_dict
|
||||||
outputs, features = self.model(imgs)
|
|
||||||
loss_t = self._compute_loss(self.criterion_t, features, pids)
|
|
||||||
loss_x = self._compute_loss(self.criterion_x, outputs, pids)
|
|
||||||
loss = self.weight_t * loss_t + self.weight_x * loss_x
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
batch_time.update(time.time() - end)
|
|
||||||
|
|
||||||
losses_t.update(loss_t.item(), pids.size(0))
|
|
||||||
losses_x.update(loss_x.item(), pids.size(0))
|
|
||||||
accs.update(metrics.accuracy(outputs, pids)[0].item())
|
|
||||||
|
|
||||||
if (batch_idx+1) % print_freq == 0:
|
|
||||||
# estimate remaining time
|
|
||||||
eta_seconds = batch_time.avg * (
|
|
||||||
num_batches - (batch_idx+1) + (max_epoch -
|
|
||||||
(epoch+1)) * num_batches
|
|
||||||
)
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
|
|
||||||
print(
|
|
||||||
'Epoch: [{0}/{1}][{2}/{3}]\t'
|
|
||||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
||||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
||||||
'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
|
||||||
'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
|
|
||||||
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
|
|
||||||
'Lr {lr:.6f}\t'
|
|
||||||
'eta {eta}'.format(
|
|
||||||
epoch + 1,
|
|
||||||
max_epoch,
|
|
||||||
batch_idx + 1,
|
|
||||||
num_batches,
|
|
||||||
batch_time=batch_time,
|
|
||||||
data_time=data_time,
|
|
||||||
loss_t=losses_t,
|
|
||||||
loss_x=losses_x,
|
|
||||||
acc=accs,
|
|
||||||
lr=self.optimizer.param_groups[0]['lr'],
|
|
||||||
eta=eta_str
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
n_iter = epoch*num_batches + batch_idx
|
|
||||||
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Data', data_time.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
|
|
||||||
writer.add_scalar('Train/Acc', accs.avg, n_iter)
|
|
||||||
writer.add_scalar(
|
|
||||||
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
|
|
||||||
)
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
if self.scheduler is not None:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
@ -13,25 +13,23 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\begin{equation}
|
\begin{equation}
|
||||||
(1 - \epsilon) \times y + \frac{\epsilon}{K},
|
(1 - \eps) \times y + \frac{\eps}{K},
|
||||||
\end{equation}
|
\end{equation}
|
||||||
|
|
||||||
where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When
|
where :math:`K` denotes the number of classes and :math:`\eps` is a weight. When
|
||||||
:math:`\epsilon = 0`, the loss function reduces to the normal cross entropy.
|
:math:`\eps = 0`, the loss function reduces to the normal cross entropy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_classes (int): number of classes.
|
num_classes (int): number of classes.
|
||||||
epsilon (float, optional): weight. Default is 0.1.
|
eps (float, optional): weight. Default is 0.1.
|
||||||
use_gpu (bool, optional): whether to use gpu devices. Default is True.
|
use_gpu (bool, optional): whether to use gpu devices. Default is True.
|
||||||
label_smooth (bool, optional): whether to apply label smoothing. Default is True.
|
label_smooth (bool, optional): whether to apply label smoothing. Default is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, num_classes, eps=0.1, use_gpu=True, label_smooth=True):
|
||||||
self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True
|
|
||||||
):
|
|
||||||
super(CrossEntropyLoss, self).__init__()
|
super(CrossEntropyLoss, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.epsilon = epsilon if label_smooth else 0
|
self.eps = eps if label_smooth else 0
|
||||||
self.use_gpu = use_gpu
|
self.use_gpu = use_gpu
|
||||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||||
|
|
||||||
@ -48,7 +46,5 @@ class CrossEntropyLoss(nn.Module):
|
|||||||
targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
||||||
if self.use_gpu:
|
if self.use_gpu:
|
||||||
targets = targets.cuda()
|
targets = targets.cuda()
|
||||||
targets = (
|
targets = (1 - self.eps) * targets + self.eps / self.num_classes
|
||||||
1 - self.epsilon
|
|
||||||
) * targets + self.epsilon / self.num_classes
|
|
||||||
return (-targets * log_probs).mean(0).sum()
|
return (-targets * log_probs).mean(0).sum()
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from __future__ import division, absolute_import
|
from __future__ import division, absolute_import
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
|
||||||
__all__ = ['AverageMeter']
|
__all__ = ['AverageMeter', 'MetricMeter']
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
class AverageMeter(object):
|
||||||
@ -27,3 +29,45 @@ class AverageMeter(object):
|
|||||||
self.sum += val * n
|
self.sum += val * n
|
||||||
self.count += n
|
self.count += n
|
||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
class MetricMeter(object):
|
||||||
|
"""A collection of metrics.
|
||||||
|
|
||||||
|
Source: https://github.com/KaiyangZhou/Dassl.pytorch
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> # 1. Create an instance of MetricMeter
|
||||||
|
>>> metric = MetricMeter()
|
||||||
|
>>> # 2. Update using a dictionary as input
|
||||||
|
>>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
|
||||||
|
>>> metric.update(input_dict)
|
||||||
|
>>> # 3. Convert to string and print
|
||||||
|
>>> print(str(metric))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, delimiter='\t'):
|
||||||
|
self.meters = defaultdict(AverageMeter)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, input_dict):
|
||||||
|
if input_dict is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(input_dict, dict):
|
||||||
|
raise TypeError(
|
||||||
|
'Input to MetricMeter.update() must be a dictionary'
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v in input_dict.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
output_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
output_str.append(
|
||||||
|
'{} {:.4f} ({:.4f})'.format(name, meter.val, meter.avg)
|
||||||
|
)
|
||||||
|
return self.delimiter.join(output_str)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user