make engine more model-agnostic

This commit is contained in:
KaiyangZhou 2020-04-16 12:46:15 +01:00
parent 322ec2b2de
commit 36e22e8ce1
6 changed files with 296 additions and 333 deletions

View File

@ -27,8 +27,7 @@ class ImageSoftmaxNASEngine(Engine):
lmda_decay_rate=0.5, lmda_decay_rate=0.5,
fixed_lmda=False fixed_lmda=False
): ):
super(ImageSoftmaxNASEngine, self super(ImageSoftmaxNASEngine, self).__init__(datamanager, use_gpu)
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
self.mc_iter = mc_iter self.mc_iter = mc_iter
self.init_lmda = init_lmda self.init_lmda = init_lmda
self.min_lmda = min_lmda self.min_lmda = min_lmda
@ -36,108 +35,44 @@ class ImageSoftmaxNASEngine(Engine):
self.lmda_decay_rate = lmda_decay_rate self.lmda_decay_rate = lmda_decay_rate
self.fixed_lmda = fixed_lmda self.fixed_lmda = fixed_lmda
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
self.criterion = CrossEntropyLoss( self.criterion = CrossEntropyLoss(
num_classes=self.datamanager.num_train_pids, num_classes=self.datamanager.num_train_pids,
use_gpu=self.use_gpu, use_gpu=self.use_gpu,
label_smooth=label_smooth label_smooth=label_smooth
) )
def train( def forward_backward(self, data):
self, imgs, pids = self._parse_data_for_train(data)
epoch,
max_epoch,
writer,
fixbase_epoch=0,
open_layers=None,
print_freq=10
):
losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.model.train() if self.use_gpu:
if (epoch + 1) <= fixbase_epoch and open_layers is not None: imgs = imgs.cuda()
print( pids = pids.cuda()
'* Only train {} (epoch: {}/{})'.format(
open_layers, epoch + 1, fixbase_epoch # softmax temporature
) if self.fixed_lmda or self.lmda_decay_step == -1:
) lmda = self.init_lmda
open_specified_layers(self.model, open_layers)
else: else:
open_all_layers(self.model) lmda = self.init_lmda * self.lmda_decay_rate**(
epoch // self.lmda_decay_step
)
if lmda < self.min_lmda:
lmda = self.min_lmda
num_batches = len(self.train_loader) for k in range(self.mc_iter):
end = time.time() outputs = self.model(imgs, lmda=lmda)
for batch_idx, data in enumerate(self.train_loader): loss = self._compute_loss(self.criterion, outputs, pids)
data_time.update(time.time() - end) self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
imgs, pids = self._parse_data_for_train(data) loss_dict = {
if self.use_gpu: 'loss': loss.item(),
imgs = imgs.cuda() 'acc': metrics.accuracy(outputs, pids)[0].item()
pids = pids.cuda() }
# softmax temporature return loss_dict
if self.fixed_lmda or self.lmda_decay_step == -1:
lmda = self.init_lmda
else:
lmda = self.init_lmda * self.lmda_decay_rate**(
epoch // self.lmda_decay_step
)
if lmda < self.min_lmda:
lmda = self.min_lmda
for k in range(self.mc_iter):
outputs = self.model(imgs, lmda=lmda)
loss = self._compute_loss(self.criterion, outputs, pids)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
accs.update(metrics.accuracy(outputs, pids)[0].item())
if (batch_idx+1) % print_freq == 0:
# estimate remaining time
eta_seconds = batch_time.avg * (
num_batches - (batch_idx+1) + (max_epoch -
(epoch+1)) * num_batches
)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
'Epoch: [{0}/{1}][{2}/{3}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
'Lr {lr:.6f}\t'
'eta {eta}'.format(
epoch + 1,
max_epoch,
batch_idx + 1,
num_batches,
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accs,
lr=self.optimizer.param_groups[0]['lr'],
eta=eta_str
)
)
if writer is not None:
n_iter = epoch*num_batches + batch_idx
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
writer.add_scalar('Train/Data', data_time.avg, n_iter)
writer.add_scalar('Train/Loss', losses.avg, n_iter)
writer.add_scalar('Train/Acc', accs.avg, n_iter)
writer.add_scalar(
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
)
end = time.time()
if self.scheduler is not None:
self.scheduler.step()

View File

