[Refactor]: Refactor base model and create a sample for MAE

This commit is contained in:
YuanLiuuuuuu 2022-06-10 04:10:06 +00:00 committed by fangyixiao18
parent 962f9b9752
commit 35e0988527
16 changed files with 344 additions and 293 deletions

View File

@ -1,4 +1,5 @@
# dataset settings
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')

View File

@ -2,7 +2,6 @@ default_scope = 'mmselfsup'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
@ -17,14 +16,14 @@ env_cfg = dict(
)
log_processor = dict(
interval=50,
custom_keys=[dict(data_src='', method='mean', windows_size='global')])
window_size=10,
custom_cfg=[dict(data_src='', method='mean', windows_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SelfSupLocalVisualizer',
vis_backends=vis_backends,
name='visualizer')
# vis_backends = [dict(type='LocalVisBackend')]
# visualizer = dict(
# type='SelfSupLocalVisualizer',
# vis_backends=vis_backends,
# name='visualizer')
# custom_hooks = [dict(type='SelfSupVisualizationHook', interval=10)]
log_level = 'INFO'

View File

@ -1,6 +1,8 @@
# model settings
model = dict(
type='MAE',
data_preprocessor=dict(
mean=[124, 117, 104], std=[59, 58, 58], bgr_to_rgb=True),
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
@ -12,5 +14,8 @@ model = dict(
decoder_num_heads=16,
mlp_ratio=4.,
),
head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16),
loss=dict(type='MAEReconstructionLoss'))
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='MAEReconstructionLoss')))

View File

@ -1,6 +1,6 @@
# optimizer
# optimizer_wrapper
optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
# learning rate scheduler
param_scheduler = [
@ -16,4 +16,4 @@ param_scheduler = [
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=300)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)

View File

@ -1,4 +1,4 @@
_base_ = 'mae_vit-base-p16_8xb512-coslr-400e_in1k.py'
# mixed precision
fp16 = dict(loss_scale='dynamic')
optim_wrapper = dict(type='AmpOptimWrapper')

View File

@ -5,23 +5,26 @@ _base_ = [
'../_base_/default_runtime.py',
]
# dataset
data = dict(samples_per_gpu=512, workers_per_gpu=32)
# dataset 8 x 512
train_dataloader = dict(batch_size=512, num_workers=16)
# optimizer
# optimizer wrapper
optimizer = dict(
lr=1.5e-4 * 4096 / 256,
paramwise_options={
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'mask_token': dict(weight_decay=0.),
'cls_token': dict(weight_decay=0.)
})
optimizer_config = dict()
type='AdamW', lr=1.5e-4 * 4096 / 256, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.)
}))
# learning rate scheduler
scheduler = [
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
@ -38,13 +41,6 @@ scheduler = [
convert_to_iter_based=True)
]
# schedule
runner = dict(max_epochs=400)
# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])
# runtime settings
# pre-train for 400 epochs
train_cfg = dict(max_epochs=400)

View File

