EasyCV/easycv/models/loss/mse_loss.py

46 lines
1.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Adapt from https://github.com/open-mmlab/mmpose/blob/master/mmpose/models/losses/mse_loss.py
import torch.nn as nn
from ..registry import LOSSES
@LOSSES.register_module()
class JointsMSELoss(nn.Module):
"""MSE loss for heatmaps.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss()
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0.
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
loss += self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx])
else:
loss += self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints * self.loss_weight