@ -3,13 +3,15 @@ import time
import numpy as np import numpy as np
import os.path as osp import os.path as osp
import datetime import datetime
from collections import OrderedDict
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchreid import metrics from torchreid import metrics
from torchreid.utils import ( from torchreid.utils import (
AverageMeter, re_ranking, save_checkpoint, visualize_ranked_results MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
open_specified_layers, visualize_ranked_results
) )
from torchreid.losses import DeepSupervision from torchreid.losses import DeepSupervision
@ -26,22 +28,89 @@ class Engine(object):
use_gpu (bool, optional): use gpu. Default is True. use_gpu (bool, optional): use gpu. Default is True.
""" """
def __init__( def __init__(self, datamanager, use_gpu=True):
self,
datamanager,
model,
optimizer=None,
scheduler=None,
use_gpu=True
):
self.datamanager = datamanager self.datamanager = datamanager
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.use_gpu = (torch.cuda.is_available() and use_gpu)
self.writer = None
self.train_loader = self.datamanager.train_loader self.train_loader = self.datamanager.train_loader
self.test_loader = self.datamanager.test_loader self.test_loader = self.datamanager.test_loader
self.use_gpu = (torch.cuda.is_available() and use_gpu)
self.writer = None
self.model = None
self.optimizer = None
self.scheduler = None
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
def register_model(self, name='model', model=None, optim=None, sched=None):
if self.__dict__.get('_models') is None:
raise AttributeError(
'Cannot assign model before super().__init__() call'
)
if self.__dict__.get('_optims') is None:
raise AttributeError(
'Cannot assign optim before super().__init__() call'
)
if self.__dict__.get('_scheds') is None:
raise AttributeError(
'Cannot assign sched before super().__init__() call'
)
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
if not isinstance(names, list):
names = [names]
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(self, epoch, rank1, save_dir, is_best=False):
names = self.get_model_names()
for name in names:
save_checkpoint(
{
'state_dict': self._models[name].state_dict(),
'epoch': epoch + 1,
'rank1': rank1,
'optimizer': self._optims[name].state_dict(),
'scheduler': self._scheds[name].state_dict()
},
osp.join(save_dir, name),
is_best=is_best
)
def set_model_mode(self, mode='train', names=None):
assert mode in ['train', 'eval', 'test']
names = self.get_model_names(names)
for name in names:
if mode == 'train':
self._models[name].train()
else:
self._models[name].eval()
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]['lr']
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def run( def run(
self, self,
@ -142,7 +211,7 @@ class Engine(object):
use_metric_cuhk03=use_metric_cuhk03, use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks ranks=ranks
) )
self._save_checkpoint(epoch, rank1, save_dir) self.save_model(epoch, rank1, save_dir)
if max_epoch > 0: if max_epoch > 0:
print('=> Final test') print('=> Final test')
@ -156,7 +225,7 @@ class Engine(object):
use_metric_cuhk03=use_metric_cuhk03, use_metric_cuhk03=use_metric_cuhk03,
ranks=ranks ranks=ranks
) )
self._save_checkpoint(epoch, rank1, save_dir) self.save_model(epoch, rank1, save_dir)
elapsed = round(time.time() - time_start) elapsed = round(time.time() - time_start)
elapsed = str(datetime.timedelta(seconds=elapsed)) elapsed = str(datetime.timedelta(seconds=elapsed))
@ -164,20 +233,68 @@ class Engine(object):
if self.writer is not None: if self.writer is not None:
self.writer.close() self.writer.close()
def train(self): def train(
r"""Performs training on source datasets for one epoch. self,
epoch,
max_epoch,
writer,
print_freq=10,
fixbase_epoch=0,
open_layers=None
):
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
This will be called every epoch in ``run()``, e.g. self.set_model_mode('train')
.. code-block:: python self._two_stepped_transfer_learning(epoch, fixbase_epoch, open_layers)
for epoch in range(start_epoch, max_epoch):
self.train(some_arguments)
.. note:: num_batches = len(self.train_loader)
end = time.time()
This must be implemented in subclasses. for batch_idx, data in enumerate(self.train_loader):
""" data_time.update(time.time() - end)
loss_dict = self.forward_backward(data)
batch_time.update(time.time() - end)
losses.update(loss_dict)
if (batch_idx+1) % print_freq == 0:
nb_this_epoch = num_batches - (batch_idx+1)
nb_future_epochs = (max_epoch - (epoch+1)) * num_batches
eta_seconds = batch_time.avg * (nb_this_epoch+nb_future_epochs)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
'epoch: [{0}/{1}][{2}/{3}]\t'
'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'eta {eta}\t'
'{losses}\t'
'lr {lr:.6f}'.format(
epoch + 1,
max_epoch,
batch_idx + 1,
num_batches,
batch_time=batch_time,
data_time=data_time,
eta=eta_str,
losses=losses,
lr=self.get_current_lr()
)
)
if writer is not None:
n_iter = epoch*num_batches + batch_idx
writer.add_scalar('Train/time', batch_time.avg, n_iter)
writer.add_scalar('Train/data', data_time.avg, n_iter)
for name, meter in losses.meters.items():
writer.add_scalar('Train/' + name, meter.avg, n_iter)
writer.add_scalar('Train/lr', self.get_current_lr(), n_iter)
end = time.time()
self.update_lr()
def forward_backward(self, data):
raise NotImplementedError raise NotImplementedError
def test( def test(
@ -205,6 +322,7 @@ class Engine(object):
``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time), ``_extract_features()`` and ``_parse_data_for_eval()`` (most of the time),
but not a must. Please refer to the source code for more details. but not a must. Please refer to the source code for more details.
""" """
self.set_model_mode('eval')
targets = list(self.test_loader.keys()) targets = list(self.test_loader.keys())
for name in targets: for name in targets:
@ -330,7 +448,6 @@ class Engine(object):
return loss return loss
def _extract_features(self, input): def _extract_features(self, input):
self.model.eval()
return self.model(input) return self.model(input)
def _parse_data_for_train(self, data): def _parse_data_for_train(self, data):
@ -344,15 +461,26 @@ class Engine(object):
camids = data[2] camids = data[2]
return imgs, pids, camids return imgs, pids, camids
def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False): def _two_stepped_transfer_learning(
save_checkpoint( self, epoch, fixbase_epoch, open_layers, model=None
{ ):
'state_dict': self.model.state_dict(), """Two stepped transfer learning.
'epoch': epoch + 1,
'rank1': rank1, The idea is to freeze base layers for a certain number of epochs
'optimizer': self.optimizer.state_dict(), and then open all layers for training.
'scheduler': self.scheduler.state_dict(),
}, Reference: https://arxiv.org/abs/1611.05244
save_dir, """
is_best=is_best model = self.model if model is None else model
) if model is None:
return
if (epoch + 1) <= fixbase_epoch and open_layers is not None:
print(
'* Only train {} (epoch: {}/{})'.format(
open_layers, epoch + 1, fixbase_epoch
)
)
open_specified_layers(model, open_layers)
else:
open_all_layers(model)

View File

@ -3,9 +3,6 @@ import time
import datetime import datetime
from torchreid import metrics from torchreid import metrics
from torchreid.utils import (
AverageMeter, open_all_layers, open_specified_layers
)
from torchreid.losses import CrossEntropyLoss from torchreid.losses import CrossEntropyLoss
from ..engine import Engine from ..engine import Engine
@ -67,8 +64,12 @@ class ImageSoftmaxEngine(Engine):
use_gpu=True, use_gpu=True,
label_smooth=True label_smooth=True
): ):
super(ImageSoftmaxEngine, self super(ImageSoftmaxEngine, self).__init__(datamanager, use_gpu)
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
self.criterion = CrossEntropyLoss( self.criterion = CrossEntropyLoss(
num_classes=self.datamanager.num_train_pids, num_classes=self.datamanager.num_train_pids,
@ -76,91 +77,22 @@ class ImageSoftmaxEngine(Engine):
label_smooth=label_smooth label_smooth=label_smooth
) )
def train( def forward_backward(self, data):
self, imgs, pids = self._parse_data_for_train(data)
epoch,
max_epoch,
writer,
print_freq=10,
fixbase_epoch=0,
open_layers=None
):
losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.model.train() if self.use_gpu:
if (epoch + 1) <= fixbase_epoch and open_layers is not None: imgs = imgs.cuda()
print( pids = pids.cuda()
'* Only train {} (epoch: {}/{})'.format(
open_layers, epoch + 1, fixbase_epoch
)
)
open_specified_layers(self.model, open_layers)
else:
open_all_layers(self.model)
num_batches = len(self.train_loader) self.optimizer.zero_grad()
end = time.time() outputs = self.model(imgs)
for batch_idx, data in enumerate(self.train_loader): loss = self._compute_loss(self.criterion, outputs, pids)
data_time.update(time.time() - end) loss.backward()
self.optimizer.step()
imgs, pids = self._parse_data_for_train(data) loss_dict = {
if self.use_gpu: 'loss': loss.item(),
imgs = imgs.cuda() 'acc': metrics.accuracy(outputs, pids)[0].item()
pids = pids.cuda() }
self.optimizer.zero_grad() return loss_dict
outputs = self.model(imgs)
loss = self._compute_loss(self.criterion, outputs, pids)
loss.backward()
self.optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
accs.update(metrics.accuracy(outputs, pids)[0].item())
if (batch_idx+1) % print_freq == 0:
# estimate remaining time
eta_seconds = batch_time.avg * (
num_batches - (batch_idx+1) + (max_epoch -
(epoch+1)) * num_batches
)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
'Epoch: [{0}/{1}][{2}/{3}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
'Lr {lr:.6f}\t'
'eta {eta}'.format(
epoch + 1,
max_epoch,
batch_idx + 1,
num_batches,
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accs,
lr=self.optimizer.param_groups[0]['lr'],
eta=eta_str
)
)
if writer is not None:
n_iter = epoch*num_batches + batch_idx
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
writer.add_scalar('Train/Data', data_time.avg, n_iter)
writer.add_scalar('Train/Loss', losses.avg, n_iter)
writer.add_scalar('Train/Acc', accs.avg, n_iter)
writer.add_scalar(
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
)
end = time.time()
if self.scheduler is not None:
self.scheduler.step()

