mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from mmselfsup.registry import MODELS
|
|
from mmselfsup.structures import SelfSupDataSample
|
|
from .base import BaseModel
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SimMIM(BaseModel):
|
|
"""SimMIM.
|
|
|
|
Implementation of `SimMIM: A Simple Framework for Masked Image Modeling
|
|
<https://arxiv.org/abs/2111.09886>`_.
|
|
"""
|
|
|
|
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
|
data_samples: List[SelfSupDataSample],
|
|
**kwarg) -> torch.Tensor:
|
|
"""The forward function to extract features.
|
|
|
|
Args:
|
|
batch_inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[SelfSupDataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
torch.Tensor: The reconstructed image.
|
|
"""
|
|
mask = torch.stack(
|
|
[data_sample.mask.value for data_sample in data_samples])
|
|
img_latent = self.backbone(batch_inputs[0], mask)
|
|
feat = self.neck(img_latent[0])
|
|
return feat
|
|
|
|
def loss(self, batch_inputs: List[torch.Tensor],
|
|
data_samples: List[SelfSupDataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
batch_inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[SelfSupDataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
Dict[str, Tensor]: A dictionary of loss components.
|
|
"""
|
|
mask = torch.stack(
|
|
[data_sample.mask.value for data_sample in data_samples])
|
|
img = batch_inputs[0]
|
|
|
|
img_latent = self.backbone(img, mask)
|
|
img_rec = self.neck(img_latent[0])
|
|
loss = self.head(img_rec, img, mask)
|
|
losses = dict(loss=loss)
|
|
|
|
return losses
|