mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Fix]: Move CAE data preprocessor to data preprocessor file and fix UT
This commit is contained in:
parent
8910743c6e
commit
d6dfa9fe40
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .cae_data_preprocessor import CAEDataPreprocessor
|
|
||||||
from .dall_e import Encoder
|
from .dall_e import Encoder
|
||||||
from .data_preprocessor import (RelativeLocDataPreprocessor,
|
from .data_preprocessor import (CAEDataPreprocessor,
|
||||||
|
RelativeLocDataPreprocessor,
|
||||||
RotationPredDataPreprocessor,
|
RotationPredDataPreprocessor,
|
||||||
SelfSupDataPreprocessor)
|
SelfSupDataPreprocessor)
|
||||||
from .ema import CosineEMA
|
from .ema import CosineEMA
|
||||||
|
@ -1,54 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
from typing import List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from mmselfsup.registry import MODELS
|
|
||||||
from .data_preprocessor import SelfSupDataPreprocessor
|
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
|
||||||
class CAEDataPreprocessor(SelfSupDataPreprocessor):
|
|
||||||
"""Image pre-processor for CAE.
|
|
||||||
|
|
||||||
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
|
|
||||||
will normalize the prediction image and target image with different
|
|
||||||
normalization parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
data: Sequence[dict],
|
|
||||||
training: bool = False
|
|
||||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
|
||||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
|
||||||
``BaseDataPreprocessor``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data (Sequence[dict]): data sampled from dataloader.
|
|
||||||
training (bool): Whether to enable training time augmentation. If
|
|
||||||
subclasses override this method, they can perform different
|
|
||||||
preprocessing strategies for training and testing based on the
|
|
||||||
value of ``training``.
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
|
||||||
model input.
|
|
||||||
"""
|
|
||||||
inputs, batch_data_samples = self.collate_data(data)
|
|
||||||
# channel transform
|
|
||||||
if self.channel_conversion:
|
|
||||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
|
||||||
for _input in inputs]
|
|
||||||
|
|
||||||
# Normalization. Here is what is different from
|
|
||||||
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
|
|
||||||
# image and prediction image with different normalization params
|
|
||||||
inputs = [[(_input[0] - self.mean) / self.std,
|
|
||||||
_input[1] / 255. * 0.8 + 0.1] for _input in inputs]
|
|
||||||
|
|
||||||
batch_inputs = []
|
|
||||||
for i in range(len(inputs[0])):
|
|
||||||
cur_batch = [img[i] for img in inputs]
|
|
||||||
batch_inputs.append(torch.stack(cur_batch))
|
|
||||||
|
|
||||||
return batch_inputs, batch_data_samples
|
|
@ -196,3 +196,50 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
|
|||||||
batch_inputs = [img]
|
batch_inputs = [img]
|
||||||
|
|
||||||
return batch_inputs, batch_data_samples
|
return batch_inputs, batch_data_samples
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class CAEDataPreprocessor(SelfSupDataPreprocessor):
|
||||||
|
"""Image pre-processor for CAE.
|
||||||
|
|
||||||
|
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
|
||||||
|
will normalize the prediction image and target image with different
|
||||||
|
normalization parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
data: Sequence[dict],
|
||||||
|
training: bool = False
|
||||||
|
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||||
|
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||||
|
``BaseDataPreprocessor``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Sequence[dict]): data sampled from dataloader.
|
||||||
|
training (bool): Whether to enable training time augmentation. If
|
||||||
|
subclasses override this method, they can perform different
|
||||||
|
preprocessing strategies for training and testing based on the
|
||||||
|
value of ``training``.
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||||
|
model input.
|
||||||
|
"""
|
||||||
|
inputs, batch_data_samples = self.collate_data(data)
|
||||||
|
# channel transform
|
||||||
|
if self.channel_conversion:
|
||||||
|
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||||
|
for _input in inputs]
|
||||||
|
|
||||||
|
# Normalization. Here is what is different from
|
||||||
|
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
|
||||||
|
# image and prediction image with different normalization params
|
||||||
|
inputs = [[(_input[0] - self.mean) / self.std,
|
||||||
|
_input[1] / 255. * 0.8 + 0.1] for _input in inputs]
|
||||||
|
|
||||||
|
batch_inputs = []
|
||||||
|
for i in range(len(inputs[0])):
|
||||||
|
cur_batch = [img[i] for img in inputs]
|
||||||
|
batch_inputs.append(torch.stack(cur_batch))
|
||||||
|
|
||||||
|
return batch_inputs, batch_data_samples
|
||||||
|
@ -6,7 +6,8 @@ from mmselfsup.models.utils import SelfSupDataPreprocessor
|
|||||||
|
|
||||||
|
|
||||||
def test_selfsup_data_preprocessor():
|
def test_selfsup_data_preprocessor():
|
||||||
data_preprocessor = SelfSupDataPreprocessor(rgb_to_bgr=True)
|
data_preprocessor = SelfSupDataPreprocessor(
|
||||||
|
rgb_to_bgr=True, mean=[124, 117, 104], std=[59, 58, 58])
|
||||||
fake_data = [{
|
fake_data = [{
|
||||||
'inputs': [torch.randn((3, 224, 224))],
|
'inputs': [torch.randn((3, 224, 224))],
|
||||||
'data_sample': SelfSupDataSample()
|
'data_sample': SelfSupDataSample()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user