View File

@ -3,9 +3,6 @@ import time
import datetime import datetime
from torchreid import metrics from torchreid import metrics
from torchreid.utils import (
AverageMeter, open_all_layers, open_specified_layers
)
from torchreid.losses import TripletLoss, CrossEntropyLoss from torchreid.losses import TripletLoss, CrossEntropyLoss
from ..engine import Engine from ..engine import Engine
@ -76,8 +73,12 @@ class ImageTripletEngine(Engine):
use_gpu=True, use_gpu=True,
label_smooth=True label_smooth=True
): ):
super(ImageTripletEngine, self super(ImageTripletEngine, self).__init__(datamanager, use_gpu)
).__init__(datamanager, model, optimizer, scheduler, use_gpu)
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
self.weight_t = weight_t self.weight_t = weight_t
self.weight_x = weight_x self.weight_x = weight_x
@ -89,98 +90,25 @@ class ImageTripletEngine(Engine):
label_smooth=label_smooth label_smooth=label_smooth
) )
def train( def forward_backward(self, data):
self, imgs, pids = self._parse_data_for_train(data)
epoch,
max_epoch,
writer,
print_freq=10,
fixbase_epoch=0,
open_layers=None
):
losses_t = AverageMeter()
losses_x = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.model.train() if self.use_gpu:
if (epoch + 1) <= fixbase_epoch and open_layers is not None: imgs = imgs.cuda()
print( pids = pids.cuda()
'* Only train {} (epoch: {}/{})'.format(
open_layers, epoch + 1, fixbase_epoch
)
)
open_specified_layers(self.model, open_layers)
else:
open_all_layers(self.model)
num_batches = len(self.train_loader) self.optimizer.zero_grad()
end = time.time() outputs, features = self.model(imgs)
for batch_idx, data in enumerate(self.train_loader): loss_t = self._compute_loss(self.criterion_t, features, pids)
data_time.update(time.time() - end) loss_x = self._compute_loss(self.criterion_x, outputs, pids)
loss = self.weight_t * loss_t + self.weight_x * loss_x
loss.backward()
self.optimizer.step()
imgs, pids = self._parse_data_for_train(data) loss_dict = {
if self.use_gpu: 'loss_t': loss_t.item(),
imgs = imgs.cuda() 'loss_x': loss_x.item(),
pids = pids.cuda() 'acc': metrics.accuracy(outputs, pids)[0].item()
}
self.optimizer.zero_grad() return loss_dict
outputs, features = self.model(imgs)
loss_t = self._compute_loss(self.criterion_t, features, pids)
loss_x = self._compute_loss(self.criterion_x, outputs, pids)
loss = self.weight_t * loss_t + self.weight_x * loss_x
loss.backward()
self.optimizer.step()
batch_time.update(time.time() - end)
losses_t.update(loss_t.item(), pids.size(0))
losses_x.update(loss_x.item(), pids.size(0))
accs.update(metrics.accuracy(outputs, pids)[0].item())
if (batch_idx+1) % print_freq == 0:
# estimate remaining time
eta_seconds = batch_time.avg * (
num_batches - (batch_idx+1) + (max_epoch -
(epoch+1)) * num_batches
)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
'Epoch: [{0}/{1}][{2}/{3}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
'Lr {lr:.6f}\t'
'eta {eta}'.format(
epoch + 1,
max_epoch,
batch_idx + 1,
num_batches,
batch_time=batch_time,
data_time=data_time,
loss_t=losses_t,
loss_x=losses_x,
acc=accs,
lr=self.optimizer.param_groups[0]['lr'],
eta=eta_str
)
)
if writer is not None:
n_iter = epoch*num_batches + batch_idx
writer.add_scalar('Train/Time', batch_time.avg, n_iter)
writer.add_scalar('Train/Data', data_time.avg, n_iter)
writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter)
writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter)
writer.add_scalar('Train/Acc', accs.avg, n_iter)
writer.add_scalar(
'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
)
end = time.time()
if self.scheduler is not None:
self.scheduler.step()

