[Refactor]: Refactor base model and create a sample for MAE

This commit is contained in:
YuanLiuuuuuu 2022-06-10 04:10:06 +00:00 committed by fangyixiao18
parent 962f9b9752
commit 35e0988527
16 changed files with 344 additions and 293 deletions

View File

@ -1,4 +1,5 @@
# dataset settings # dataset settings
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet' dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/' data_root = 'data/imagenet/'
file_client_args = dict(backend='disk') file_client_args = dict(backend='disk')

View File

@ -2,7 +2,6 @@ default_scope = 'mmselfsup'
default_hooks = dict( default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'), runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'), timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50), logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'), param_scheduler=dict(type='ParamSchedulerHook'),
@ -17,14 +16,14 @@ env_cfg = dict(
) )
log_processor = dict( log_processor = dict(
interval=50, window_size=10,
custom_keys=[dict(data_src='', method='mean', windows_size='global')]) custom_cfg=[dict(data_src='', method='mean', windows_size='global')])
vis_backends = [dict(type='LocalVisBackend')] # vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict( # visualizer = dict(
type='SelfSupLocalVisualizer', # type='SelfSupLocalVisualizer',
vis_backends=vis_backends, # vis_backends=vis_backends,
name='visualizer') # name='visualizer')
# custom_hooks = [dict(type='SelfSupVisualizationHook', interval=10)] # custom_hooks = [dict(type='SelfSupVisualizationHook', interval=10)]
log_level = 'INFO' log_level = 'INFO'

View File

@ -1,6 +1,8 @@
# model settings # model settings
model = dict( model = dict(
type='MAE', type='MAE',
data_preprocessor=dict(
mean=[124, 117, 104], std=[59, 58, 58], bgr_to_rgb=True),
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75), backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
neck=dict( neck=dict(
type='MAEPretrainDecoder', type='MAEPretrainDecoder',
@ -12,5 +14,8 @@ model = dict(
decoder_num_heads=16, decoder_num_heads=16,
mlp_ratio=4., mlp_ratio=4.,
), ),
head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16), head=dict(
loss=dict(type='MAEReconstructionLoss')) type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='MAEReconstructionLoss')))

View File

@ -1,6 +1,6 @@
# optimizer # optimizer_wrapper
optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05) optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
# learning rate scheduler # learning rate scheduler
param_scheduler = [ param_scheduler = [
@ -16,4 +16,4 @@ param_scheduler = [
] ]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=300) train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)

View File

@ -1,4 +1,4 @@
_base_ = 'mae_vit-base-p16_8xb512-coslr-400e_in1k.py' _base_ = 'mae_vit-base-p16_8xb512-coslr-400e_in1k.py'
# mixed precision # mixed precision
fp16 = dict(loss_scale='dynamic') optim_wrapper = dict(type='AmpOptimWrapper')

View File

