74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
import time
|
|
import datetime
|
|
|
|
import torch
|
|
|
|
import torchreid
|
|
from torchreid.engine import engine
|
|
from torchreid.losses import CrossEntropyLoss
|
|
from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers
|
|
from torchreid import metrics
|
|
|
|
|
|
class ImageSoftmaxEngine(engine.Engine):
|
|
|
|
def __init__(self, dataset, model, optimizer, scheduler=None, use_cpu=False,
|
|
label_smooth=True):
|
|
super(ImageSoftmaxEngine, self).__init__(dataset, model, optimizer, scheduler, use_cpu)
|
|
|
|
self.criterion = CrossEntropyLoss(
|
|
num_classes=self.dataset.num_train_pids,
|
|
use_gpu=self.use_gpu,
|
|
label_smooth=label_smooth
|
|
)
|
|
|
|
def train(self, epoch, trainloader, fixbase=False, open_layers=None, print_freq=10):
|
|
losses = AverageMeter()
|
|
accs = AverageMeter()
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
|
|
self.model.train()
|
|
|
|
if fixbase and (open_layers is not None):
|
|
open_specified_layers(self.model, open_layers)
|
|
else:
|
|
open_all_layers(self.model)
|
|
|
|
end = time.time()
|
|
for batch_idx, data in enumerate(trainloader):
|
|
data_time.update(time.time() - end)
|
|
|
|
imgs, pids = self._parse_data_for_train(data)
|
|
if self.use_gpu:
|
|
imgs = imgs.cuda()
|
|
pids = pids.cuda()
|
|
|
|
self.optimizer.zero_grad()
|
|
outputs = self.model(imgs)
|
|
loss = self._compute_loss(self.criterion, outputs, pids)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
losses.update(loss.item(), pids.size(0))
|
|
accs.update(metrics.accuracy(outputs, pids)[0].item())
|
|
|
|
if (batch_idx+1) % print_freq==0:
|
|
print('Epoch: [{0}][{1}/{2}]\t'
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
|
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'.format(
|
|
epoch + 1, batch_idx + 1, len(trainloader),
|
|
batch_time=batch_time,
|
|
data_time=data_time,
|
|
loss=losses,
|
|
acc=accs))
|
|
|
|
end = time.time() |