update isinstance(outputs, (tuple, list))

This commit is contained in:
KaiyangZhou 2018-10-25 17:45:00 +01:00
parent 602971e4d7
commit 478da99655
4 changed files with 11 additions and 8 deletions

View File

@ -287,7 +287,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs) outputs = model(imgs)
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids) loss = DeepSupervision(criterion, outputs, pids)
else: else:
loss = criterion(outputs, pids) loss = criterion(outputs, pids)

View File

@ -275,17 +275,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
outputs, features = model(imgs) outputs, features = model(imgs)
if args.htri_only: if args.htri_only:
if isinstance(features, tuple): if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids) loss = DeepSupervision(criterion_htri, features, pids)
else: else:
loss = criterion_htri(features, pids) loss = criterion_htri(features, pids)
else: else:
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids) xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else: else:
xent_loss = criterion_xent(outputs, pids) xent_loss = criterion_xent(outputs, pids)
if isinstance(features, tuple): if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids) htri_loss = DeepSupervision(criterion_htri, features, pids)
else: else:
htri_loss = criterion_htri(features, pids) htri_loss = criterion_htri(features, pids)

View File

@ -284,6 +284,9 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
imgs, pids = imgs.cuda(), pids.cuda() imgs, pids = imgs.cuda(), pids.cuda()
outputs = model(imgs) outputs = model(imgs)
if isinstance(outputs, (tuple, list)):
loss = DeepSupervision(criterion, outputs, pids)
else:
loss = criterion(outputs, pids) loss = criterion(outputs, pids)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()

View File

@ -272,17 +272,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
outputs, features = model(imgs) outputs, features = model(imgs)
if args.htri_only: if args.htri_only:
if isinstance(features, tuple): if isinstance(features, (tuple, list)):
loss = DeepSupervision(criterion_htri, features, pids) loss = DeepSupervision(criterion_htri, features, pids)
else: else:
loss = criterion_htri(features, pids) loss = criterion_htri(features, pids)
else: else:
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
xent_loss = DeepSupervision(criterion_xent, outputs, pids) xent_loss = DeepSupervision(criterion_xent, outputs, pids)
else: else:
xent_loss = criterion_xent(outputs, pids) xent_loss = criterion_xent(outputs, pids)
if isinstance(features, tuple): if isinstance(features, (tuple, list)):
htri_loss = DeepSupervision(criterion_htri, features, pids) htri_loss = DeepSupervision(criterion_htri, features, pids)
else: else:
htri_loss = criterion_htri(features, pids) htri_loss = criterion_htri(features, pids)