[Refactor]: Change the interface of SimMIM

This commit is contained in:
liuyuan1.vendor 2022-05-16 09:32:33 +00:00 committed by fangyixiao18
parent e87be11a98
commit e687aff595
2 changed files with 83 additions and 21 deletions

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from mmselfsup.core import SelfSupDataSample
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel from .base import BaseModel
@ -15,19 +16,21 @@ class SimMIM(BaseModel):
<https://arxiv.org/abs/2111.09886>`_. <https://arxiv.org/abs/2111.09886>`_.
Args: Args:
backbone (dict): Config dict for encoder. Defaults to None. backbone (Dict): Config dict for encoder. Defaults to None.
neck (dict): Config dict for encoder. Defaults to None. neck (Dict): Config dict for encoder. Defaults to None.
head (dict): Config dict for loss functions. Defaults to None. head (Dict): Config dict for loss functions. Defaults to None.
init_cfg (dict, optional): Config dict for weight initialization. preprocess_cfg (Dict): Config to preprocess images.
Defaults to None. init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
initialization. Defaults to None.
""" """
def __init__(self, def __init__(self,
backbone: dict, backbone: Dict,
neck: dict, neck: Dict,
head: dict, head: Dict,
init_cfg: Optional[dict] = None) -> None: preprocess_cfg: Dict,
super(SimMIM, self).__init__(init_cfg) init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None assert backbone is not None
self.backbone = build_backbone(backbone) self.backbone = build_backbone(backbone)
assert neck is not None assert neck is not None
@ -35,27 +38,39 @@ class SimMIM(BaseModel):
assert head is not None assert head is not None
self.head = build_head(head) self.head = build_head(head)
def extract_feat(self, img: torch.Tensor) -> tuple: def extract_feat(self, inputs: List[torch.Tensor],
"""Function to extract features from backbone. data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features.
Args: Args:
img (torch.Tensor): Input images of shape (N, C, H, W). inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
tuple[Tensor]: Latent representations of images. Tuple[torch.Tensor]: backbone outputs.
""" """
return self.backbone(img) mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
return self.backbone(inputs[0], mask)
def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict: def forward_train(self, inputs: List[torch.Tensor],
"""Forward the masked image and get the reconstruction loss. data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args: Args:
x (List[torch.Tensor, torch.Tensor]): Images and masks. inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
dict: Reconstructed loss. Dict[str, Tensor]: A dictionary of loss components.
""" """
img, mask = x mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
img = inputs[0]
img_latent = self.backbone(img, mask) img_latent = self.backbone(img, mask)
img_rec = self.neck(img_latent[0]) img_rec = self.neck(img_latent[0])

View File

@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmengine.data import BaseDataElement as PixelData
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import SimMIM
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_simmim():
# model config
model_config = dict(
backbone=dict(
type='SimMIMSwinTransformer',
arch='B',
img_size=192,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(
type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
head=dict(type='SimMIMHead', patch_size=4, encoder_in_channels=3),
preprocess_cfg={
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
})
model = SimMIM(**model_config)
# test forward_train
fake_data_sample = SelfSupDataSample()
fake_mask = PixelData(value=torch.rand((48, 48)))
fake_data_sample.mask = fake_mask
fake_data = [{
'inputs': [torch.randn((3, 192, 192))],
'data_sample': fake_data_sample
} for _ in range(2)]
outputs = model(fake_data, return_loss=True)
assert isinstance(outputs['loss'], torch.Tensor)
# test extract_feat
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
fake_feat = model.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 1024, 6, 6]