make projects compatible with the latest torchreid

pull/345/head
KaiyangZhou 2020-05-05 13:41:10 +01:00
parent d60ca7f224
commit 2a9f44af9b
2 changed files with 9 additions and 9 deletions

View File

@ -56,19 +56,19 @@ class ImageDMLEngine(Engine):
) )
def forward_backward(self, data): def forward_backward(self, data):
imgs, pids = self._parse_data_for_train(data) imgs, pids = self.parse_data_for_train(data)
if self.use_gpu: if self.use_gpu:
imgs = imgs.cuda() imgs = imgs.cuda()
pids = pids.cuda() pids = pids.cuda()
outputs1, features1 = self.model1(imgs) outputs1, features1 = self.model1(imgs)
loss1_x = self._compute_loss(self.criterion_x, outputs1, pids) loss1_x = self.compute_loss(self.criterion_x, outputs1, pids)
loss1_t = self._compute_loss(self.criterion_t, features1, pids) loss1_t = self.compute_loss(self.criterion_t, features1, pids)
outputs2, features2 = self.model2(imgs) outputs2, features2 = self.model2(imgs)
loss2_x = self._compute_loss(self.criterion_x, outputs2, pids) loss2_x = self.compute_loss(self.criterion_x, outputs2, pids)
loss2_t = self._compute_loss(self.criterion_t, features2, pids) loss2_t = self.compute_loss(self.criterion_t, features2, pids)
loss1_ml = self.compute_kl_div( loss1_ml = self.compute_kl_div(
outputs2.detach(), outputs1, is_logit=True outputs2.detach(), outputs1, is_logit=True
@ -113,7 +113,7 @@ class ImageDMLEngine(Engine):
q = F.softmax(q, dim=1) q = F.softmax(q, dim=1)
return -(p * torch.log(q + 1e-8)).sum(1).mean() return -(p * torch.log(q + 1e-8)).sum(1).mean()
def _two_stepped_transfer_learning( def two_stepped_transfer_learning(
self, epoch, fixbase_epoch, open_layers, model=None self, epoch, fixbase_epoch, open_layers, model=None
): ):
"""Two stepped transfer learning. """Two stepped transfer learning.
@ -138,7 +138,7 @@ class ImageDMLEngine(Engine):
open_all_layers(model1) open_all_layers(model1)
open_all_layers(model2) open_all_layers(model2)
def _extract_features(self, input): def extract_features(self, input):
if self.deploy == 'model1': if self.deploy == 'model1':
return self.model1(input) return self.model1(input)

View File

@ -47,7 +47,7 @@ class ImageSoftmaxNASEngine(Engine):
) )
def forward_backward(self, data): def forward_backward(self, data):
imgs, pids = self._parse_data_for_train(data) imgs, pids = self.parse_data_for_train(data)
if self.use_gpu: if self.use_gpu:
imgs = imgs.cuda() imgs = imgs.cuda()
@ -65,7 +65,7 @@ class ImageSoftmaxNASEngine(Engine):
for k in range(self.mc_iter): for k in range(self.mc_iter):
outputs = self.model(imgs, lmda=lmda) outputs = self.model(imgs, lmda=lmda)
loss = self._compute_loss(self.criterion, outputs, pids) loss = self.compute_loss(self.criterion, outputs, pids)
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()