mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: Change the interface of SimMIM
This commit is contained in:
parent
e87be11a98
commit
e687aff595
@ -1,8 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||
from .base import BaseModel
|
||||
|
||||
@ -15,19 +16,21 @@ class SimMIM(BaseModel):
|
||||
<https://arxiv.org/abs/2111.09886>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config dict for encoder. Defaults to None.
|
||||
neck (dict): Config dict for encoder. Defaults to None.
|
||||
head (dict): Config dict for loss functions. Defaults to None.
|
||||
init_cfg (dict, optional): Config dict for weight initialization.
|
||||
Defaults to None.
|
||||
backbone (Dict): Config dict for encoder. Defaults to None.
|
||||
neck (Dict): Config dict for encoder. Defaults to None.
|
||||
head (Dict): Config dict for loss functions. Defaults to None.
|
||||
preprocess_cfg (Dict): Config to preprocess images.
|
||||
init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
super(SimMIM, self).__init__(init_cfg)
|
||||
backbone: Dict,
|
||||
neck: Dict,
|
||||
head: Dict,
|
||||
preprocess_cfg: Dict,
|
||||
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
assert neck is not None
|
||||
@ -35,27 +38,39 @@ class SimMIM(BaseModel):
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
|
||||
def extract_feat(self, img: torch.Tensor) -> tuple:
|
||||
"""Function to extract features from backbone.
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: Latent representations of images.
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
return self.backbone(img)
|
||||
mask = torch.stack(
|
||||
[data_sample.mask.value for data_sample in data_samples])
|
||||
return self.backbone(inputs[0], mask)
|
||||
|
||||
def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict:
|
||||
"""Forward the masked image and get the reconstruction loss.
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
x (List[torch.Tensor, torch.Tensor]): Images and masks.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
dict: Reconstructed loss.
|
||||
Dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
img, mask = x
|
||||
mask = torch.stack(
|
||||
[data_sample.mask.value for data_sample in data_samples])
|
||||
img = inputs[0]
|
||||
|
||||
img_latent = self.backbone(img, mask)
|
||||
img_rec = self.neck(img_latent[0])
|
||||
|
47
tests/test_models/test_algorithms/test_simmim.py
Normal file
47
tests/test_models/test_algorithms/test_simmim.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement as PixelData
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models.algorithms import SimMIM
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_simmim():
|
||||
|
||||
# model config
|
||||
model_config = dict(
|
||||
backbone=dict(
|
||||
type='SimMIMSwinTransformer',
|
||||
arch='B',
|
||||
img_size=192,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
|
||||
neck=dict(
|
||||
type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
|
||||
head=dict(type='SimMIMHead', patch_size=4, encoder_in_channels=3),
|
||||
preprocess_cfg={
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
})
|
||||
model = SimMIM(**model_config)
|
||||
|
||||
# test forward_train
|
||||
fake_data_sample = SelfSupDataSample()
|
||||
fake_mask = PixelData(value=torch.rand((48, 48)))
|
||||
fake_data_sample.mask = fake_mask
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 192, 192))],
|
||||
'data_sample': fake_data_sample
|
||||
} for _ in range(2)]
|
||||
outputs = model(fake_data, return_loss=True)
|
||||
assert isinstance(outputs['loss'], torch.Tensor)
|
||||
|
||||
# test extract_feat
|
||||
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
|
||||
fake_feat = model.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert list(fake_feat[0].shape) == [2, 1024, 6, 6]
|
Loading…
x
Reference in New Issue
Block a user