mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1586 from lorenzbaraldi/eval_loss
Put validation loss under amp_autocast
This commit is contained in:
commit
f266f841a0
16
train.py
16
train.py
@ -970,16 +970,16 @@ def validate(
|
|||||||
|
|
||||||
with amp_autocast():
|
with amp_autocast():
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if isinstance(output, (tuple, list)):
|
if isinstance(output, (tuple, list)):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
# augmentation reduction
|
# augmentation reduction
|
||||||
reduce_factor = args.tta
|
reduce_factor = args.tta
|
||||||
if reduce_factor > 1:
|
if reduce_factor > 1:
|
||||||
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
||||||
target = target[0:target.size(0):reduce_factor]
|
target = target[0:target.size(0):reduce_factor]
|
||||||
|
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
|
@ -294,9 +294,9 @@ def validate(args):
|
|||||||
with amp_autocast():
|
with amp_autocast():
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
if valid_labels is not None:
|
if valid_labels is not None:
|
||||||
output = output[:, valid_labels]
|
output = output[:, valid_labels]
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
if real_labels is not None:
|
if real_labels is not None:
|
||||||
real_labels.add_result(output)
|
real_labels.add_result(output)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user