fast-reid/engine/trainer.py
2019-07-23 09:46:37 +08:00

127 lines
4.2 KiB
Python

# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from data.datasets.eval_reid import evaluate
from fastai.vision import *
@dataclass
class TrackValue(Callback):
log_path: Path
total_iter: int
def on_batch_end(self, num_batch, last_loss, **kwargs:Any):
if (num_batch+1) % (self.total_iter//3) == 0:
with self.log_path.open('a') as f:
f.write('Iter [{}/{}], loss: {:.4f}\n'.format(num_batch, self.total_iter, last_loss.item()))
def on_epoch_end(self, epoch, smooth_loss, **kwargs:Any):
with self.log_path.open('a') as f:
f.write('Epoch {}, loss: {:.4f}\n\n'.format(epoch, smooth_loss.item()))
@dataclass
class LRScheduler(Callback):
learn: Learner
lr_sched: Scheduler
def on_train_begin(self, **kwargs:Any):
self.opt = self.learn.opt
def on_epoch_begin(self, **kwargs:Any):
self.opt.lr = self.lr_sched.step()
class TestModel(LearnerCallback):
def __init__(self, learn: Learner, test_labels: Iterator, eval_period: int, num_query: int,
output_dir: Path, log_path: Path, norm=True):
super().__init__(learn)
self._test_dl = learn.data.test_dl
self._eval_period = eval_period
self._norm = norm
self._output_dir = output_dir
self._log_path = log_path
self._num_query = num_query
pids = []
camids = []
for i in test_labels:
pids.append(i[0])
camids.append(i[1])
self.q_pids = np.asarray(pids[:num_query])
self.q_camids = np.asarray(camids[:num_query])
self.g_pids = np.asarray(pids[num_query:])
self.g_camids = np.asarray(camids[num_query:])
def on_epoch_end(self, epoch, **kwargs: Any) -> None:
# test model performance
if (epoch + 1) % self._eval_period == 0:
print('Testing ...')
feats, pids, camids = [], [], []
self.learn.model.eval()
with torch.no_grad():
for imgs, _ in self._test_dl:
feat = self.learn.model(imgs)
feats.append(feat)
if self._norm:
feats = torch.norm(feats, p=2, dim=1)
feats = torch.cat(feats, dim=0)
# query
qf = feats[:self._num_query]
# gallery
gf = feats[self._num_query:]
m, n = qf.shape[0], gf.shape[0]
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = to_np(distmat)
cmc, mAP = evaluate(distmat, self.q_pids, self.g_pids, self.q_camids, self.g_camids)
print("Test Results - Epoch: {}".format(epoch + 1))
print("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
print("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
self.learn.save(self._output_dir / 'reid_model_{}'.format(epoch))
with open(self.log_path, 'a') as f:
f.write("Test Results - Epoch: {}\n mAP: {:.1%}\n".format(epoch + 1, mAP))
for r in [1, 5, 10]:
f.write("CMC curve, Rank-{:<3}:{:.1%}\n".format(r, cmc[r - 1]))
f.write('\n')
def do_train(
cfg,
log_path,
model,
data_bunch,
test_labels,
opt_func,
lr_sched,
loss_func,
num_query
):
eval_period = cfg.SOLVER.EVAL_PERIOD
output_dir = Path(cfg.OUTPUT_DIR)
epochs = cfg.SOLVER.MAX_EPOCHS
total_iter = len(data_bunch.train_dl)
print("Start training")
learn = Learner(
data_bunch,
model,
opt_func=opt_func,
loss_func=loss_func,
true_wd=False,
callback_fns=[
partial(LRScheduler, lr_sched=lr_sched),
partial(TestModel,test_labels=test_labels, eval_period=eval_period,
num_query=num_query, output_dir=output_dir, log_path=log_path)],
callbacks=[TrackValue(log_path, total_iter)])
learn.fit(epochs, lr=cfg.SOLVER.BASE_LR, wd=cfg.SOLVER.WEIGHT_DECAY)