View File

@ -13,25 +13,23 @@ class CrossEntropyLoss(nn.Module):
.. math:: .. math::
\begin{equation} \begin{equation}
(1 - \epsilon) \times y + \frac{\epsilon}{K}, (1 - \eps) \times y + \frac{\eps}{K},
\end{equation} \end{equation}
where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When where :math:`K` denotes the number of classes and :math:`\eps` is a weight. When
:math:`\epsilon = 0`, the loss function reduces to the normal cross entropy. :math:`\eps = 0`, the loss function reduces to the normal cross entropy.
Args: Args:
num_classes (int): number of classes. num_classes (int): number of classes.
epsilon (float, optional): weight. Default is 0.1. eps (float, optional): weight. Default is 0.1.
use_gpu (bool, optional): whether to use gpu devices. Default is True. use_gpu (bool, optional): whether to use gpu devices. Default is True.
label_smooth (bool, optional): whether to apply label smoothing. Default is True. label_smooth (bool, optional): whether to apply label smoothing. Default is True.
""" """
def __init__( def __init__(self, num_classes, eps=0.1, use_gpu=True, label_smooth=True):
self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True
):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.epsilon = epsilon if label_smooth else 0 self.eps = eps if label_smooth else 0
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1) self.logsoftmax = nn.LogSoftmax(dim=1)
@ -48,7 +46,5 @@ class CrossEntropyLoss(nn.Module):
targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1) targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
if self.use_gpu: if self.use_gpu:
targets = targets.cuda() targets = targets.cuda()
targets = ( targets = (1 - self.eps) * targets + self.eps / self.num_classes
1 - self.epsilon
) * targets + self.epsilon / self.num_classes
return (-targets * log_probs).mean(0).sum() return (-targets * log_probs).mean(0).sum()

