[Enhance]: make total loss at the end of all losses (#369)

This commit is contained in:
Mashiro 2022-10-09 15:37:59 +08:00 committed by GitHub
parent 8f91cf0be8
commit b5fd38ab8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -159,20 +159,23 @@ class BaseModel(BaseModule):
all losses, and the second is log_vars which will be sent to the
logger.
"""
log_vars = OrderedDict()
log_vars = []
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
log_vars['loss'] = loss
loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore
return loss, log_vars
return loss, log_vars # type: ignore
def to(self,
device: Optional[Union[int, str, torch.device]] = None,