From e687aff59531c9a01523a5cf4060adb27fea27d0 Mon Sep 17 00:00:00 2001 From: "liuyuan1.vendor" Date: Mon, 16 May 2022 09:32:33 +0000 Subject: [PATCH] [Refactor]: Change the interface of SimMIM --- mmselfsup/models/algorithms/simmim.py | 57 ++++++++++++------- .../test_algorithms/test_simmim.py | 47 +++++++++++++++ 2 files changed, 83 insertions(+), 21 deletions(-) create mode 100644 tests/test_models/test_algorithms/test_simmim.py diff --git a/mmselfsup/models/algorithms/simmim.py b/mmselfsup/models/algorithms/simmim.py index 1e33ac89..ae0e80de 100644 --- a/mmselfsup/models/algorithms/simmim.py +++ b/mmselfsup/models/algorithms/simmim.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import Dict, List, Optional, Tuple, Union import torch +from mmselfsup.core import SelfSupDataSample from ..builder import ALGORITHMS, build_backbone, build_head, build_neck from .base import BaseModel @@ -15,19 +16,21 @@ class SimMIM(BaseModel): `_. Args: - backbone (dict): Config dict for encoder. Defaults to None. - neck (dict): Config dict for encoder. Defaults to None. - head (dict): Config dict for loss functions. Defaults to None. - init_cfg (dict, optional): Config dict for weight initialization. - Defaults to None. + backbone (Dict): Config dict for encoder. Defaults to None. + neck (Dict): Config dict for encoder. Defaults to None. + head (Dict): Config dict for loss functions. Defaults to None. + preprocess_cfg (Dict): Config to preprocess images. + init_cfg (Union[List[Dict], Dict], optional): Config dict for weight + initialization. Defaults to None. """ def __init__(self, - backbone: dict, - neck: dict, - head: dict, - init_cfg: Optional[dict] = None) -> None: - super(SimMIM, self).__init__(init_cfg) + backbone: Dict, + neck: Dict, + head: Dict, + preprocess_cfg: Dict, + init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None: + super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) assert backbone is not None self.backbone = build_backbone(backbone) assert neck is not None @@ -35,27 +38,39 @@ class SimMIM(BaseModel): assert head is not None self.head = build_head(head) - def extract_feat(self, img: torch.Tensor) -> tuple: - """Function to extract features from backbone. + def extract_feat(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwarg) -> Tuple[torch.Tensor]: + """The forward function to extract features. Args: - img (torch.Tensor): Input images of shape (N, C, H, W). + inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. Returns: - tuple[Tensor]: Latent representations of images. + Tuple[torch.Tensor]: backbone outputs. """ - return self.backbone(img) + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + return self.backbone(inputs[0], mask) - def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict: - """Forward the masked image and get the reconstruction loss. + def forward_train(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. Args: - x (List[torch.Tensor, torch.Tensor]): Images and masks. + inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. Returns: - dict: Reconstructed loss. + Dict[str, Tensor]: A dictionary of loss components. """ - img, mask = x + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + img = inputs[0] img_latent = self.backbone(img, mask) img_rec = self.neck(img_latent[0]) diff --git a/tests/test_models/test_algorithms/test_simmim.py b/tests/test_models/test_algorithms/test_simmim.py new file mode 100644 index 00000000..d1e3e5a1 --- /dev/null +++ b/tests/test_models/test_algorithms/test_simmim.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch +from mmengine.data import BaseDataElement as PixelData + +from mmselfsup.core import SelfSupDataSample +from mmselfsup.models.algorithms import SimMIM + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_simmim(): + + # model config + model_config = dict( + backbone=dict( + type='SimMIMSwinTransformer', + arch='B', + img_size=192, + stage_cfgs=dict(block_cfgs=dict(window_size=6))), + neck=dict( + type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32), + head=dict(type='SimMIMHead', patch_size=4, encoder_in_channels=3), + preprocess_cfg={ + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'to_rgb': True + }) + model = SimMIM(**model_config) + + # test forward_train + fake_data_sample = SelfSupDataSample() + fake_mask = PixelData(value=torch.rand((48, 48))) + fake_data_sample.mask = fake_mask + fake_data = [{ + 'inputs': [torch.randn((3, 192, 192))], + 'data_sample': fake_data_sample + } for _ in range(2)] + outputs = model(fake_data, return_loss=True) + assert isinstance(outputs['loss'], torch.Tensor) + + # test extract_feat + fake_inputs, fake_data_samples = model.preprocss_data(fake_data) + fake_feat = model.extract_feat( + inputs=fake_inputs, data_samples=fake_data_samples) + assert list(fake_feat[0].shape) == [2, 1024, 6, 6]