update isinstance(outputs, (tuple, list))

This commit is contained in:
KaiyangZhou 2018-10-25 17:45:00 +01:00
parent 602971e4d7
commit 478da99655
4 changed files with 11 additions and 8 deletions

View File

@ -287,7 +287,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs)
if isinstance(outputs, tuple):
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)

View File

@ -275,17 +275,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, tuple):
if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids)
else:
loss = criterion_htri(features, pids)
else:
if isinstance(outputs, tuple):
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, tuple):
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)

View File

@ -284,7 +284,10 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs)
loss = criterion(outputs, pids)
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids)
optimizer.zero_grad()
loss.backward()
optimizer.step()

View File

@ -272,17 +272,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
outputs, features = model(imgs)
if args.htri_only:
if isinstance(features, tuple):
if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids)
else:
loss = criterion_htri(features, pids)
else:
if isinstance(outputs, tuple):
if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else:
xent_loss = criterion_xent(outputs, pids)
if isinstance(features, tuple):
if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids)
else:
htri_loss = criterion_htri(features, pids)