51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def reduce_loss(loss, reduction):
|
||
|
"""Reduce loss as specified.
|
||
|
|
||
|
Args:
|
||
|
loss (Tensor): Elementwise loss tensor.
|
||
|
reduction (str): Options are "none", "mean" and "sum".
|
||
|
|
||
|
Return:
|
||
|
Tensor: Reduced loss tensor.
|
||
|
"""
|
||
|
reduction_enum = F._Reduction.get_enum(reduction)
|
||
|
# none: 0, elementwise_mean:1, sum: 2
|
||
|
if reduction_enum == 0:
|
||
|
return loss
|
||
|
elif reduction_enum == 1:
|
||
|
return loss.mean()
|
||
|
elif reduction_enum == 2:
|
||
|
return loss.sum()
|
||
|
|
||
|
|
||
|
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
||
|
"""Apply element-wise weight and reduce loss.
|
||
|
|
||
|
Args:
|
||
|
loss (Tensor): Element-wise loss.
|
||
|
weight (Tensor): Element-wise weights.
|
||
|
reduction (str): Same as built-in losses of PyTorch.
|
||
|
avg_factor (float): Avarage factor when computing the mean of losses.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Processed loss values.
|
||
|
"""
|
||
|
# if weight is specified, apply element-wise weight
|
||
|
if weight is not None:
|
||
|
loss = loss * weight
|
||
|
|
||
|
# if avg_factor is not specified, just reduce the loss
|
||
|
if avg_factor is None:
|
||
|
loss = reduce_loss(loss, reduction)
|
||
|
else:
|
||
|
# if reduction is mean, then average the loss by avg_factor
|
||
|
if reduction == 'mean':
|
||
|
loss = loss.sum() / avg_factor
|
||
|
# if reduction is 'none', then do nothing, otherwise raise an error
|
||
|
elif reduction != 'none':
|
||
|
raise ValueError('avg_factor can not be used with reduction="sum"')
|
||
|
return loss
|