mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
update isinstance(outputs, (tuple, list))
This commit is contained in:
parent
602971e4d7
commit
478da99655
@ -287,7 +287,7 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
|
|||||||
imgs, pids = imgs.cuda(), pids.cuda()
|
imgs, pids = imgs.cuda(), pids.cuda()
|
||||||
|
|
||||||
outputs = model(imgs)
|
outputs = model(imgs)
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, (tuple, list)):
|
||||||
loss = DeepSupervision(criterion, outputs, pids)
|
loss = DeepSupervision(criterion, outputs, pids)
|
||||||
else:
|
else:
|
||||||
loss = criterion(outputs, pids)
|
loss = criterion(outputs, pids)
|
||||||
|
@ -275,17 +275,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
|||||||
|
|
||||||
outputs, features = model(imgs)
|
outputs, features = model(imgs)
|
||||||
if args.htri_only:
|
if args.htri_only:
|
||||||
if isinstance(features, tuple):
|
if isinstance(features, (tuple, list)):
|
||||||
loss = DeepSupervision(criterion_htri, features, pids)
|
loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
loss = criterion_htri(features, pids)
|
loss = criterion_htri(features, pids)
|
||||||
else:
|
else:
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, (tuple, list)):
|
||||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||||
else:
|
else:
|
||||||
xent_loss = criterion_xent(outputs, pids)
|
xent_loss = criterion_xent(outputs, pids)
|
||||||
|
|
||||||
if isinstance(features, tuple):
|
if isinstance(features, (tuple, list)):
|
||||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
|
@ -284,6 +284,9 @@ def train(epoch, model, criterion, optimizer, trainloader, use_gpu, freeze_bn=Fa
|
|||||||
imgs, pids = imgs.cuda(), pids.cuda()
|
imgs, pids = imgs.cuda(), pids.cuda()
|
||||||
|
|
||||||
outputs = model(imgs)
|
outputs = model(imgs)
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
loss = DeepSupervision(criterion, outputs, pids)
|
||||||
|
else:
|
||||||
loss = criterion(outputs, pids)
|
loss = criterion(outputs, pids)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -272,17 +272,17 @@ def train(epoch, model, criterion_xent, criterion_htri, optimizer, trainloader,
|
|||||||
|
|
||||||
outputs, features = model(imgs)
|
outputs, features = model(imgs)
|
||||||
if args.htri_only:
|
if args.htri_only:
|
||||||
if isinstance(features, tuple):
|
if isinstance(features, (tuple, list)):
|
||||||
loss = DeepSupervision(criterion_htri, features, pids)
|
loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
loss = criterion_htri(features, pids)
|
loss = criterion_htri(features, pids)
|
||||||
else:
|
else:
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, (tuple, list)):
|
||||||
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
xent_loss = DeepSupervision(criterion_xent, outputs, pids)
|
||||||
else:
|
else:
|
||||||
xent_loss = criterion_xent(outputs, pids)
|
xent_loss = criterion_xent(outputs, pids)
|
||||||
|
|
||||||
if isinstance(features, tuple):
|
if isinstance(features, (tuple, list)):
|
||||||
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
htri_loss = DeepSupervision(criterion_htri, features, pids)
|
||||||
else:
|
else:
|
||||||
htri_loss = criterion_htri(features, pids)
|
htri_loss = criterion_htri(features, pids)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user