mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance]: make total loss at the end of all losses (#369)
This commit is contained in:
parent
8f91cf0be8
commit
b5fd38ab8c
@ -159,20 +159,23 @@ class BaseModel(BaseModule):
|
|||||||
all losses, and the second is log_vars which will be sent to the
|
all losses, and the second is log_vars which will be sent to the
|
||||||
logger.
|
logger.
|
||||||
"""
|
"""
|
||||||
log_vars = OrderedDict()
|
log_vars = []
|
||||||
for loss_name, loss_value in losses.items():
|
for loss_name, loss_value in losses.items():
|
||||||
if isinstance(loss_value, torch.Tensor):
|
if isinstance(loss_value, torch.Tensor):
|
||||||
log_vars[loss_name] = loss_value.mean()
|
log_vars.append([loss_name, loss_value.mean()])
|
||||||
elif is_list_of(loss_value, torch.Tensor):
|
elif is_list_of(loss_value, torch.Tensor):
|
||||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
log_vars.append(
|
||||||
|
[loss_name,
|
||||||
|
sum(_loss.mean() for _loss in loss_value)])
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f'{loss_name} is not a tensor or list of tensors')
|
f'{loss_name} is not a tensor or list of tensors')
|
||||||
|
|
||||||
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
|
loss = sum(value for key, value in log_vars if 'loss' in key)
|
||||||
log_vars['loss'] = loss
|
log_vars.insert(0, ['loss', loss])
|
||||||
|
log_vars = OrderedDict(log_vars) # type: ignore
|
||||||
|
|
||||||
return loss, log_vars
|
return loss, log_vars # type: ignore
|
||||||
|
|
||||||
def to(self,
|
def to(self,
|
||||||
device: Optional[Union[int, str, torch.device]] = None,
|
device: Optional[Union[int, str, torch.device]] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user