diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py index b9e5c950..af59d946 100644 --- a/configs/selfsup/_base_/datasets/imagenet_mae.py +++ b/configs/selfsup/_base_/datasets/imagenet_mae.py @@ -1,4 +1,5 @@ # dataset settings +custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False) dataset_type = 'mmcls.ImageNet' data_root = 'data/imagenet/' file_client_args = dict(backend='disk') diff --git a/configs/selfsup/_base_/default_runtime.py b/configs/selfsup/_base_/default_runtime.py index f7fb3b59..c58a8f24 100644 --- a/configs/selfsup/_base_/default_runtime.py +++ b/configs/selfsup/_base_/default_runtime.py @@ -2,7 +2,6 @@ default_scope = 'mmselfsup' default_hooks = dict( runtime_info=dict(type='RuntimeInfoHook'), - optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=50), param_scheduler=dict(type='ParamSchedulerHook'), @@ -17,14 +16,14 @@ env_cfg = dict( ) log_processor = dict( - interval=50, - custom_keys=[dict(data_src='', method='mean', windows_size='global')]) + window_size=10, + custom_cfg=[dict(data_src='', method='mean', windows_size='global')]) -vis_backends = [dict(type='LocalVisBackend')] -visualizer = dict( - type='SelfSupLocalVisualizer', - vis_backends=vis_backends, - name='visualizer') +# vis_backends = [dict(type='LocalVisBackend')] +# visualizer = dict( +# type='SelfSupLocalVisualizer', +# vis_backends=vis_backends, +# name='visualizer') # custom_hooks = [dict(type='SelfSupVisualizationHook', interval=10)] log_level = 'INFO' diff --git a/configs/selfsup/_base_/models/mae_vit-base-p16.py b/configs/selfsup/_base_/models/mae_vit-base-p16.py index 25f94c51..2412fed2 100644 --- a/configs/selfsup/_base_/models/mae_vit-base-p16.py +++ b/configs/selfsup/_base_/models/mae_vit-base-p16.py @@ -1,6 +1,8 @@ # model settings model = dict( type='MAE', + data_preprocessor=dict( + mean=[124, 117, 104], std=[59, 58, 58], bgr_to_rgb=True), backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75), neck=dict( type='MAEPretrainDecoder', @@ -12,5 +14,8 @@ model = dict( decoder_num_heads=16, mlp_ratio=4., ), - head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16), - loss=dict(type='MAEReconstructionLoss')) + head=dict( + type='MAEPretrainHead', + norm_pix=True, + patch_size=16, + loss=dict(type='MAEReconstructionLoss'))) diff --git a/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py b/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py index d964c1ac..e98ef0e6 100644 --- a/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py +++ b/configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py @@ -1,6 +1,6 @@ -# optimizer +# optimizer_wrapper optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05) -optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) # learning rate scheduler param_scheduler = [ @@ -16,4 +16,4 @@ param_scheduler = [ ] # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=300) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300) diff --git a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py index cd63e9bc..4582071e 100644 --- a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py +++ b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py @@ -1,4 +1,4 @@ _base_ = 'mae_vit-base-p16_8xb512-coslr-400e_in1k.py' # mixed precision -fp16 = dict(loss_scale='dynamic') +optim_wrapper = dict(type='AmpOptimWrapper') diff --git a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py index e37f9093..6280ede2 100644 --- a/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py +++ b/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py @@ -5,23 +5,26 @@ _base_ = [ '../_base_/default_runtime.py', ] -# dataset -data = dict(samples_per_gpu=512, workers_per_gpu=32) +# dataset 8 x 512 +train_dataloader = dict(batch_size=512, num_workers=16) -# optimizer +# optimizer wrapper optimizer = dict( - lr=1.5e-4 * 4096 / 256, - paramwise_options={ - 'norm': dict(weight_decay=0.), - 'bias': dict(weight_decay=0.), - 'pos_embed': dict(weight_decay=0.), - 'mask_token': dict(weight_decay=0.), - 'cls_token': dict(weight_decay=0.) - }) -optimizer_config = dict() + type='AdamW', lr=1.5e-4 * 4096 / 256, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=optimizer, + paramwise_cfg=dict( + custom_keys={ + 'ln': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'pos_embed': dict(decay_mult=0.), + 'mask_token': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) # learning rate scheduler -scheduler = [ +param_scheduler = [ dict( type='LinearLR', start_factor=1e-4, @@ -38,13 +41,6 @@ scheduler = [ convert_to_iter_based=True) ] -# schedule -runner = dict(max_epochs=400) - -# runtime -checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='') -persistent_workers = True -log_config = dict( - interval=100, hooks=[ - dict(type='TextLoggerHook'), - ]) +# runtime settings +# pre-train for 400 epochs +train_cfg = dict(max_epochs=400) diff --git a/mmselfsup/models/algorithms/base.py b/mmselfsup/models/algorithms/base.py index 30ce67d0..6ccefdb8 100644 --- a/mmselfsup/models/algorithms/base.py +++ b/mmselfsup/models/algorithms/base.py @@ -1,40 +1,66 @@ # Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod -from collections import OrderedDict -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Union import torch -import torch.distributed as dist -from mmcv.runner import BaseModule, auto_fp16 +from mmengine.model import BaseModel as _BaseModel +from torch import nn from mmselfsup.core import SelfSupDataSample -from mmselfsup.utils import get_module_device +from mmselfsup.registry import MODELS -class BaseModel(BaseModule, metaclass=ABCMeta): - """Base model class for self-supervised learning. +class BaseModel(_BaseModel): + """BaseModel for SelfSup. + + All algorithms should inherit this module. Args: - preprocess_cfg (Dict): Config to preprocess images. - init_cfg (Dict, optional): Config to initialize models. + backbone (dict): The backbone module. See + :mod:`mmcls.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmcls.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmcls.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. + Defaults to None. + init_cfg (dict, optional): the config to control the initialization. Defaults to None. """ def __init__(self, - preprocess_cfg: Dict, - init_cfg: Optional[Dict] = None) -> None: - super(BaseModel, self).__init__(init_cfg) - self.fp16_enabled = False - assert 'mean' in preprocess_cfg - self.register_buffer( - 'mean_norm', - torch.tensor(preprocess_cfg.pop('mean')).view(3, 1, 1)) - assert 'std' in preprocess_cfg - self.register_buffer( - 'std_norm', - torch.tensor(preprocess_cfg.pop('std')).view(3, 1, 1)) - assert 'to_rgb' in preprocess_cfg - self.to_rgb = preprocess_cfg.pop('to_rgb') + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', + 'mmselfsup.SelfSupDataPreprocessor') + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + self.backbone = MODELS.build(backbone) + + if neck is not None: + self.neck = MODELS.build(neck) + + if head is not None: + self.head = MODELS.build(head) @property def with_neck(self) -> bool: @@ -44,156 +70,98 @@ class BaseModel(BaseModule, metaclass=ABCMeta): def with_head(self) -> bool: return hasattr(self, 'head') and self.head is not None - @abstractmethod - def extract_feat(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> object: - """The forward function to extract features. - - Args: - inputs (List[torch.Tensor]): The input images. - data_samples (List[SelfSupDataSample]): All elements required - during the forward function. - """ - raise NotImplementedError('``extract_feat`` should be implemented') - - @abstractmethod - def forward_train(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> object: - """The forward function in training - Args: - inputs (List[torch.Tensor]): The input images. - data_samples (List[SelfSupDataSample]): All elements required - during the forward function. - """ - raise NotImplementedError('``forward_train`` should be implemented') - - def forward_test(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> object: - """The forward function in testing - Args: - inputs (List[torch.Tensor]): The input images. - data_samples (List[SelfSupDataSample]): All elements required - during the forward function. - """ - raise NotImplementedError('``forward_test`` should be implemented') - - @auto_fp16(apply_to=('data', )) def forward(self, - data: List[Dict], - return_loss: bool = False, - extract: bool = False, - **kwargs) -> object: - """Forward function of model. + batch_inputs: torch.Tensor, + data_samples: Optional[List[SelfSupDataSample]] = None, + mode: str = 'tensor'): + """Returns losses or predictions of training, validation, testing, and + simple inference process. - Calls either forward_train, forward_test or extract_feat function - according to the mode. + This module overwrites the abstract method in ``BaseModel``. Args: - data (List[Dict]): The input data for model. - return_loss (bool): Train mode or test mode. Defaults to False. - extract (bool): Whether or not only extract features from model. - If set to True, the ``return_loss`` will be ignored. Defaults - to False. + batch_inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + mode (str): mode should be one of ``loss``, ``predict`` and + ``tensor`` + - ``loss``: Called by ``train_step`` and return loss ``dict`` + used for logging + - ``predict``: Called by ``val_step`` and ``test_step`` + and return list of ``BaseDataElement`` results used for + computing metric. + - ``tensor``: Called by custom use to get ``Tensor`` type + results. + Returns: + ForwardResults: + - If ``mode == loss``, return a ``dict`` of loss tensor used + for backward and logging. + - If ``mode == predict``, return a ``list`` of + :obj:`BaseDataElement` for computing metric + and getting inference result. + - If ``mode == tensor``, return a tensor or ``tuple`` of tensor + or ``dict of tensor for custom use. """ - # preprocess images - inputs, data_samples = self.preprocss_data(data) - - # Whether or not extract features. If set to True, the ``return_loss`` - # will be ignored. - if extract: - return self.extract_feat( - inputs=inputs, data_samples=data_samples, **kwargs) - - if return_loss: - losses = self.forward_train( - inputs=inputs, data_samples=data_samples, **kwargs) - loss, log_vars = self._parse_losses(losses) - outputs = dict(loss=loss, log_vars=log_vars) - return outputs + if mode == 'tensor': + feats = self.extract_feat(batch_inputs) + return feats + elif mode == 'loss': + return self.loss(batch_inputs, data_samples) + elif mode == 'predict': + return self.predict(batch_inputs, data_samples) else: - # should be a list of SelfSupDataSample - return self.forward_test( - inputs=inputs, data_samples=data_samples, **kwargs) + raise RuntimeError(f'Invalid mode "{mode}".') - def _parse_losses(self, losses: Dict) -> Tuple[torch.Tensor, Dict]: - """Parse the raw outputs (losses) of the network. + def extract_feat(self, batch_inputs): + """Extract features from the input tensor with shape (N, C, ...). + + This is a abstract method, and subclass should overwrite this methods + if needed. Args: - losses (Dict): Raw output of the network, which usually contain - losses and other necessary information. + batch_inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + Returns: - tuple[torch.Tensor, Dict]: (loss, log_vars), loss is the loss - tensor which may be a weighted sum of all losses, log_vars - contains all the variables to be sent to the logger. + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. """ - log_vars = OrderedDict() - for loss_name, loss_value in losses.items(): - if isinstance(loss_value, torch.Tensor): - log_vars[loss_name] = loss_value.mean() - elif isinstance(loss_value, list): - log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) - elif isinstance(loss_value, dict): - for name, value in loss_value.items(): - log_vars[name] = value - else: - raise TypeError( - f'{loss_name} is not a tensor or list of tensors') + raise NotImplementedError - loss = sum(_value for _key, _value in log_vars.items() - if 'loss' in _key) + def loss(self, batch_inputs: torch.Tensor, + data_samples: List[SelfSupDataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. - log_vars['loss'] = loss - for loss_name, loss_value in log_vars.items(): - # reduce loss when distributed training - if dist.is_available() and dist.is_initialized(): - loss_value = loss_value.data.clone() - dist.all_reduce(loss_value.div_(dist.get_world_size())) - log_vars[loss_name] = loss_value.item() - - return loss, log_vars - - def preprocss_data( - self, - data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]: - """Process input data during training, testing or extracting. + This is a abstract method, and subclass should overwrite this methods + if needed. Args: - data (List[Dict]): The data to be processed, which - comes from dataloader. + batch_inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[SelfSupDataSample]): The annotation data of + every samples. Returns: - tuple: It should contain 2 item. - - batch_images (List[torch.Tensor]): The batch image tensor. - - data_samples (List[SelfSupDataSample], Optional): The Data - Samples. It usually includes information such as - `gt_label`. Return None If the input data does not - contain `data_sample`. + dict[str, Tensor]: a dictionary of loss components """ - # data_['inputs] is a list - images = [data_['inputs'] for data_ in data] - data_samples = [data_['data_sample'] for data_ in data] + raise NotImplementedError - device = get_module_device(self) - data_samples = [data_sample.to(device) for data_sample in data_samples] - images = [[img_.to(device) for img_ in img] for img in images] + def predict(self, + batch_inputs: tuple, + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> List[SelfSupDataSample]: + """Predict results from the extracted features. - # convert images to rgb - if self.to_rgb and images[0][0].size(0) == 3: - images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images] + This module returns the logits before loss, which are used to compute + all kinds of metrics. This is a abstract method, and subclass should + overwrite this methods if needed. - # normalize images - images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img] - for img in images] - - # reconstruct images into several batches. For example, SimCLR needs - # two crops for each image, and this code snippet will convert images - # into two batches, each containing one crop of an image. - batch_images = [] - for i in range(len(images[0])): - cur_batch = [img[i] for img in images] - batch_images.append(torch.stack(cur_batch)) - - return batch_images, data_samples + Args: + feats (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + raise NotImplementedError diff --git a/mmselfsup/models/algorithms/mae.py b/mmselfsup/models/algorithms/mae.py index 13e71070..d7fa3d37 100644 --- a/mmselfsup/models/algorithms/mae.py +++ b/mmselfsup/models/algorithms/mae.py @@ -1,86 +1,50 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple import torch from mmselfsup.core import SelfSupDataSample -from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss, - build_neck) +from ..builder import MODELS from .base import BaseModel -@ALGORITHMS.register_module() +@MODELS.register_module() class MAE(BaseModel): """MAE. Implementation of `Masked Autoencoders Are Scalable Vision Learners - `_. - Args: - backbone (Dict, optional): Config dict for encoder. Defaults to None. - neck (Dict, optional): Config dict for encoder. Defaults to None. - head (Dict, optional): Config dict for head functions. - Defaults to None. - loss (Dict, optional): Config dict for loss functions. - Defaults to None. - preprocess_cfg (Dict, optional): Config to preprocess images. - Defaults to None. - init_cfg (Dict or List[Dict], optional): Config dict for weight - initialization. Defaults to None. + `_. """ - def __init__(self, - backbone: Optional[Dict] = None, - neck: Optional[Dict] = None, - head: Optional[Dict] = None, - loss: Optional[Dict] = None, - preprocess_cfg: Optional[Dict] = None, - init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: - super(MAE, self).__init__( - preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) - assert backbone is not None - self.backbone = build_backbone(backbone) - assert neck is not None - self.neck = build_neck(neck) - self.neck.num_patches = self.backbone.num_patches - assert head is not None - self.head = build_head(head) - assert loss is not None - self.loss = build_loss(loss) - - def init_weights(self) -> None: - super(MAE, self).init_weights() - - def extract_feat(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], + def extract_feat(self, batch_inputs: List[torch.Tensor], **kwarg) -> Tuple[torch.Tensor]: - """The forward function to extract features. + """The forward function to extract features from neck. Args: - inputs (List[torch.Tensor]): The input images. - data_samples (List[SelfSupDataSample]): All elements required - during the forward function. + batch_inputs (List[torch.Tensor]): The input images. Returns: - Tuple[torch.Tensor]: backbone outputs. + torch.Tensor: Outputs from neck. """ - return self.backbone(inputs[0]) + latent, _, ids_restore = self.backbone(batch_inputs[0]) + pred = self.neck(latent, ids_restore) + return pred - def forward_train(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> Dict[str, torch.Tensor]: + def loss(self, batch_inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: - inputs (List[torch.Tensor]): The input images. + batch_inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Dict[str, Tensor]: A dictionary of loss components. """ - latent, mask, ids_restore = self.backbone(inputs[0]) + latent, mask, ids_restore = self.backbone(batch_inputs[0]) pred = self.neck(latent, ids_restore) - target = self.head(inputs[0]) - loss = self.loss(pred, target, mask) + loss = self.head(pred, batch_inputs[0], mask) losses = dict(loss=loss) return losses diff --git a/mmselfsup/models/heads/mae_head.py b/mmselfsup/models/heads/mae_head.py index 6b688eff..64cb6279 100644 --- a/mmselfsup/models/heads/mae_head.py +++ b/mmselfsup/models/heads/mae_head.py @@ -4,26 +4,31 @@ from typing import Dict, List import torch from mmcls.models import LabelSmoothLoss from mmcv.cnn.utils.weight_init import trunc_normal_ -from mmcv.runner import BaseModule +from mmengine.model import BaseModule from torch import nn -from ..builder import HEADS +from ..builder import MODELS -@HEADS.register_module() +@MODELS.register_module() class MAEPretrainHead(BaseModule): """Pre-training head for MAE. Args: + loss (dict): Config of loss. norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. """ - def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None: + def __init__(self, + loss: dict, + norm_pix: bool = False, + patch_size: int = 16) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.loss = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -36,9 +41,19 @@ class MAEPretrainHead(BaseModule): x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x - def forward(self, x: torch.Tensor) -> torch.Tensor: + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. - target = self.patchify(x) + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) if self.norm_pix: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) @@ -46,8 +61,25 @@ class MAEPretrainHead(BaseModule): return target + def forward(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE head. -@HEADS.register_module() + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + target = self.construct_target(target) + loss = self.loss(pred, target, mask) + + return loss + + +@MODELS.register_module() class MAEFinetuneHead(BaseModule): """Fine-tuning head for MAE. @@ -83,7 +115,7 @@ class MAEFinetuneHead(BaseModule): return losses -@HEADS.register_module() +@MODELS.register_module() class MAELinprobeHead(BaseModule): """Linear probing head for MAE. diff --git a/mmselfsup/models/losses/mae_loss.py b/mmselfsup/models/losses/mae_loss.py index 66c8546e..e6333f06 100644 --- a/mmselfsup/models/losses/mae_loss.py +++ b/mmselfsup/models/losses/mae_loss.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmcv.runner import BaseModule +from mmengine.model import BaseModule from ..builder import LOSSES diff --git a/mmselfsup/models/necks/mae_neck.py b/mmselfsup/models/necks/mae_neck.py index 15954f77..75e9a8b1 100644 --- a/mmselfsup/models/necks/mae_neck.py +++ b/mmselfsup/models/necks/mae_neck.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from mmcv.cnn import build_norm_layer -from mmcv.runner import BaseModule +from mmengine.model import BaseModule from ..builder import NECKS from ..utils import build_2d_sincos_position_embedding diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index dc79a52d..87b28798 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dall_e import Encoder +from .data_preprocessor import SelfSupDataPreprocessor from .ema import CosineEMA from .extractor import Extractor from .gather_layer import GatherLayer @@ -14,5 +15,5 @@ __all__ = [ 'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes', 'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention', 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder', - 'CosineEMA' + 'CosineEMA', 'SelfSupDataPreprocessor' ] diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py new file mode 100644 index 00000000..72f40304 --- /dev/null +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple + +import torch +from mmengine.data import BaseDataElement +from mmengine.model import ImgDataPreprocessor + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class SelfSupDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for operations, like normalization and bgr to rgb. + + Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats + each item in `inputs` of input data as a list, instead of torch.Tensor. + """ + + def collate_data( + self, + data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]: + """Collating and copying data to the target device. + + This module overwrite the default method by treating each item in + ``input`` of the input data as a list. + + Collates the data sampled from dataloader into a list of tensor and + list of labels, and then copies tensor to the target device. + + Subclasses could override it to be compatible with the custom format + data sampled from custom dataloader. + + Args: + data (Sequence[dict]): Data sampled from dataloader. + + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input + tensor and list of labels at target device. + """ + inputs = [[img.to(self.device) for img in _data['inputs']] + for _data in data] + batch_data_samples: List[BaseDataElement] = [] + # Model can get predictions without any data samples. + for _data in data: + if 'data_sample' in _data: + batch_data_samples.append(_data['data_sample']) + # Move data from CPU to corresponding device. + batch_data_samples = [ + data_sample.to(self.device) for data_sample in batch_data_samples + ] + + if not batch_data_samples: + batch_data_samples = None # type: ignore + + return inputs, batch_data_samples + + def forward( + self, + data: Sequence[dict], + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (Sequence[dict]): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + inputs, batch_data_samples = self.collate_data(data) + # channel transform + if self.channel_conversion: + inputs = [[img_[[2, 1, 0], ...] for img_ in _input] + for _input in inputs] + + # Normalization. Here is what is different from + # :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views + # for an image for some algorithms, e.g. SimCLR, each item in inputs + # is a list, containing multi-views for an image. + inputs = [[(img_ - self.mean) / self.std for img_ in _input] + for _input in inputs] + + batch_inputs = [] + for i in range(len(inputs[0])): + cur_batch = [img[i] for img in inputs] + batch_inputs.append(torch.stack(cur_batch)) + + return batch_inputs, batch_data_samples diff --git a/tests/test_models/test_algorithms/test_mae.py b/tests/test_models/test_algorithms/test_mae.py index 487a4e84..70c141ed 100644 --- a/tests/test_models/test_algorithms/test_mae.py +++ b/tests/test_models/test_algorithms/test_mae.py @@ -20,54 +20,32 @@ neck = dict( decoder_num_heads=16, mlp_ratio=4., ) -head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16) loss = dict(type='MAEReconstructionLoss') +head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss) @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') def test_mae(): - preprocess_cfg = { + data_preprocessor = { 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], - 'to_rgb': True + 'bgr_to_rgb': True } - with pytest.raises(AssertionError): - alg = MAE( - backbone=backbone, - neck=None, - head=head, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = MAE( - backbone=backbone, - neck=neck, - head=None, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = MAE( - backbone=None, - neck=neck, - head=head, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) + alg = MAE( backbone=backbone, neck=neck, head=head, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - alg.init_weights() + data_preprocessor=copy.deepcopy(data_preprocessor)) fake_data = [{ 'inputs': [torch.randn((3, 224, 224))], 'data_sample': SelfSupDataSample() } for _ in range(2)] - fake_outputs = alg(fake_data, return_loss=True) + + fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data) + fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss') assert isinstance(fake_outputs['loss'].item(), float) - fake_inputs, fake_data_samples = alg.preprocss_data(fake_data) - fake_feat = alg.extract_feat( - inputs=fake_inputs, data_samples=fake_data_samples) - assert list(fake_feat[0].shape) == [2, 50, 768] + fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor') + assert list(fake_feats.shape) == [2, 196, 768] diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py new file mode 100644 index 00000000..1d3bb92d --- /dev/null +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmselfsup.core import SelfSupDataSample +from mmselfsup.models.utils import SelfSupDataPreprocessor + + +def test_selfsup_data_preprocessor(): + data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True) + fake_data = [{ + 'inputs': [torch.randn((3, 224, 224))], + 'data_sample': SelfSupDataSample() + } for _ in range(2)] + fake_batches, fake_samples = data_preprocessor(fake_data) + assert len(fake_batches) == 1 + assert len(fake_samples) == 2 diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh index db7dc539..ac36d508 100644 --- a/tools/slurm_train.sh +++ b/tools/slurm_train.sh @@ -5,12 +5,11 @@ set -x PARTITION=$1 JOB_NAME=$2 CONFIG=$3 -WORK_DIR=$4 GPUS=${GPUS:-8} GPUS_PER_NODE=${GPUS_PER_NODE:-8} CPUS_PER_TASK=${CPUS_PER_TASK:-5} SRUN_ARGS=${SRUN_ARGS:-""} -PY_ARGS=${@:5} +PY_ARGS=${@:4} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ srun -p ${PARTITION} \ @@ -21,5 +20,4 @@ srun -p ${PARTITION} \ --cpus-per-task=${CPUS_PER_TASK} \ --kill-on-bad-exit=1 \ ${SRUN_ARGS} \ - python -u tools/train.py ${CONFIG} \ - --work-dir=${WORK_DIR} --seed 0 --launcher="slurm" ${PY_ARGS} + python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}