mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: Refactor base model and create a sample for MAE
This commit is contained in:
parent
962f9b9752
commit
35e0988527
@ -1,4 +1,5 @@
|
||||
# dataset settings
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
@ -2,7 +2,6 @@ default_scope = 'mmselfsup'
|
||||
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
@ -17,14 +16,14 @@ env_cfg = dict(
|
||||
)
|
||||
|
||||
log_processor = dict(
|
||||
interval=50,
|
||||
custom_keys=[dict(data_src='', method='mean', windows_size='global')])
|
||||
window_size=10,
|
||||
custom_cfg=[dict(data_src='', method='mean', windows_size='global')])
|
||||
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='SelfSupLocalVisualizer',
|
||||
vis_backends=vis_backends,
|
||||
name='visualizer')
|
||||
# vis_backends = [dict(type='LocalVisBackend')]
|
||||
# visualizer = dict(
|
||||
# type='SelfSupLocalVisualizer',
|
||||
# vis_backends=vis_backends,
|
||||
# name='visualizer')
|
||||
# custom_hooks = [dict(type='SelfSupVisualizationHook', interval=10)]
|
||||
|
||||
log_level = 'INFO'
|
||||
|
@ -1,6 +1,8 @@
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MAE',
|
||||
data_preprocessor=dict(
|
||||
mean=[124, 117, 104], std=[59, 58, 58], bgr_to_rgb=True),
|
||||
backbone=dict(type='MAEViT', arch='b', patch_size=16, mask_ratio=0.75),
|
||||
neck=dict(
|
||||
type='MAEPretrainDecoder',
|
||||
@ -12,5 +14,8 @@ model = dict(
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
),
|
||||
head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16),
|
||||
loss=dict(type='MAEReconstructionLoss'))
|
||||
head=dict(
|
||||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='MAEReconstructionLoss')))
|
||||
|
@ -1,6 +1,6 @@
|
||||
# optimizer
|
||||
# optimizer_wrapper
|
||||
optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05)
|
||||
optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
@ -16,4 +16,4 @@ param_scheduler = [
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
|
||||
|
@ -1,4 +1,4 @@
|
||||
_base_ = 'mae_vit-base-p16_8xb512-coslr-400e_in1k.py'
|
||||
|
||||
# mixed precision
|
||||
fp16 = dict(loss_scale='dynamic')
|
||||
optim_wrapper = dict(type='AmpOptimWrapper')
|
||||
|
@ -5,23 +5,26 @@ _base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# dataset
|
||||
data = dict(samples_per_gpu=512, workers_per_gpu=32)
|
||||
# dataset 8 x 512
|
||||
train_dataloader = dict(batch_size=512, num_workers=16)
|
||||
|
||||
# optimizer
|
||||
# optimizer wrapper
|
||||
optimizer = dict(
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
paramwise_options={
|
||||
'norm': dict(weight_decay=0.),
|
||||
'bias': dict(weight_decay=0.),
|
||||
'pos_embed': dict(weight_decay=0.),
|
||||
'mask_token': dict(weight_decay=0.),
|
||||
'cls_token': dict(weight_decay=0.)
|
||||
})
|
||||
optimizer_config = dict()
|
||||
type='AdamW', lr=1.5e-4 * 4096 / 256, betas=(0.9, 0.95), weight_decay=0.05)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'ln': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
'cls_token': dict(decay_mult=0.)
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = [
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
@ -38,13 +41,6 @@ scheduler = [
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# schedule
|
||||
runner = dict(max_epochs=400)
|
||||
|
||||
# runtime
|
||||
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
|
||||
persistent_workers = True
|
||||
log_config = dict(
|
||||
interval=100, hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
])
|
||||
# runtime settings
|
||||
# pre-train for 400 epochs
|
||||
train_cfg = dict(max_epochs=400)
|
||||
|
@ -1,40 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
from mmengine.model import BaseModel as _BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.utils import get_module_device
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
class BaseModel(BaseModule, metaclass=ABCMeta):
|
||||
"""Base model class for self-supervised learning.
|
||||
class BaseModel(_BaseModel):
|
||||
"""BaseModel for SelfSup.
|
||||
|
||||
All algorithms should inherit this module.
|
||||
|
||||
Args:
|
||||
preprocess_cfg (Dict): Config to preprocess images.
|
||||
init_cfg (Dict, optional): Config to initialize models.
|
||||
backbone (dict): The backbone module. See
|
||||
:mod:`mmcls.models.backbones`.
|
||||
neck (dict, optional): The neck module to process features from
|
||||
backbone. See :mod:`mmcls.models.necks`. Defaults to None.
|
||||
head (dict, optional): The head module to do prediction and calculate
|
||||
loss from processed features. See :mod:`mmcls.models.heads`.
|
||||
Notice that if the head is not set, almost all methods cannot be
|
||||
used except :meth:`extract_feat`. Defaults to None.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"SelfSupDataPreprocessor" as type.
|
||||
See :class:`SelfSupDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): the config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
preprocess_cfg: Dict,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super(BaseModel, self).__init__(init_cfg)
|
||||
self.fp16_enabled = False
|
||||
assert 'mean' in preprocess_cfg
|
||||
self.register_buffer(
|
||||
'mean_norm',
|
||||
torch.tensor(preprocess_cfg.pop('mean')).view(3, 1, 1))
|
||||
assert 'std' in preprocess_cfg
|
||||
self.register_buffer(
|
||||
'std_norm',
|
||||
torch.tensor(preprocess_cfg.pop('std')).view(3, 1, 1))
|
||||
assert 'to_rgb' in preprocess_cfg
|
||||
self.to_rgb = preprocess_cfg.pop('to_rgb')
|
||||
backbone: dict,
|
||||
neck: Optional[dict] = None,
|
||||
head: Optional[dict] = None,
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
||||
if pretrained is not None:
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
# The build process is in MMEngine, so we need to add scope here.
|
||||
data_preprocessor.setdefault('type',
|
||||
'mmselfsup.SelfSupDataPreprocessor')
|
||||
|
||||
super().__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
self.backbone = MODELS.build(backbone)
|
||||
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
|
||||
if head is not None:
|
||||
self.head = MODELS.build(head)
|
||||
|
||||
@property
|
||||
def with_neck(self) -> bool:
|
||||
@ -44,156 +70,98 @@ class BaseModel(BaseModule, metaclass=ABCMeta):
|
||||
def with_head(self) -> bool:
|
||||
return hasattr(self, 'head') and self.head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> object:
|
||||
"""The forward function to extract features.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
"""
|
||||
raise NotImplementedError('``extract_feat`` should be implemented')
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> object:
|
||||
"""The forward function in training
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
"""
|
||||
raise NotImplementedError('``forward_train`` should be implemented')
|
||||
|
||||
def forward_test(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> object:
|
||||
"""The forward function in testing
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
"""
|
||||
raise NotImplementedError('``forward_test`` should be implemented')
|
||||
|
||||
@auto_fp16(apply_to=('data', ))
|
||||
def forward(self,
|
||||
data: List[Dict],
|
||||
return_loss: bool = False,
|
||||
extract: bool = False,
|
||||
**kwargs) -> object:
|
||||
"""Forward function of model.
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""Returns losses or predictions of training, validation, testing, and
|
||||
simple inference process.
|
||||
|
||||
Calls either forward_train, forward_test or extract_feat function
|
||||
according to the mode.
|
||||
This module overwrites the abstract method in ``BaseModel``.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): The input data for model.
|
||||
return_loss (bool): Train mode or test mode. Defaults to False.
|
||||
extract (bool): Whether or not only extract features from model.
|
||||
If set to True, the ``return_loss`` will be ignored. Defaults
|
||||
to False.
|
||||
batch_inputs (torch.Tensor): batch input tensor collated by
|
||||
:attr:`data_preprocessor`.
|
||||
data_samples (List[BaseDataElement], optional):
|
||||
data samples collated by :attr:`data_preprocessor`.
|
||||
mode (str): mode should be one of ``loss``, ``predict`` and
|
||||
``tensor``
|
||||
- ``loss``: Called by ``train_step`` and return loss ``dict``
|
||||
used for logging
|
||||
- ``predict``: Called by ``val_step`` and ``test_step``
|
||||
and return list of ``BaseDataElement`` results used for
|
||||
computing metric.
|
||||
- ``tensor``: Called by custom use to get ``Tensor`` type
|
||||
results.
|
||||
Returns:
|
||||
ForwardResults:
|
||||
- If ``mode == loss``, return a ``dict`` of loss tensor used
|
||||
for backward and logging.
|
||||
- If ``mode == predict``, return a ``list`` of
|
||||
:obj:`BaseDataElement` for computing metric
|
||||
and getting inference result.
|
||||
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
|
||||
or ``dict of tensor for custom use.
|
||||
"""
|
||||
# preprocess images
|
||||
inputs, data_samples = self.preprocss_data(data)
|
||||
|
||||
# Whether or not extract features. If set to True, the ``return_loss``
|
||||
# will be ignored.
|
||||
if extract:
|
||||
return self.extract_feat(
|
||||
inputs=inputs, data_samples=data_samples, **kwargs)
|
||||
|
||||
if return_loss:
|
||||
losses = self.forward_train(
|
||||
inputs=inputs, data_samples=data_samples, **kwargs)
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
outputs = dict(loss=loss, log_vars=log_vars)
|
||||
return outputs
|
||||
if mode == 'tensor':
|
||||
feats = self.extract_feat(batch_inputs)
|
||||
return feats
|
||||
elif mode == 'loss':
|
||||
return self.loss(batch_inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, data_samples)
|
||||
else:
|
||||
# should be a list of SelfSupDataSample
|
||||
return self.forward_test(
|
||||
inputs=inputs, data_samples=data_samples, **kwargs)
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def _parse_losses(self, losses: Dict) -> Tuple[torch.Tensor, Dict]:
|
||||
"""Parse the raw outputs (losses) of the network.
|
||||
def extract_feat(self, batch_inputs):
|
||||
"""Extract features from the input tensor with shape (N, C, ...).
|
||||
|
||||
This is a abstract method, and subclass should overwrite this methods
|
||||
if needed.
|
||||
|
||||
Args:
|
||||
losses (Dict): Raw output of the network, which usually contain
|
||||
losses and other necessary information.
|
||||
batch_inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
``(num_samples, num_channels, *img_shape)``.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, Dict]: (loss, log_vars), loss is the loss
|
||||
tensor which may be a weighted sum of all losses, log_vars
|
||||
contains all the variables to be sent to the logger.
|
||||
tuple | Tensor: The output of specified stage.
|
||||
The output depends on detailed implementation.
|
||||
"""
|
||||
log_vars = OrderedDict()
|
||||
for loss_name, loss_value in losses.items():
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
log_vars[loss_name] = loss_value.mean()
|
||||
elif isinstance(loss_value, list):
|
||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
||||
elif isinstance(loss_value, dict):
|
||||
for name, value in loss_value.items():
|
||||
log_vars[name] = value
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{loss_name} is not a tensor or list of tensors')
|
||||
raise NotImplementedError
|
||||
|
||||
loss = sum(_value for _key, _value in log_vars.items()
|
||||
if 'loss' in _key)
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
data_samples: List[SelfSupDataSample]) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
log_vars['loss'] = loss
|
||||
for loss_name, loss_value in log_vars.items():
|
||||
# reduce loss when distributed training
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
loss_value = loss_value.data.clone()
|
||||
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
||||
log_vars[loss_name] = loss_value.item()
|
||||
|
||||
return loss, log_vars
|
||||
|
||||
def preprocss_data(
|
||||
self,
|
||||
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
|
||||
"""Process input data during training, testing or extracting.
|
||||
This is a abstract method, and subclass should overwrite this methods
|
||||
if needed.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): The data to be processed, which
|
||||
comes from dataloader.
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[SelfSupDataSample]): The annotation data of
|
||||
every samples.
|
||||
|
||||
Returns:
|
||||
tuple: It should contain 2 item.
|
||||
- batch_images (List[torch.Tensor]): The batch image tensor.
|
||||
- data_samples (List[SelfSupDataSample], Optional): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_label`. Return None If the input data does not
|
||||
contain `data_sample`.
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
# data_['inputs] is a list
|
||||
images = [data_['inputs'] for data_ in data]
|
||||
data_samples = [data_['data_sample'] for data_ in data]
|
||||
raise NotImplementedError
|
||||
|
||||
device = get_module_device(self)
|
||||
data_samples = [data_sample.to(device) for data_sample in data_samples]
|
||||
images = [[img_.to(device) for img_ in img] for img in images]
|
||||
def predict(self,
|
||||
batch_inputs: tuple,
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
"""Predict results from the extracted features.
|
||||
|
||||
# convert images to rgb
|
||||
if self.to_rgb and images[0][0].size(0) == 3:
|
||||
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
|
||||
This module returns the logits before loss, which are used to compute
|
||||
all kinds of metrics. This is a abstract method, and subclass should
|
||||
overwrite this methods if needed.
|
||||
|
||||
# normalize images
|
||||
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
|
||||
for img in images]
|
||||
|
||||
# reconstruct images into several batches. For example, SimCLR needs
|
||||
# two crops for each image, and this code snippet will convert images
|
||||
# into two batches, each containing one crop of an image.
|
||||
batch_images = []
|
||||
for i in range(len(images[0])):
|
||||
cur_batch = [img[i] for img in images]
|
||||
batch_images.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_images, data_samples
|
||||
Args:
|
||||
feats (tuple): The features extracted from the backbone.
|
||||
data_samples (List[BaseDataElement], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**kwargs: Other keyword arguments accepted by the ``predict``
|
||||
method of :attr:`head`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -1,86 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_neck)
|
||||
from ..builder import MODELS
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MAE(BaseModel):
|
||||
"""MAE.
|
||||
|
||||
Implementation of `Masked Autoencoders Are Scalable Vision Learners
|
||||
<https://arxiv.org/abs/2111.06377>`_.
|
||||
Args:
|
||||
backbone (Dict, optional): Config dict for encoder. Defaults to None.
|
||||
neck (Dict, optional): Config dict for encoder. Defaults to None.
|
||||
head (Dict, optional): Config dict for head functions.
|
||||
Defaults to None.
|
||||
loss (Dict, optional): Config dict for loss functions.
|
||||
Defaults to None.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
<https://arxiv.org/abs/2111.06377>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Optional[Dict] = None,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super(MAE, self).__init__(
|
||||
preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
assert neck is not None
|
||||
self.neck = build_neck(neck)
|
||||
self.neck.num_patches = self.backbone.num_patches
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def init_weights(self) -> None:
|
||||
super(MAE, self).init_weights()
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features.
|
||||
"""The forward function to extract features from neck.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
torch.Tensor: Outputs from neck.
|
||||
"""
|
||||
return self.backbone(inputs[0])
|
||||
latent, _, ids_restore = self.backbone(batch_inputs[0])
|
||||
pred = self.neck(latent, ids_restore)
|
||||
return pred
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
latent, mask, ids_restore = self.backbone(inputs[0])
|
||||
latent, mask, ids_restore = self.backbone(batch_inputs[0])
|
||||
pred = self.neck(latent, ids_restore)
|
||||
target = self.head(inputs[0])
|
||||
loss = self.loss(pred, target, mask)
|
||||
loss = self.head(pred, batch_inputs[0], mask)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
@ -4,26 +4,31 @@ from typing import Dict, List
|
||||
import torch
|
||||
from mmcls.models import LabelSmoothLoss
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from ..builder import HEADS
|
||||
from ..builder import MODELS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MAEPretrainHead(BaseModule):
|
||||
"""Pre-training head for MAE.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of loss.
|
||||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None:
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -36,9 +41,19 @@ class MAEPretrainHead(BaseModule):
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Construct the reconstruction target.
|
||||
|
||||
target = self.patchify(x)
|
||||
In addition to splitting images into tokens, this module will also
|
||||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
"""
|
||||
target = self.patchify(target)
|
||||
if self.norm_pix:
|
||||
mean = target.mean(dim=-1, keepdim=True)
|
||||
var = target.var(dim=-1, keepdim=True)
|
||||
@ -46,8 +61,25 @@ class MAEPretrainHead(BaseModule):
|
||||
|
||||
return target
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of MAE head.
|
||||
|
||||
@HEADS.register_module()
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
target = self.construct_target(target)
|
||||
loss = self.loss(pred, target, mask)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAEFinetuneHead(BaseModule):
|
||||
"""Fine-tuning head for MAE.
|
||||
|
||||
@ -83,7 +115,7 @@ class MAEFinetuneHead(BaseModule):
|
||||
return losses
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MAELinprobeHead(BaseModule):
|
||||
"""Linear probing head for MAE.
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from ..builder import LOSSES
|
||||
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from ..builder import NECKS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dall_e import Encoder
|
||||
from .data_preprocessor import SelfSupDataPreprocessor
|
||||
from .ema import CosineEMA
|
||||
from .extractor import Extractor
|
||||
from .gather_layer import GatherLayer
|
||||
@ -14,5 +15,5 @@ __all__ = [
|
||||
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
|
||||
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
|
||||
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
|
||||
'CosineEMA'
|
||||
'CosineEMA', 'SelfSupDataPreprocessor'
|
||||
]
|
||||
|
93
mmselfsup/models/utils/data_preprocessor.py
Normal file
93
mmselfsup/models/utils/data_preprocessor.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.model import ImgDataPreprocessor
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SelfSupDataPreprocessor(ImgDataPreprocessor):
|
||||
"""Image pre-processor for operations, like normalization and bgr to rgb.
|
||||
|
||||
Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats
|
||||
each item in `inputs` of input data as a list, instead of torch.Tensor.
|
||||
"""
|
||||
|
||||
def collate_data(
|
||||
self,
|
||||
data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Collating and copying data to the target device.
|
||||
|
||||
This module overwrite the default method by treating each item in
|
||||
``input`` of the input data as a list.
|
||||
|
||||
Collates the data sampled from dataloader into a list of tensor and
|
||||
list of labels, and then copies tensor to the target device.
|
||||
|
||||
Subclasses could override it to be compatible with the custom format
|
||||
data sampled from custom dataloader.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): Data sampled from dataloader.
|
||||
|
||||
Returns:
|
||||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||
tensor and list of labels at target device.
|
||||
"""
|
||||
inputs = [[img.to(self.device) for img in _data['inputs']]
|
||||
for _data in data]
|
||||
batch_data_samples: List[BaseDataElement] = []
|
||||
# Model can get predictions without any data samples.
|
||||
for _data in data:
|
||||
if 'data_sample' in _data:
|
||||
batch_data_samples.append(_data['data_sample'])
|
||||
# Move data from CPU to corresponding device.
|
||||
batch_data_samples = [
|
||||
data_sample.to(self.device) for data_sample in batch_data_samples
|
||||
]
|
||||
|
||||
if not batch_data_samples:
|
||||
batch_data_samples = None # type: ignore
|
||||
|
||||
return inputs, batch_data_samples
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
cur_batch = [img[i] for img in inputs]
|
||||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_inputs, batch_data_samples
|
@ -20,54 +20,32 @@ neck = dict(
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
)
|
||||
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16)
|
||||
loss = dict(type='MAEReconstructionLoss')
|
||||
head = dict(type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae():
|
||||
preprocess_cfg = {
|
||||
data_preprocessor = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MAE(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MAE(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MAE(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
alg = MAE(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
alg.init_weights()
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224))],
|
||||
'data_sample': SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
|
||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert list(fake_feat[0].shape) == [2, 50, 768]
|
||||
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats.shape) == [2, 196, 768]
|
||||
|
16
tests/test_models/test_utils/test_data_preprocessor.py
Normal file
16
tests/test_models/test_utils/test_data_preprocessor.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models.utils import SelfSupDataPreprocessor
|
||||
|
||||
|
||||
def test_selfsup_data_preprocessor():
|
||||
data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True)
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224))],
|
||||
'data_sample': SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
fake_batches, fake_samples = data_preprocessor(fake_data)
|
||||
assert len(fake_batches) == 1
|
||||
assert len(fake_samples) == 2
|
@ -5,12 +5,11 @@ set -x
|
||||
PARTITION=$1
|
||||
JOB_NAME=$2
|
||||
CONFIG=$3
|
||||
WORK_DIR=$4
|
||||
GPUS=${GPUS:-8}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
PY_ARGS=${@:5}
|
||||
PY_ARGS=${@:4}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
srun -p ${PARTITION} \
|
||||
@ -21,5 +20,4 @@ srun -p ${PARTITION} \
|
||||
--cpus-per-task=${CPUS_PER_TASK} \
|
||||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u tools/train.py ${CONFIG} \
|
||||
--work-dir=${WORK_DIR} --seed 0 --launcher="slurm" ${PY_ARGS}
|
||||
python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
|
||||
|
Loading…
x
Reference in New Issue
Block a user