diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index ce5e435d..b810cfdc 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .cae_data_preprocessor import CAEDataPreprocessor from .dall_e import Encoder -from .data_preprocessor import (RelativeLocDataPreprocessor, +from .data_preprocessor import (CAEDataPreprocessor, + RelativeLocDataPreprocessor, RotationPredDataPreprocessor, SelfSupDataPreprocessor) from .ema import CosineEMA diff --git a/mmselfsup/models/utils/cae_data_preprocessor.py b/mmselfsup/models/utils/cae_data_preprocessor.py deleted file mode 100644 index 5892fd26..00000000 --- a/mmselfsup/models/utils/cae_data_preprocessor.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Sequence, Tuple - -import torch - -from mmselfsup.registry import MODELS -from .data_preprocessor import SelfSupDataPreprocessor - - -@MODELS.register_module() -class CAEDataPreprocessor(SelfSupDataPreprocessor): - """Image pre-processor for CAE. - - Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module - will normalize the prediction image and target image with different - normalization parameters. - """ - - def forward( - self, - data: Sequence[dict], - training: bool = False - ) -> Tuple[List[torch.Tensor], Optional[list]]: - """Performs normalization、padding and bgr2rgb conversion based on - ``BaseDataPreprocessor``. - - Args: - data (Sequence[dict]): data sampled from dataloader. - training (bool): Whether to enable training time augmentation. If - subclasses override this method, they can perform different - preprocessing strategies for training and testing based on the - value of ``training``. - Returns: - Tuple[torch.Tensor, Optional[list]]: Data in the same format as the - model input. - """ - inputs, batch_data_samples = self.collate_data(data) - # channel transform - if self.channel_conversion: - inputs = [[img_[[2, 1, 0], ...] for img_ in _input] - for _input in inputs] - - # Normalization. Here is what is different from - # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target - # image and prediction image with different normalization params - inputs = [[(_input[0] - self.mean) / self.std, - _input[1] / 255. * 0.8 + 0.1] for _input in inputs] - - batch_inputs = [] - for i in range(len(inputs[0])): - cur_batch = [img[i] for img in inputs] - batch_inputs.append(torch.stack(cur_batch)) - - return batch_inputs, batch_data_samples diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py index 79a62078..8d4350be 100644 --- a/mmselfsup/models/utils/data_preprocessor.py +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -196,3 +196,50 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor): batch_inputs = [img] return batch_inputs, batch_data_samples + + +@MODELS.register_module() +class CAEDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + """ + + def forward( + self, + data: Sequence[dict], + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (Sequence[dict]): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + inputs, batch_data_samples = self.collate_data(data) + # channel transform + if self.channel_conversion: + inputs = [[img_[[2, 1, 0], ...] for img_ in _input] + for _input in inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + inputs = [[(_input[0] - self.mean) / self.std, + _input[1] / 255. * 0.8 + 0.1] for _input in inputs] + + batch_inputs = [] + for i in range(len(inputs[0])): + cur_batch = [img[i] for img in inputs] + batch_inputs.append(torch.stack(cur_batch)) + + return batch_inputs, batch_data_samples diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index 1d3bb92d..d1bcb508 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -6,7 +6,8 @@ from mmselfsup.models.utils import SelfSupDataPreprocessor def test_selfsup_data_preprocessor(): - data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True) + data_preprocessor = SelfSupDataPreprocessor( + rgb_to_bgr=True, mean=[124, 117, 104], std=[59, 58, 58]) fake_data = [{ 'inputs': [torch.randn((3, 224, 224))], 'data_sample': SelfSupDataSample()