mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Fix]: Move CAE data preprocessor to data preprocessor file and fix UT
This commit is contained in:
parent
8910743c6e
commit
d6dfa9fe40
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cae_data_preprocessor import CAEDataPreprocessor
|
||||
from .dall_e import Encoder
|
||||
from .data_preprocessor import (RelativeLocDataPreprocessor,
|
||||
from .data_preprocessor import (CAEDataPreprocessor,
|
||||
RelativeLocDataPreprocessor,
|
||||
RotationPredDataPreprocessor,
|
||||
SelfSupDataPreprocessor)
|
||||
from .ema import CosineEMA
|
||||
|
@ -1,54 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from .data_preprocessor import SelfSupDataPreprocessor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAEDataPreprocessor(SelfSupDataPreprocessor):
|
||||
"""Image pre-processor for CAE.
|
||||
|
||||
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
|
||||
will normalize the prediction image and target image with different
|
||||
normalization parameters.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
|
||||
# image and prediction image with different normalization params
|
||||
inputs = [[(_input[0] - self.mean) / self.std,
|
||||
_input[1] / 255. * 0.8 + 0.1] for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
cur_batch = [img[i] for img in inputs]
|
||||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_inputs, batch_data_samples
|
@ -196,3 +196,50 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
|
||||
batch_inputs = [img]
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAEDataPreprocessor(SelfSupDataPreprocessor):
|
||||
"""Image pre-processor for CAE.
|
||||
|
||||
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
|
||||
will normalize the prediction image and target image with different
|
||||
normalization parameters.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
|
||||
# image and prediction image with different normalization params
|
||||
inputs = [[(_input[0] - self.mean) / self.std,
|
||||
_input[1] / 255. * 0.8 + 0.1] for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
cur_batch = [img[i] for img in inputs]
|
||||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
@ -6,7 +6,8 @@ from mmselfsup.models.utils import SelfSupDataPreprocessor
|
||||
|
||||
|
||||
def test_selfsup_data_preprocessor():
|
||||
data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True)
|
||||
data_preprocessor = SelfSupDataPreprocessor(
|
||||
rgb_to_bgr=True, mean=[124, 117, 104], std=[59, 58, 58])
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224))],
|
||||
'data_sample': SelfSupDataSample()
|
||||
|
Loading…
x
Reference in New Issue
Block a user