View File

@ -1,6 +1,8 @@
from __future__ import division, absolute_import from __future__ import division, absolute_import
from collections import defaultdict
import torch
__all__ = ['AverageMeter'] __all__ = ['AverageMeter', 'MetricMeter']
class AverageMeter(object): class AverageMeter(object):
@ -27,3 +29,45 @@ class AverageMeter(object):
self.sum += val * n self.sum += val * n
self.count += n self.count += n
self.avg = self.sum / self.count self.avg = self.sum / self.count
class MetricMeter(object):
"""A collection of metrics.
Source: https://github.com/KaiyangZhou/Dassl.pytorch
Examples::
>>> # 1. Create an instance of MetricMeter
>>> metric = MetricMeter()
>>> # 2. Update using a dictionary as input
>>> input_dict = {'loss_1': value_1, 'loss_2': value_2}
>>> metric.update(input_dict)
>>> # 3. Convert to string and print
>>> print(str(metric))
"""
def __init__(self, delimiter='\t'):
self.meters = defaultdict(AverageMeter)
self.delimiter = delimiter
def update(self, input_dict):
if input_dict is None:
return
if not isinstance(input_dict, dict):
raise TypeError(
'Input to MetricMeter.update() must be a dictionary'
)
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
self.meters[k].update(v)
def __str__(self):
output_str = []
for name, meter in self.meters.items():
output_str.append(
'{} {:.4f} ({:.4f})'.format(name, meter.val, meter.avg)
)
return self.delimiter.join(output_str)