[Refactor]: Change the interface of SimMIM

This commit is contained in:
liuyuan1.vendor 2022-05-16 09:32:33 +00:00 committed by fangyixiao18
parent e87be11a98
commit e687aff595
2 changed files with 83 additions and 21 deletions

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmselfsup.core import SelfSupDataSample
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
@ -15,19 +16,21 @@ class SimMIM(BaseModel):
<https://arxiv.org/abs/2111.09886>`_.
Args:
backbone (dict): Config dict for encoder. Defaults to None.
neck (dict): Config dict for encoder. Defaults to None.
head (dict): Config dict for loss functions. Defaults to None.
init_cfg (dict, optional): Config dict for weight initialization.
Defaults to None.
backbone (Dict): Config dict for encoder. Defaults to None.
neck (Dict): Config dict for encoder. Defaults to None.
head (Dict): Config dict for loss functions. Defaults to None.
preprocess_cfg (Dict): Config to preprocess images.
init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
init_cfg: Optional[dict] = None) -> None:
super(SimMIM, self).__init__(init_cfg)
backbone: Dict,
neck: Dict,
head: Dict,
preprocess_cfg: Dict,
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
@ -35,27 +38,39 @@ class SimMIM(BaseModel):
assert head is not None
self.head = build_head(head)
def extract_feat(self, img: torch.Tensor) -> tuple:
"""Function to extract features from backbone.
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features.
Args:
img (torch.Tensor): Input images of shape (N, C, H, W).
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
tuple[Tensor]: Latent representations of images.
Tuple[torch.Tensor]: backbone outputs.
"""
return self.backbone(img)
mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
return self.backbone(inputs[0], mask)
def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict:
"""Forward the masked image and get the reconstruction loss.
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
x (List[torch.Tensor, torch.Tensor]): Images and masks.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
dict: Reconstructed loss.
Dict[str, Tensor]: A dictionary of loss components.
"""
img, mask = x
mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
img = inputs[0]
img_latent = self.backbone(img, mask)
img_rec = self.neck(img_latent[0])

View File

@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmengine.data import BaseDataElement as PixelData
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import SimMIM
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_simmim():
# model config
model_config = dict(
backbone=dict(
type='SimMIMSwinTransformer',
arch='B',
img_size=192,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(
type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
head=dict(type='SimMIMHead', patch_size=4, encoder_in_channels=3),
preprocess_cfg={
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
})
model = SimMIM(**model_config)
# test forward_train
fake_data_sample = SelfSupDataSample()
fake_mask = PixelData(value=torch.rand((48, 48)))
fake_data_sample.mask = fake_mask
fake_data = [{
'inputs': [torch.randn((3, 192, 192))],
'data_sample': fake_data_sample
} for _ in range(2)]
outputs = model(fake_data, return_loss=True)
assert isinstance(outputs['loss'], torch.Tensor)
# test extract_feat
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
fake_feat = model.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 1024, 6, 6]