@ -5,23 +5,26 @@ _base_ = [
'../_base_/default_runtime.py', '../_base_/default_runtime.py',
] ]
# dataset # dataset 8 x 512
data = dict(samples_per_gpu=512, workers_per_gpu=32) train_dataloader = dict(batch_size=512, num_workers=16)
# optimizer # optimizer wrapper
optimizer = dict( optimizer = dict(
lr=1.5e-4 * 4096 / 256, type='AdamW', lr=1.5e-4 * 4096 / 256, betas=(0.9, 0.95), weight_decay=0.05)
paramwise_options={ optim_wrapper = dict(
'norm': dict(weight_decay=0.), type='OptimWrapper',
'bias': dict(weight_decay=0.), optimizer=optimizer,
'pos_embed': dict(weight_decay=0.), paramwise_cfg=dict(
'mask_token': dict(weight_decay=0.), custom_keys={
'cls_token': dict(weight_decay=0.) 'ln': dict(decay_mult=0.0),
}) 'bias': dict(decay_mult=0.0),
optimizer_config = dict() 'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler # learning rate scheduler
scheduler = [ param_scheduler = [
dict( dict(
type='LinearLR', type='LinearLR',
start_factor=1e-4, start_factor=1e-4,
@ -38,13 +41,6 @@ scheduler = [
convert_to_iter_based=True) convert_to_iter_based=True)
] ]
# schedule # runtime settings
runner = dict(max_epochs=400) # pre-train for 400 epochs
train_cfg = dict(max_epochs=400)
# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])

View File

@ -1,40 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from typing import List, Optional, Union
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.distributed as dist from mmengine.model import BaseModel as _BaseModel
from mmcv.runner import BaseModule, auto_fp16 from torch import nn
from mmselfsup.core import SelfSupDataSample from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import get_module_device from mmselfsup.registry import MODELS
class BaseModel(BaseModule, metaclass=ABCMeta): class BaseModel(_BaseModel):
"""Base model class for self-supervised learning. """BaseModel for SelfSup.
All algorithms should inherit this module.
Args: Args:
preprocess_cfg (Dict): Config to preprocess images. backbone (dict): The backbone module. See
init_cfg (Dict, optional): Config to initialize models. :mod:`mmcls.models.backbones`.
neck (dict, optional): The neck module to process features from
backbone. See :mod:`mmcls.models.necks`. Defaults to None.
head (dict, optional): The head module to do prediction and calculate
loss from processed features. See :mod:`mmcls.models.heads`.
Notice that if the head is not set, almost all methods cannot be
used except :meth:`extract_feat`. Defaults to None.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None. Defaults to None.
""" """
def __init__(self, def __init__(self,
preprocess_cfg: Dict, backbone: dict,
init_cfg: Optional[Dict] = None) -> None: neck: Optional[dict] = None,
super(BaseModel, self).__init__(init_cfg) head: Optional[dict] = None,
self.fp16_enabled = False pretrained: Optional[str] = None,
assert 'mean' in preprocess_cfg data_preprocessor: Optional[Union[dict, nn.Module]] = None,
self.register_buffer( init_cfg: Optional[dict] = None):
'mean_norm',
torch.tensor(preprocess_cfg.pop('mean')).view(3, 1, 1)) if pretrained is not None:
assert 'std' in preprocess_cfg init_cfg = dict(type='Pretrained', checkpoint=pretrained)
self.register_buffer(
'std_norm', if data_preprocessor is None:
torch.tensor(preprocess_cfg.pop('std')).view(3, 1, 1)) data_preprocessor = {}
assert 'to_rgb' in preprocess_cfg # The build process is in MMEngine, so we need to add scope here.
self.to_rgb = preprocess_cfg.pop('to_rgb') data_preprocessor.setdefault('type',
'mmselfsup.SelfSupDataPreprocessor')
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
if head is not None:
self.head = MODELS.build(head)
@property @property
def with_neck(self) -> bool: def with_neck(self) -> bool:
@ -44,156 +70,98 @@ class BaseModel(BaseModule, metaclass=ABCMeta):
def with_head(self) -> bool: def with_head(self) -> bool:
return hasattr(self, 'head') and self.head is not None return hasattr(self, 'head') and self.head is not None
@abstractmethod
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function to extract features.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``extract_feat`` should be implemented')
@abstractmethod
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function in training
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``forward_train`` should be implemented')
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function in testing
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``forward_test`` should be implemented')
@auto_fp16(apply_to=('data', ))
def forward(self, def forward(self,
data: List[Dict], batch_inputs: torch.Tensor,
return_loss: bool = False, data_samples: Optional[List[SelfSupDataSample]] = None,
extract: bool = False, mode: str = 'tensor'):
**kwargs) -> object: """Returns losses or predictions of training, validation, testing, and
"""Forward function of model. simple inference process.
Calls either forward_train, forward_test or extract_feat function This module overwrites the abstract method in ``BaseModel``.
according to the mode.
Args: Args:
data (List[Dict]): The input data for model. batch_inputs (torch.Tensor): batch input tensor collated by
return_loss (bool): Train mode or test mode. Defaults to False. :attr:`data_preprocessor`.
extract (bool): Whether or not only extract features from model. data_samples (List[BaseDataElement], optional):
If set to True, the ``return_loss`` will be ignored. Defaults data samples collated by :attr:`data_preprocessor`.
to False. mode (str): mode should be one of ``loss``, ``predict`` and
""" ``tensor``
# preprocess images - ``loss``: Called by ``train_step`` and return loss ``dict``
inputs, data_samples = self.preprocss_data(data) used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
# Whether or not extract features. If set to True, the ``return_loss`` and return list of ``BaseDataElement`` results used for
# will be ignored. computing metric.
if extract: - ``tensor``: Called by custom use to get ``Tensor`` type
return self.extract_feat( results.
inputs=inputs, data_samples=data_samples, **kwargs)
if return_loss:
losses = self.forward_train(
inputs=inputs, data_samples=data_samples, **kwargs)
loss, log_vars = self._parse_losses(losses)
outputs = dict(loss=loss, log_vars=log_vars)
return outputs
else:
# should be a list of SelfSupDataSample
return self.forward_test(
inputs=inputs, data_samples=data_samples, **kwargs)
def _parse_losses(self, losses: Dict) -> Tuple[torch.Tensor, Dict]:
"""Parse the raw outputs (losses) of the network.
Args:
losses (Dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns: Returns:
tuple[torch.Tensor, Dict]: (loss, log_vars), loss is the loss ForwardResults:
tensor which may be a weighted sum of all losses, log_vars - If ``mode == loss``, return a ``dict`` of loss tensor used
contains all the variables to be sent to the logger. for backward and logging.
- If ``mode == predict``, return a ``list`` of
:obj:`BaseDataElement` for computing metric
and getting inference result.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict of tensor for custom use.
""" """
log_vars = OrderedDict() if mode == 'tensor':
for loss_name, loss_value in losses.items(): feats = self.extract_feat(batch_inputs)
if isinstance(loss_value, torch.Tensor): return feats
log_vars[loss_name] = loss_value.mean() elif mode == 'loss':
elif isinstance(loss_value, list): return self.loss(batch_inputs, data_samples)
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) elif mode == 'predict':
elif isinstance(loss_value, dict): return self.predict(batch_inputs, data_samples)
for name, value in loss_value.items():
log_vars[name] = value
else: else:
raise TypeError( raise RuntimeError(f'Invalid mode "{mode}".')
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items() def extract_feat(self, batch_inputs):
if 'loss' in _key) """Extract features from the input tensor with shape (N, C, ...).
log_vars['loss'] = loss This is a abstract method, and subclass should overwrite this methods
for loss_name, loss_value in log_vars.items(): if needed.
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def preprocss_data(
self,
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
"""Process input data during training, testing or extracting.
Args: Args:
data (List[Dict]): The data to be processed, which batch_inputs (Tensor): A batch of inputs. The shape of it should be
comes from dataloader. ``(num_samples, num_channels, *img_shape)``.
Returns: Returns:
tuple: It should contain 2 item. tuple | Tensor: The output of specified stage.
- batch_images (List[torch.Tensor]): The batch image tensor. The output depends on detailed implementation.
- data_samples (List[SelfSupDataSample], Optional): The Data
Samples. It usually includes information such as
`gt_label`. Return None If the input data does not
contain `data_sample`.
""" """
# data_['inputs] is a list raise NotImplementedError
images = [data_['inputs'] for data_ in data]
data_samples = [data_['data_sample'] for data_ in data]
device = get_module_device(self) def loss(self, batch_inputs: torch.Tensor,
data_samples = [data_sample.to(device) for data_sample in data_samples] data_samples: List[SelfSupDataSample]) -> dict:
images = [[img_.to(device) for img_ in img] for img in images] """Calculate losses from a batch of inputs and data samples.
# convert images to rgb This is a abstract method, and subclass should overwrite this methods
if self.to_rgb and images[0][0].size(0) == 3: if needed.
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
# normalize images Args:
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img] batch_inputs (torch.Tensor): The input tensor with shape
for img in images] (N, C, ...) in general.
data_samples (List[SelfSupDataSample]): The annotation data of
every samples.
# reconstruct images into several batches. For example, SimCLR needs Returns:
# two crops for each image, and this code snippet will convert images dict[str, Tensor]: a dictionary of loss components
# into two batches, each containing one crop of an image. """
batch_images = [] raise NotImplementedError
for i in range(len(images[0])):
cur_batch = [img[i] for img in images]
batch_images.append(torch.stack(cur_batch))
return batch_images, data_samples def predict(self,
batch_inputs: tuple,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> List[SelfSupDataSample]:
"""Predict results from the extracted features.
This module returns the logits before loss, which are used to compute
all kinds of metrics. This is a abstract method, and subclass should
overwrite this methods if needed.
Args:
feats (tuple): The features extracted from the backbone.
data_samples (List[BaseDataElement], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
"""
raise NotImplementedError

View File

@ -1,86 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple
import torch import torch
from mmselfsup.core import SelfSupDataSample from mmselfsup.core import SelfSupDataSample
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss, from ..builder import MODELS
build_neck)
from .base import BaseModel from .base import BaseModel
@ALGORITHMS.register_module() @MODELS.register_module()
class MAE(BaseModel): class MAE(BaseModel):
"""MAE. """MAE.
Implementation of `Masked Autoencoders Are Scalable Vision Learners Implementation of `Masked Autoencoders Are Scalable Vision Learners
<https://arxiv.org/abs/2111.06377>`_. <https://arxiv.org/abs/2111.06377>`_.
Args:
backbone (Dict, optional): Config dict for encoder. Defaults to None.
neck (Dict, optional): Config dict for encoder. Defaults to None.
head (Dict, optional): Config dict for head functions.
Defaults to None.
loss (Dict, optional): Config dict for loss functions.
Defaults to None.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
""" """
def __init__(self, def extract_feat(self, batch_inputs: List[torch.Tensor],
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super(MAE, self).__init__(
preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
self.neck.num_patches = self.backbone.num_patches
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
def init_weights(self) -> None:
super(MAE, self).init_weights()
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]: **kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features. """The forward function to extract features from neck.
Args: Args:
inputs (List[torch.Tensor]): The input images. batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
Tuple[torch.Tensor]: backbone outputs. torch.Tensor: Outputs from neck.
""" """
return self.backbone(inputs[0]) latent, _, ids_restore = self.backbone(batch_inputs[0])
pred = self.neck(latent, ids_restore)
return pred
def forward_train(self, inputs: List[torch.Tensor], def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample], data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]: **kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training. """The forward function in training.
Args: Args:
inputs (List[torch.Tensor]): The input images. batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required data_samples (List[SelfSupDataSample]): All elements required
during the forward function. during the forward function.
Returns: Returns:
Dict[str, Tensor]: A dictionary of loss components. Dict[str, Tensor]: A dictionary of loss components.
""" """
latent, mask, ids_restore = self.backbone(inputs[0]) latent, mask, ids_restore = self.backbone(batch_inputs[0])
pred = self.neck(latent, ids_restore) pred = self.neck(latent, ids_restore)
target = self.head(inputs[0]) loss = self.head(pred, batch_inputs[0], mask)
loss = self.loss(pred, target, mask)
losses = dict(loss=loss) losses = dict(loss=loss)
return losses return losses

View File

@ -4,26 +4,31 @@ from typing import Dict, List
import torch import torch
from mmcls.models import LabelSmoothLoss from mmcls.models import LabelSmoothLoss
from mmcv.cnn.utils.weight_init import trunc_normal_ from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from torch import nn from torch import nn
from ..builder import HEADS from ..builder import MODELS
@HEADS.register_module() @MODELS.register_module()
class MAEPretrainHead(BaseModule): class MAEPretrainHead(BaseModule):
"""Pre-training head for MAE. """Pre-training head for MAE.
Args: Args:
loss (dict): Config of loss.
norm_pix_loss (bool): Whether or not normalize target. norm_pix_loss (bool): Whether or not normalize target.
Defaults to False. Defaults to False.
patch_size (int): Patch size. Defaults to 16. patch_size (int): Patch size. Defaults to 16.
""" """
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None: def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
super().__init__() super().__init__()
self.norm_pix = norm_pix self.norm_pix = norm_pix
self.patch_size = patch_size self.patch_size = patch_size
self.loss = MODELS.build(loss)
def patchify(self, imgs: torch.Tensor) -> torch.Tensor: def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
@ -36,9 +41,19 @@ class MAEPretrainHead(BaseModule):
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x return x
def forward(self, x: torch.Tensor) -> torch.Tensor: def construct_target(self, target: torch.Tensor) -> torch.Tensor:
"""Construct the reconstruction target.
target = self.patchify(x) In addition to splitting images into tokens, this module will also
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C
"""
target = self.patchify(target)
if self.norm_pix: if self.norm_pix:
mean = target.mean(dim=-1, keepdim=True) mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True)
@ -46,8 +61,25 @@ class MAEPretrainHead(BaseModule):
return target return target
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MAE head.
@HEADS.register_module() Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
target = self.construct_target(target)
loss = self.loss(pred, target, mask)
return loss
@MODELS.register_module()
class MAEFinetuneHead(BaseModule): class MAEFinetuneHead(BaseModule):
"""Fine-tuning head for MAE. """Fine-tuning head for MAE.
@ -83,7 +115,7 @@ class MAEFinetuneHead(BaseModule):
return losses return losses
@HEADS.register_module() @MODELS.register_module()
class MAELinprobeHead(BaseModule): class MAELinprobeHead(BaseModule):
"""Linear probing head for MAE. """Linear probing head for MAE.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from ..builder import LOSSES from ..builder import LOSSES

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from ..builder import NECKS from ..builder import NECKS
from ..utils import build_2d_sincos_position_embedding from ..utils import build_2d_sincos_position_embedding

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .dall_e import Encoder from .dall_e import Encoder
from .data_preprocessor import SelfSupDataPreprocessor
from .ema import CosineEMA from .ema import CosineEMA
from .extractor import Extractor from .extractor import Extractor
from .gather_layer import GatherLayer from .gather_layer import GatherLayer
@ -14,5 +15,5 @@ __all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes', 'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention', 'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder', 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA' 'CosineEMA', 'SelfSupDataPreprocessor'
] ]

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple
import torch
from mmengine.data import BaseDataElement
from mmengine.model import ImgDataPreprocessor
from mmselfsup.registry import MODELS
@MODELS.register_module()
class SelfSupDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for operations, like normalization and bgr to rgb.
Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats
each item in `inputs` of input data as a list, instead of torch.Tensor.
"""
def collate_data(
self,
data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Collating and copying data to the target device.
This module overwrite the default method by treating each item in
``input`` of the input data as a list.
Collates the data sampled from dataloader into a list of tensor and
list of labels, and then copies tensor to the target device.
Subclasses could override it to be compatible with the custom format
data sampled from custom dataloader.
Args:
data (Sequence[dict]): Data sampled from dataloader.
Returns:
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
"""
inputs = [[img.to(self.device) for img in _data['inputs']]
for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
if 'data_sample' in _data:
batch_data_samples.append(_data['data_sample'])
# Move data from CPU to corresponding device.
batch_data_samples = [
data_sample.to(self.device) for data_sample in batch_data_samples
]
if not batch_data_samples:
batch_data_samples = None # type: ignore
return inputs, batch_data_samples
def forward(
self,
data: Sequence[dict],
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
inputs, batch_data_samples = self.collate_data(data)
# channel transform
if self.channel_conversion:
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
for _input in inputs]
# Normalization. Here is what is different from
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
cur_batch = [img[i] for img in inputs]
batch_inputs.append(torch.stack(cur_batch))
return batch_inputs, batch_data_samples

View File

@ -20,54 +20,32 @@ neck = dict(
decoder_num_heads=16, decoder_num_heads=16,
mlp_ratio=4., mlp_ratio=4.,
) )
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16)
loss = dict(type='MAEReconstructionLoss') loss = dict(type='MAEReconstructionLoss')
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae(): def test_mae():
preprocess_cfg = { data_preprocessor = {
'mean': [0.5, 0.5, 0.5], 'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5],
'to_rgb': True 'bgr_to_rgb': True
} }
with pytest.raises(AssertionError):
alg = MAE(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MAE(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MAE(
backbone=None,
neck=neck,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MAE( alg = MAE(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
head=head, head=head,
loss=loss, data_preprocessor=copy.deepcopy(data_preprocessor))
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
fake_data = [{ fake_data = [{
'inputs': [torch.randn((3, 224, 224))], 'inputs': [torch.randn((3, 224, 224))],
'data_sample': SelfSupDataSample() 'data_sample': SelfSupDataSample()
} for _ in range(2)] } for _ in range(2)]
fake_outputs = alg(fake_data, return_loss=True)
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float) assert isinstance(fake_outputs['loss'].item(), float)
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data) fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
fake_feat = alg.extract_feat( assert list(fake_feats.shape) == [2, 196, 768]
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 50, 768]

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.utils import SelfSupDataPreprocessor
def test_selfsup_data_preprocessor():
data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True)
fake_data = [{
'inputs': [torch.randn((3, 224, 224))],
'data_sample': SelfSupDataSample()
} for _ in range(2)]
fake_batches, fake_samples = data_preprocessor(fake_data)
assert len(fake_batches) == 1
assert len(fake_samples) == 2

View File

@ -5,12 +5,11 @@ set -x
PARTITION=$1 PARTITION=$1
JOB_NAME=$2 JOB_NAME=$2
CONFIG=$3 CONFIG=$3
WORK_DIR=$4
GPUS=${GPUS:-8} GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8} GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5} CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""} SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:5} PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \ srun -p ${PARTITION} \
@ -21,5 +20,4 @@ srun -p ${PARTITION} \
--cpus-per-task=${CPUS_PER_TASK} \ --cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \ --kill-on-bad-exit=1 \
${SRUN_ARGS} \ ${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} \ python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
--work-dir=${WORK_DIR} --seed 0 --launcher="slurm" ${PY_ARGS}