@ -1,40 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from mmcv.runner import BaseModule, auto_fp16
from mmengine.model import BaseModel as _BaseModel
from torch import nn
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import get_module_device
from mmselfsup.registry import MODELS
class BaseModel(BaseModule, metaclass=ABCMeta):
"""Base model class for self-supervised learning.
class BaseModel(_BaseModel):
"""BaseModel for SelfSup.
All algorithms should inherit this module.
Args:
preprocess_cfg (Dict): Config to preprocess images.
init_cfg (Dict, optional): Config to initialize models.
backbone (dict): The backbone module. See
:mod:`mmcls.models.backbones`.
neck (dict, optional): The neck module to process features from
backbone. See :mod:`mmcls.models.necks`. Defaults to None.
head (dict, optional): The head module to do prediction and calculate
loss from processed features. See :mod:`mmcls.models.heads`.
Notice that if the head is not set, almost all methods cannot be
used except :meth:`extract_feat`. Defaults to None.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
preprocess_cfg: Dict,
init_cfg: Optional[Dict] = None) -> None:
super(BaseModel, self).__init__(init_cfg)
self.fp16_enabled = False
assert 'mean' in preprocess_cfg
self.register_buffer(
'mean_norm',
torch.tensor(preprocess_cfg.pop('mean')).view(3, 1, 1))
assert 'std' in preprocess_cfg
self.register_buffer(
'std_norm',
torch.tensor(preprocess_cfg.pop('std')).view(3, 1, 1))
assert 'to_rgb' in preprocess_cfg
self.to_rgb = preprocess_cfg.pop('to_rgb')
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
if pretrained is not None:
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
data_preprocessor.setdefault('type',
'mmselfsup.SelfSupDataPreprocessor')
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
if head is not None:
self.head = MODELS.build(head)
@property
def with_neck(self) -> bool:
@ -44,156 +70,98 @@ class BaseModel(BaseModule, metaclass=ABCMeta):
def with_head(self) -> bool:
return hasattr(self, 'head') and self.head is not None
@abstractmethod
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function to extract features.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``extract_feat`` should be implemented')
@abstractmethod
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function in training
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``forward_train`` should be implemented')
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function in testing
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
raise NotImplementedError('``forward_test`` should be implemented')
@auto_fp16(apply_to=('data', ))
def forward(self,
data: List[Dict],
return_loss: bool = False,
extract: bool = False,
**kwargs) -> object:
"""Forward function of model.
batch_inputs: torch.Tensor,
data_samples: Optional[List[SelfSupDataSample]] = None,
mode: str = 'tensor'):
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
Calls either forward_train, forward_test or extract_feat function
according to the mode.
This module overwrites the abstract method in ``BaseModel``.
Args:
data (List[Dict]): The input data for model.
return_loss (bool): Train mode or test mode. Defaults to False.
extract (bool): Whether or not only extract features from model.
If set to True, the ``return_loss`` will be ignored. Defaults
to False.
batch_inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of ``BaseDataElement`` results used for
computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
ForwardResults:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of
:obj:`BaseDataElement` for computing metric
and getting inference result.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict of tensor for custom use.
"""
# preprocess images
inputs, data_samples = self.preprocss_data(data)
# Whether or not extract features. If set to True, the ``return_loss``
# will be ignored.
if extract:
return self.extract_feat(
inputs=inputs, data_samples=data_samples, **kwargs)
if return_loss:
losses = self.forward_train(
inputs=inputs, data_samples=data_samples, **kwargs)
loss, log_vars = self._parse_losses(losses)
outputs = dict(loss=loss, log_vars=log_vars)
return outputs
if mode == 'tensor':
feats = self.extract_feat(batch_inputs)
return feats
elif mode == 'loss':
return self.loss(batch_inputs, data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, data_samples)
else:
# should be a list of SelfSupDataSample
return self.forward_test(
inputs=inputs, data_samples=data_samples, **kwargs)
raise RuntimeError(f'Invalid mode "{mode}".')
def _parse_losses(self, losses: Dict) -> Tuple[torch.Tensor, Dict]:
"""Parse the raw outputs (losses) of the network.
def extract_feat(self, batch_inputs):
"""Extract features from the input tensor with shape (N, C, ...).
This is a abstract method, and subclass should overwrite this methods
if needed.
Args:
losses (Dict): Raw output of the network, which usually contain
losses and other necessary information.
batch_inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
Returns:
tuple[torch.Tensor, Dict]: (loss, log_vars), loss is the loss
tensor which may be a weighted sum of all losses, log_vars
contains all the variables to be sent to the logger.
tuple | Tensor: The output of specified stage.
The output depends on detailed implementation.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
elif isinstance(loss_value, dict):
for name, value in loss_value.items():
log_vars[name] = value
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
raise NotImplementedError
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
def loss(self, batch_inputs: torch.Tensor,
data_samples: List[SelfSupDataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def preprocss_data(
self,
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
"""Process input data during training, testing or extracting.
This is a abstract method, and subclass should overwrite this methods
if needed.
Args:
data (List[Dict]): The data to be processed, which
comes from dataloader.
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[SelfSupDataSample]): The annotation data of
every samples.
Returns:
tuple: It should contain 2 item.
- batch_images (List[torch.Tensor]): The batch image tensor.
- data_samples (List[SelfSupDataSample], Optional): The Data
Samples. It usually includes information such as
`gt_label`. Return None If the input data does not
contain `data_sample`.
dict[str, Tensor]: a dictionary of loss components
"""
# data_['inputs] is a list
images = [data_['inputs'] for data_ in data]
data_samples = [data_['data_sample'] for data_ in data]
raise NotImplementedError
device = get_module_device(self)
data_samples = [data_sample.to(device) for data_sample in data_samples]
images = [[img_.to(device) for img_ in img] for img in images]
def predict(self,
batch_inputs: tuple,
data_samples: Optional[List[SelfSupDataSample]] = None,
**kwargs) -> List[SelfSupDataSample]:
"""Predict results from the extracted features.
# convert images to rgb
if self.to_rgb and images[0][0].size(0) == 3:
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
This module returns the logits before loss, which are used to compute
all kinds of metrics. This is a abstract method, and subclass should
overwrite this methods if needed.
# normalize images
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
for img in images]
# reconstruct images into several batches. For example, SimCLR needs
# two crops for each image, and this code snippet will convert images
# into two batches, each containing one crop of an image.
batch_images = []
for i in range(len(images[0])):
cur_batch = [img[i] for img in images]
batch_images.append(torch.stack(cur_batch))
return batch_images, data_samples
Args:
feats (tuple): The features extracted from the backbone.
data_samples (List[BaseDataElement], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
"""
raise NotImplementedError

View File

@ -1,86 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import torch
from mmselfsup.core import SelfSupDataSample
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_neck)
from ..builder import MODELS
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class MAE(BaseModel):
"""MAE.
Implementation of `Masked Autoencoders Are Scalable Vision Learners
<https://arxiv.org/abs/2111.06377>`_.
Args:
backbone (Dict, optional): Config dict for encoder. Defaults to None.
neck (Dict, optional): Config dict for encoder. Defaults to None.
head (Dict, optional): Config dict for head functions.
Defaults to None.
loss (Dict, optional): Config dict for loss functions.
Defaults to None.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
<https://arxiv.org/abs/2111.06377>`_.
"""
def __init__(self,
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super(MAE, self).__init__(
preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
self.neck.num_patches = self.backbone.num_patches
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
def init_weights(self) -> None:
super(MAE, self).init_weights()
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features.
"""The forward function to extract features from neck.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
batch_inputs (List[torch.Tensor]): The input images.
Returns:
Tuple[torch.Tensor]: backbone outputs.
torch.Tensor: Outputs from neck.
"""
return self.backbone(inputs[0])
latent, _, ids_restore = self.backbone(batch_inputs[0])
pred = self.neck(latent, ids_restore)
return pred
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
latent, mask, ids_restore = self.backbone(inputs[0])
latent, mask, ids_restore = self.backbone(batch_inputs[0])
pred = self.neck(latent, ids_restore)
target = self.head(inputs[0])
loss = self.loss(pred, target, mask)
loss = self.head(pred, batch_inputs[0], mask)
losses = dict(loss=loss)
return losses

View File

@ -4,26 +4,31 @@ from typing import Dict, List
import torch
from mmcls.models import LabelSmoothLoss
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from torch import nn
from ..builder import HEADS
from ..builder import MODELS
@HEADS.register_module()
@MODELS.register_module()
class MAEPretrainHead(BaseModule):
"""Pre-training head for MAE.
Args:
loss (dict): Config of loss.
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None:
def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
self.loss = MODELS.build(loss)
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
@ -36,9 +41,19 @@ class MAEPretrainHead(BaseModule):
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
"""Construct the reconstruction target.
target = self.patchify(x)
In addition to splitting images into tokens, this module will also
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C
"""
target = self.patchify(target)
if self.norm_pix:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
@ -46,8 +61,25 @@ class MAEPretrainHead(BaseModule):
return target
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MAE head.
@HEADS.register_module()
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
target = self.construct_target(target)
loss = self.loss(pred, target, mask)
return loss
@MODELS.register_module()
class MAEFinetuneHead(BaseModule):
"""Fine-tuning head for MAE.
@ -83,7 +115,7 @@ class MAEFinetuneHead(BaseModule):
return losses
@HEADS.register_module()
@MODELS.register_module()
class MAELinprobeHead(BaseModule):
"""Linear probing head for MAE.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from ..builder import LOSSES

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from ..builder import NECKS
from ..utils import build_2d_sincos_position_embedding

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dall_e import Encoder
from .data_preprocessor import SelfSupDataPreprocessor
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
@ -14,5 +15,5 @@ __all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA'
'CosineEMA', 'SelfSupDataPreprocessor'
]

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple
import torch
from mmengine.data import BaseDataElement
from mmengine.model import ImgDataPreprocessor
from mmselfsup.registry import MODELS
@MODELS.register_module()
class SelfSupDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for operations, like normalization and bgr to rgb.
Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats
each item in `inputs` of input data as a list, instead of torch.Tensor.
"""
def collate_data(
self,
data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Collating and copying data to the target device.
This module overwrite the default method by treating each item in
``input`` of the input data as a list.
Collates the data sampled from dataloader into a list of tensor and
list of labels, and then copies tensor to the target device.
Subclasses could override it to be compatible with the custom format
data sampled from custom dataloader.
Args:
data (Sequence[dict]): Data sampled from dataloader.
Returns:
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
"""
inputs = [[img.to(self.device) for img in _data['inputs']]
for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
if 'data_sample' in _data:
batch_data_samples.append(_data['data_sample'])
# Move data from CPU to corresponding device.
batch_data_samples = [
data_sample.to(self.device) for data_sample in batch_data_samples
]
if not batch_data_samples:
batch_data_samples = None # type: ignore
return inputs, batch_data_samples
def forward(
self,
data: Sequence[dict],
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
inputs, batch_data_samples = self.collate_data(data)
# channel transform
if self.channel_conversion:
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
for _input in inputs]
# Normalization. Here is what is different from
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
cur_batch = [img[i] for img in inputs]
batch_inputs.append(torch.stack(cur_batch))
return batch_inputs, batch_data_samples

View File

@ -20,54 +20,32 @@ neck = dict(
decoder_num_heads=16,
mlp_ratio=4.,
)
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16)
loss = dict(type='MAEReconstructionLoss')
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae():
preprocess_cfg = {
data_preprocessor = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
'bgr_to_rgb': True
}
with pytest.raises(AssertionError):
alg = MAE(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MAE(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MAE(
backbone=None,
neck=neck,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MAE(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
data_preprocessor=copy.deepcopy(data_preprocessor))
fake_data = [{
'inputs': [torch.randn((3, 224, 224))],
'data_sample': SelfSupDataSample()
} for _ in range(2)]
fake_outputs = alg(fake_data, return_loss=True)
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 50, 768]
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
assert list(fake_feats.shape) == [2, 196, 768]

View File

@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.utils import SelfSupDataPreprocessor
def test_selfsup_data_preprocessor():
data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True)
fake_data = [{
'inputs': [torch.randn((3, 224, 224))],
'data_sample': SelfSupDataSample()
} for _ in range(2)]
fake_batches, fake_samples = data_preprocessor(fake_data)
assert len(fake_batches) == 1
assert len(fake_samples) == 2

View File

@ -5,12 +5,11 @@ set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:5}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
@ -21,5 +20,4 @@ srun -p ${PARTITION} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} \
--work-dir=${WORK_DIR} --seed 0 --launcher="slurm" ${PY_ARGS}
python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}