[Feature]: Add reconstruction loss (#365)
* [Feature]: Add pixel reconstruction loss * [Fix]: Fix lintpull/372/head
parent
353e3f7c58
commit
10d9539f67
|
@ -3,10 +3,12 @@ from .cae_loss import CAELoss
|
|||
from .cosine_similarity_loss import CosineSimilarityLoss
|
||||
from .cross_correlation_loss import CrossCorrelationLoss
|
||||
from .mae_loss import MAEReconstructionLoss
|
||||
from .reconstruction_loss import PixelReconstructionLoss
|
||||
from .simmim_loss import SimMIMReconstructionLoss
|
||||
from .swav_loss import SwAVLoss
|
||||
|
||||
__all__ = [
|
||||
'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss',
|
||||
'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss'
|
||||
'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss',
|
||||
'PixelReconstructionLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PixelReconstructionLoss(BaseModule):
|
||||
"""Loss for the reconstruction of pixel in Masked Image Modeling.
|
||||
|
||||
This module measures the distance between the target image and the
|
||||
reconstructed image and compute the loss to optimize the model. Currently,
|
||||
This module only provides L1 and L2 loss to penalize the reconstructed
|
||||
error. In addition, a mask can be passed in the ``forward`` function to
|
||||
only apply loss on visible region, like that in MAE.
|
||||
|
||||
Args:
|
||||
criterion (str): The loss the penalize the reconstructed error.
|
||||
Currently, only supports L1 and L2 loss
|
||||
channel (int, optional): The number of channels to average the
|
||||
reconstruction loss. If not None, the reconstruction loss
|
||||
will be divided by the channel. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: str, channel: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
if criterion == 'L1':
|
||||
self.penalty = torch.nn.L1Loss(reduction='none')
|
||||
elif criterion == 'L2':
|
||||
self.penalty = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise NotImplementedError(f'Currently, PixelReconstructionLoss \
|
||||
only supports L1 and L2 loss, but get {criterion}')
|
||||
|
||||
self.channel = channel if channel is not None else 1
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function to compute the reconstrction loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
loss = self.penalty(pred, target)
|
||||
|
||||
# if the dim of the loss is 3, take the average of the loss
|
||||
# along the last dim
|
||||
if len(loss.shape) == 3:
|
||||
loss = loss.mean(dim=-1)
|
||||
|
||||
loss = (loss * mask).sum() / mask.sum() / self.channel
|
||||
|
||||
return loss
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models import PixelReconstructionLoss
|
||||
|
||||
|
||||
def test_reconstruction_loss():
|
||||
|
||||
# test L2 loss
|
||||
loss_config = dict(criterion='L2')
|
||||
|
||||
fake_pred = torch.rand((2, 196, 768))
|
||||
fake_target = torch.rand((2, 196, 768))
|
||||
fake_mask = torch.ones((2, 196))
|
||||
|
||||
loss = PixelReconstructionLoss(**loss_config)
|
||||
loss_value = loss(fake_pred, fake_target, fake_mask)
|
||||
|
||||
assert isinstance(loss_value.item(), float)
|
||||
|
||||
# test L1 loss
|
||||
loss_config = dict(criterion='L1', channel=3)
|
||||
|
||||
fake_pred = torch.rand((2, 3, 192, 192))
|
||||
fake_target = torch.rand((2, 3, 192, 192))
|
||||
fake_mask = torch.ones((2, 1, 192, 192))
|
||||
|
||||
loss = PixelReconstructionLoss(**loss_config)
|
||||
loss_value = loss(fake_pred, fake_target, fake_mask)
|
||||
|
||||
assert isinstance(loss_value.item(), float)
|
Loading…
Reference in New Issue