[Feature]: Add reconstruction loss (#365)

* [Feature]: Add pixel reconstruction loss

* [Fix]: Fix lint
pull/372/head
Yuan Liu 2022-07-22 17:28:50 +08:00 committed by GitHub
parent 353e3f7c58
commit 10d9539f67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 96 additions and 1 deletions

View File

@ -3,10 +3,12 @@ from .cae_loss import CAELoss
from .cosine_similarity_loss import CosineSimilarityLoss
from .cross_correlation_loss import CrossCorrelationLoss
from .mae_loss import MAEReconstructionLoss
from .reconstruction_loss import PixelReconstructionLoss
from .simmim_loss import SimMIMReconstructionLoss
from .swav_loss import SwAVLoss
__all__ = [
'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss',
'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss'
'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss',
'PixelReconstructionLoss'
]

View File

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from mmselfsup.registry import MODELS
@MODELS.register_module()
class PixelReconstructionLoss(BaseModule):
"""Loss for the reconstruction of pixel in Masked Image Modeling.
This module measures the distance between the target image and the
reconstructed image and compute the loss to optimize the model. Currently,
This module only provides L1 and L2 loss to penalize the reconstructed
error. In addition, a mask can be passed in the ``forward`` function to
only apply loss on visible region, like that in MAE.
Args:
criterion (str): The loss the penalize the reconstructed error.
Currently, only supports L1 and L2 loss
channel (int, optional): The number of channels to average the
reconstruction loss. If not None, the reconstruction loss
will be divided by the channel. Defaults to None.
"""
def __init__(self, criterion: str, channel: Optional[int] = None) -> None:
super().__init__()
if criterion == 'L1':
self.penalty = torch.nn.L1Loss(reduction='none')
elif criterion == 'L2':
self.penalty = torch.nn.MSELoss(reduction='none')
else:
raise NotImplementedError(f'Currently, PixelReconstructionLoss \
only supports L1 and L2 loss, but get {criterion}')
self.channel = channel if channel is not None else 1
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function to compute the reconstrction loss.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
loss = self.penalty(pred, target)
# if the dim of the loss is 3, take the average of the loss
# along the last dim
if len(loss.shape) == 3:
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum() / self.channel
return loss

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmselfsup.models import PixelReconstructionLoss
def test_reconstruction_loss():
# test L2 loss
loss_config = dict(criterion='L2')
fake_pred = torch.rand((2, 196, 768))
fake_target = torch.rand((2, 196, 768))
fake_mask = torch.ones((2, 196))
loss = PixelReconstructionLoss(**loss_config)
loss_value = loss(fake_pred, fake_target, fake_mask)
assert isinstance(loss_value.item(), float)
# test L1 loss
loss_config = dict(criterion='L1', channel=3)
fake_pred = torch.rand((2, 3, 192, 192))
fake_target = torch.rand((2, 3, 192, 192))
fake_mask = torch.ones((2, 1, 192, 192))
loss = PixelReconstructionLoss(**loss_config)
loss_value = loss(fake_pred, fake_target, fake_mask)
assert isinstance(loss_value.item(), float)