mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance]: make total loss at the end of all losses (#369)
This commit is contained in:
parent
8f91cf0be8
commit
b5fd38ab8c
@ -159,20 +159,23 @@ class BaseModel(BaseModule):
|
||||
all losses, and the second is log_vars which will be sent to the
|
||||
logger.
|
||||
"""
|
||||
log_vars = OrderedDict()
|
||||
log_vars = []
|
||||
for loss_name, loss_value in losses.items():
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
log_vars[loss_name] = loss_value.mean()
|
||||
log_vars.append([loss_name, loss_value.mean()])
|
||||
elif is_list_of(loss_value, torch.Tensor):
|
||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
||||
log_vars.append(
|
||||
[loss_name,
|
||||
sum(_loss.mean() for _loss in loss_value)])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{loss_name} is not a tensor or list of tensors')
|
||||
|
||||
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
|
||||
log_vars['loss'] = loss
|
||||
loss = sum(value for key, value in log_vars if 'loss' in key)
|
||||
log_vars.insert(0, ['loss', loss])
|
||||
log_vars = OrderedDict(log_vars) # type: ignore
|
||||
|
||||
return loss, log_vars
|
||||
return loss, log_vars # type: ignore
|
||||
|
||||
def to(self,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user