56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from mmengine.model import BaseModule
|
|
from torch import nn
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CosineSimilarityLoss(BaseModule):
|
|
"""Cosine similarity loss function.
|
|
|
|
Compute the similarity between two features and optimize that similarity as
|
|
loss.
|
|
|
|
Args:
|
|
shift_factor (float): The shift factor of cosine similarity.
|
|
Default: 0.0.
|
|
scale_factor (float): The scale factor of cosine similarity.
|
|
Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
shift_factor: float = 0.0,
|
|
scale_factor: float = 1.0) -> None:
|
|
super().__init__()
|
|
self.shift_factor = shift_factor
|
|
self.scale_factor = scale_factor
|
|
|
|
def forward(self,
|
|
pred: torch.Tensor,
|
|
target: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
"""Forward function of cosine similarity loss.
|
|
|
|
Args:
|
|
pred (torch.Tensor): The predicted features.
|
|
target (torch.Tensor): The target features.
|
|
|
|
Returns:
|
|
torch.Tensor: The cosine similarity loss.
|
|
"""
|
|
pred_norm = nn.functional.normalize(pred, dim=-1)
|
|
target_norm = nn.functional.normalize(target, dim=-1)
|
|
loss = self.shift_factor - self.scale_factor * (
|
|
pred_norm * target_norm).sum(dim=-1)
|
|
|
|
if mask is None:
|
|
loss = loss.mean()
|
|
else:
|
|
loss = (loss * mask).sum() / mask.sum()
|
|
return loss
|