2019-12-01 10:35:44 +08:00
|
|
|
from __future__ import division, print_function, absolute_import
|
2019-11-08 21:00:39 +08:00
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
from torchreid import metrics
|
2019-11-08 21:00:39 +08:00
|
|
|
from torchreid.engine import Engine
|
|
|
|
from torchreid.losses import CrossEntropyLoss
|
|
|
|
|
|
|
|
|
|
|
|
class ImageSoftmaxNASEngine(Engine):
|
|
|
|
|
2019-12-01 10:35:44 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
datamanager,
|
|
|
|
model,
|
|
|
|
optimizer,
|
|
|
|
scheduler=None,
|
|
|
|
use_gpu=False,
|
|
|
|
label_smooth=True,
|
|
|
|
mc_iter=1,
|
|
|
|
init_lmda=1.,
|
|
|
|
min_lmda=1.,
|
|
|
|
lmda_decay_step=20,
|
|
|
|
lmda_decay_rate=0.5,
|
|
|
|
fixed_lmda=False
|
|
|
|
):
|
2020-04-16 19:46:15 +08:00
|
|
|
super(ImageSoftmaxNASEngine, self).__init__(datamanager, use_gpu)
|
2019-11-08 21:00:39 +08:00
|
|
|
self.mc_iter = mc_iter
|
|
|
|
self.init_lmda = init_lmda
|
|
|
|
self.min_lmda = min_lmda
|
|
|
|
self.lmda_decay_step = lmda_decay_step
|
|
|
|
self.lmda_decay_rate = lmda_decay_rate
|
|
|
|
self.fixed_lmda = fixed_lmda
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2020-04-16 19:46:15 +08:00
|
|
|
self.model = model
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.scheduler = scheduler
|
|
|
|
self.register_model('model', model, optimizer, scheduler)
|
|
|
|
|
2019-11-08 21:00:39 +08:00
|
|
|
self.criterion = CrossEntropyLoss(
|
|
|
|
num_classes=self.datamanager.num_train_pids,
|
|
|
|
use_gpu=self.use_gpu,
|
|
|
|
label_smooth=label_smooth
|
|
|
|
)
|
|
|
|
|
2020-04-16 19:46:15 +08:00
|
|
|
def forward_backward(self, data):
|
2020-05-05 20:41:10 +08:00
|
|
|
imgs, pids = self.parse_data_for_train(data)
|
2019-11-08 21:00:39 +08:00
|
|
|
|
2020-04-16 19:46:15 +08:00
|
|
|
if self.use_gpu:
|
|
|
|
imgs = imgs.cuda()
|
|
|
|
pids = pids.cuda()
|
2019-12-01 10:35:44 +08:00
|
|
|
|
2020-04-16 19:46:15 +08:00
|
|
|
# softmax temporature
|
|
|
|
if self.fixed_lmda or self.lmda_decay_step == -1:
|
|
|
|
lmda = self.init_lmda
|
|
|
|
else:
|
|
|
|
lmda = self.init_lmda * self.lmda_decay_rate**(
|
2020-05-05 22:58:00 +08:00
|
|
|
self.epoch // self.lmda_decay_step
|
2020-04-16 19:46:15 +08:00
|
|
|
)
|
|
|
|
if lmda < self.min_lmda:
|
|
|
|
lmda = self.min_lmda
|
|
|
|
|
|
|
|
for k in range(self.mc_iter):
|
|
|
|
outputs = self.model(imgs, lmda=lmda)
|
2020-05-05 20:41:10 +08:00
|
|
|
loss = self.compute_loss(self.criterion, outputs, pids)
|
2020-04-16 19:46:15 +08:00
|
|
|
self.optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
self.optimizer.step()
|
|
|
|
|
|
|
|
loss_dict = {
|
|
|
|
'loss': loss.item(),
|
|
|
|
'acc': metrics.accuracy(outputs, pids)[0].item()
|
|
|
|
}
|
|
|
|
|
|
|
|
return loss_dict
|