mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
update isinstance(outputs, (tuple, list))
This commit is contained in:
parent
602971e4d7
commit
478da99655
@ -287,7 +287,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
|
||||
imgs, pids = imgs.cuda(), pids.cuda()
|
||||
|
||||
outputs = model(imgs)
|
||||
if isinstance(outputs, tuple):
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
loss = DeepSupervision(criterion, outputs, pids)
|
||||
else:
|
||||
loss = criterion(outputs, pids)
|
||||
|
@ -275,17 +275,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||
|
||||
outputs, features = model(imgs)
|
||||
if args.htri_only:
|
||||
if isinstance(features, tuple):
|
||||
if isinstance(features, (tuple, list)):
|
||||
loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
loss = criterion_htri(features, pids)
|
||||
else:
|
||||
if isinstance(outputs, tuple):
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
|
||||
if isinstance(features, tuple):
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
@ -284,7 +284,10 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
|
||||
imgs, pids = imgs.cuda(), pids.cuda()
|
||||
|
||||
outputs = model(imgs)
|
||||
loss = criterion(outputs, pids)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
loss = DeepSupervision(criterion, outputs, pids)
|
||||
else:
|
||||
loss = criterion(outputs, pids)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -272,17 +272,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
||||
|
||||
outputs, features = model(imgs)
|
||||
if args.htri_only:
|
||||
if isinstance(features, tuple):
|
||||
if isinstance(features, (tuple, list)):
|
||||
loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
loss = criterion_htri(features, pids)
|
||||
else:
|
||||
if isinstance(outputs, tuple):
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||
else:
|
||||
xent_loss = criterion_xent(outputs, pids)
|
||||
|
||||
if isinstance(features, tuple):
|
||||
if isinstance(features, (tuple, list)):
|
||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||
else:
|
||||
htri_loss = criterion_htri(features, pids)
|
||||
|
Loading…
x
Reference in New Issue
Block a user