deep-person-reid/projects/OSNet_AIN/softmax_nas.py

74 lines
2.1 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-11-08 21:00:39 +08:00
2019-12-01 10:35:44 +08:00
from torchreid import metrics
2019-11-08 21:00:39 +08:00
from torchreid.engine import Engine
from torchreid.losses import CrossEntropyLoss
class ImageSoftmaxNASEngine(Engine):
2019-12-01 10:35:44 +08:00
def __init__(
self,
datamanager,
model,
optimizer,
scheduler=None,
use_gpu=False,
label_smooth=True,
mc_iter=1,
init_lmda=1.,
min_lmda=1.,
lmda_decay_step=20,
lmda_decay_rate=0.5,
fixed_lmda=False
):
2020-04-16 19:46:15 +08:00
super(ImageSoftmaxNASEngine, self).__init__(datamanager, use_gpu)
2019-11-08 21:00:39 +08:00
self.mc_iter = mc_iter
self.init_lmda = init_lmda
self.min_lmda = min_lmda
self.lmda_decay_step = lmda_decay_step
self.lmda_decay_rate = lmda_decay_rate
self.fixed_lmda = fixed_lmda
2019-12-01 10:35:44 +08:00
2020-04-16 19:46:15 +08:00
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.register_model('model', model, optimizer, scheduler)
2019-11-08 21:00:39 +08:00
self.criterion = CrossEntropyLoss(
num_classes=self.datamanager.num_train_pids,
use_gpu=self.use_gpu,
label_smooth=label_smooth
)
2020-04-16 19:46:15 +08:00
def forward_backward(self, data):
imgs, pids = self.parse_data_for_train(data)
2019-11-08 21:00:39 +08:00
2020-04-16 19:46:15 +08:00
if self.use_gpu:
imgs = imgs.cuda()
pids = pids.cuda()
2019-12-01 10:35:44 +08:00
2020-04-16 19:46:15 +08:00
# softmax temporature
if self.fixed_lmda or self.lmda_decay_step == -1:
lmda = self.init_lmda
else:
lmda = self.init_lmda * self.lmda_decay_rate**(
2020-05-05 22:58:00 +08:00
self.epoch // self.lmda_decay_step
2020-04-16 19:46:15 +08:00
)
if lmda < self.min_lmda:
lmda = self.min_lmda
for k in range(self.mc_iter):
outputs = self.model(imgs, lmda=lmda)
loss = self.compute_loss(self.criterion, outputs, pids)
2020-04-16 19:46:15 +08:00
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
loss_dict = {
'loss': loss.item(),
'acc': metrics.accuracy(outputs, pids)[0].item()
}
return loss_dict