mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: Change the interface of SimMIM
This commit is contained in:
parent
e87be11a98
commit
e687aff595
@ -1,8 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmselfsup.core import SelfSupDataSample
|
||||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
@ -15,19 +16,21 @@ class SimMIM(BaseModel):
|
|||||||
<https://arxiv.org/abs/2111.09886>`_.
|
<https://arxiv.org/abs/2111.09886>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backbone (dict): Config dict for encoder. Defaults to None.
|
backbone (Dict): Config dict for encoder. Defaults to None.
|
||||||
neck (dict): Config dict for encoder. Defaults to None.
|
neck (Dict): Config dict for encoder. Defaults to None.
|
||||||
head (dict): Config dict for loss functions. Defaults to None.
|
head (Dict): Config dict for loss functions. Defaults to None.
|
||||||
init_cfg (dict, optional): Config dict for weight initialization.
|
preprocess_cfg (Dict): Config to preprocess images.
|
||||||
Defaults to None.
|
init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
|
||||||
|
initialization. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone: dict,
|
backbone: Dict,
|
||||||
neck: dict,
|
neck: Dict,
|
||||||
head: dict,
|
head: Dict,
|
||||||
init_cfg: Optional[dict] = None) -> None:
|
preprocess_cfg: Dict,
|
||||||
super(SimMIM, self).__init__(init_cfg)
|
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||||
|
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||||
assert backbone is not None
|
assert backbone is not None
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = build_backbone(backbone)
|
||||||
assert neck is not None
|
assert neck is not None
|
||||||
@ -35,27 +38,39 @@ class SimMIM(BaseModel):
|
|||||||
assert head is not None
|
assert head is not None
|
||||||
self.head = build_head(head)
|
self.head = build_head(head)
|
||||||
|
|
||||||
def extract_feat(self, img: torch.Tensor) -> tuple:
|
def extract_feat(self, inputs: List[torch.Tensor],
|
||||||
"""Function to extract features from backbone.
|
data_samples: List[SelfSupDataSample],
|
||||||
|
**kwarg) -> Tuple[torch.Tensor]:
|
||||||
|
"""The forward function to extract features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (torch.Tensor): Input images of shape (N, C, H, W).
|
inputs (List[torch.Tensor]): The input images.
|
||||||
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
|
during the forward function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[Tensor]: Latent representations of images.
|
Tuple[torch.Tensor]: backbone outputs.
|
||||||
"""
|
"""
|
||||||
return self.backbone(img)
|
mask = torch.stack(
|
||||||
|
[data_sample.mask.value for data_sample in data_samples])
|
||||||
|
return self.backbone(inputs[0], mask)
|
||||||
|
|
||||||
def forward_train(self, x: List[torch.Tensor], **kwargs) -> dict:
|
def forward_train(self, inputs: List[torch.Tensor],
|
||||||
"""Forward the masked image and get the reconstruction loss.
|
data_samples: List[SelfSupDataSample],
|
||||||
|
**kwargs) -> Dict[str, torch.Tensor]:
|
||||||
|
"""The forward function in training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (List[torch.Tensor, torch.Tensor]): Images and masks.
|
inputs (List[torch.Tensor]): The input images.
|
||||||
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
|
during the forward function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Reconstructed loss.
|
Dict[str, Tensor]: A dictionary of loss components.
|
||||||
"""
|
"""
|
||||||
img, mask = x
|
mask = torch.stack(
|
||||||
|
[data_sample.mask.value for data_sample in data_samples])
|
||||||
|
img = inputs[0]
|
||||||
|
|
||||||
img_latent = self.backbone(img, mask)
|
img_latent = self.backbone(img, mask)
|
||||||
img_rec = self.neck(img_latent[0])
|
img_rec = self.neck(img_latent[0])
|
||||||
|
47
tests/test_models/test_algorithms/test_simmim.py
Normal file
47
tests/test_models/test_algorithms/test_simmim.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import platform
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from mmengine.data import BaseDataElement as PixelData
|
||||||
|
|
||||||
|
from mmselfsup.core import SelfSupDataSample
|
||||||
|
from mmselfsup.models.algorithms import SimMIM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
|
def test_simmim():
|
||||||
|
|
||||||
|
# model config
|
||||||
|
model_config = dict(
|
||||||
|
backbone=dict(
|
||||||
|
type='SimMIMSwinTransformer',
|
||||||
|
arch='B',
|
||||||
|
img_size=192,
|
||||||
|
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
|
||||||
|
neck=dict(
|
||||||
|
type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
|
||||||
|
head=dict(type='SimMIMHead', patch_size=4, encoder_in_channels=3),
|
||||||
|
preprocess_cfg={
|
||||||
|
'mean': [0.5, 0.5, 0.5],
|
||||||
|
'std': [0.5, 0.5, 0.5],
|
||||||
|
'to_rgb': True
|
||||||
|
})
|
||||||
|
model = SimMIM(**model_config)
|
||||||
|
|
||||||
|
# test forward_train
|
||||||
|
fake_data_sample = SelfSupDataSample()
|
||||||
|
fake_mask = PixelData(value=torch.rand((48, 48)))
|
||||||
|
fake_data_sample.mask = fake_mask
|
||||||
|
fake_data = [{
|
||||||
|
'inputs': [torch.randn((3, 192, 192))],
|
||||||
|
'data_sample': fake_data_sample
|
||||||
|
} for _ in range(2)]
|
||||||
|
outputs = model(fake_data, return_loss=True)
|
||||||
|
assert isinstance(outputs['loss'], torch.Tensor)
|
||||||
|
|
||||||
|
# test extract_feat
|
||||||
|
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
|
||||||
|
fake_feat = model.extract_feat(
|
||||||
|
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||||
|
assert list(fake_feat[0].shape) == [2, 1024, 6, 6]
|
Loading…
x
Reference in New Issue
Block a user