[Fix] Fix BaseDataPreprocessor and BaseModel (#285)

* fix BaseDataPreprocessor

* fix BaseDataPreprocessor

* change device type to torch.device

* change device type to torch.device

* fix cpu method of base model
This commit is contained in:
Mashiro 2022-06-09 11:45:19 +08:00 committed by GitHub
parent 6f321f88ee
commit a9afdad7a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 33 deletions

View File

@ -3,7 +3,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule
from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs
from .utils import detect_anomalous_params, merge_dict, stack_batch
from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)
@ -11,6 +11,6 @@ __all__ = [
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
'BaseDataPreprocessor', 'ImgDataPreprocessor',
'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs',
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
'merge_dict', 'detect_anomalous_params'
]

View File

@ -179,8 +179,8 @@ class BaseModel(BaseModule):
def to(self, device: Optional[Union[int, torch.device]], *args,
**kwargs) -> nn.Module:
"""Overrides this method to set the ``device`` attribute of
:obj:`BaseDataPreprocessor` additionally
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Args:
device (int or torch.device, optional): the desired device of the
@ -189,19 +189,29 @@ class BaseModel(BaseModule):
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.device = torch.device(device)
self.data_preprocessor.to(device)
return super().to(device)
def cuda(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the ``device`` attribute of
:obj:`BaseDataPreprocessor` additionally
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.device = torch.cuda.current_device()
self.data_preprocessor.cuda()
return super().cuda()
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.cpu()
return super().cpu()
@abstractmethod
def forward(self,
batch_inputs: torch.Tensor,

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmengine.data import BaseDataElement
from mmengine.registry import MODELS
from ..utils import stach_batch_imgs
from ..utils import stack_batch
@MODELS.register_module()
@ -25,18 +25,15 @@ class BaseDataPreprocessor(nn.Module):
forward method to implement custom data pre-processing, such as
batch-resize, MixUp, or CutMix.
Args:
device (int or torch.device): Target device.
Warnings:
Each item of data sampled from dataloader must be a dict and at least
contain the ``inputs`` key. Furthermore, the value of ``inputs``
must be a ``Tensor`` with the same shape.
"""
def __init__(self, device: Union[int, torch.device] = 'cpu'):
def __init__(self):
super().__init__()
self.device = device
self._device = torch.device('cpu')
def collate_data(
self,
@ -56,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
"""
inputs = [_data['inputs'].to(self.device) for _data in data]
inputs = [_data['inputs'].to(self._device) for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
@ -64,7 +61,7 @@ class BaseDataPreprocessor(nn.Module):
batch_data_samples.append(_data['data_sample'])
# Move data from CPU to corresponding device.
batch_data_samples = [
data_sample.to(self.device) for data_sample in batch_data_samples
data_sample.to(self._device) for data_sample in batch_data_samples
]
if not batch_data_samples:
@ -93,6 +90,10 @@ class BaseDataPreprocessor(nn.Module):
batch_inputs = torch.stack(inputs, dim=0)
return batch_inputs, batch_data_samples
@property
def device(self):
return self._device
def to(self, device: Optional[Union[int, torch.device]], *args,
**kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
@ -104,7 +105,7 @@ class BaseDataPreprocessor(nn.Module):
Returns:
nn.Module: The model itself.
"""
self.device = torch.device(device)
self._device = torch.device(device)
return super().to(device)
def cuda(self, *args, **kwargs) -> nn.Module:
@ -113,9 +114,18 @@ class BaseDataPreprocessor(nn.Module):
Returns:
nn.Module: The model itself.
"""
self.device = torch.cuda.current_device()
self._device = torch.device(torch.cuda.current_device())
return super().cuda()
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device('cpu')
return super().cpu()
@MODELS.register_module()
class ImgDataPreprocessor(BaseDataPreprocessor):
@ -158,7 +168,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
device (int or torch.device): Target device.
"""
def __init__(self,
@ -167,9 +176,8 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
device: Union[int, torch.device] = 'cpu'):
super().__init__(device)
rgb_to_bgr: bool = False):
super().__init__()
assert len(mean) == 3 or len(mean) == 1, (
'The length of mean should be 1 or 3 to be compatible with RGB '
f'or gray image, but got {len(mean)}')
@ -208,6 +216,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
# Normalization.
inputs = [(_input - self.mean) / self.std for _input in inputs]
# Pad and stack Tensor.
batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor,
self.pad_value)
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
self.pad_value)
return batch_inputs, batch_data_samples

View File

@ -673,10 +673,10 @@ def trunc_normal_(tensor: Tensor,
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def stach_batch_imgs(tensor_list: List[torch.Tensor],
pad_size_divisor: int = 1,
pad_value: Union[int, float] = 0) -> torch.Tensor:
"""Stack multiple tensors to form a batch and pad the images to the max
def stack_batch(tensor_list: List[torch.Tensor],
pad_size_divisor: int = 1,
pad_value: Union[int, float] = 0) -> torch.Tensor:
"""Stack multiple tensors to form a batch and pad the tensor to the max
shape use the right bottom padding mode in these images. If
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
divisible by ``pad_size_divisor``.
@ -690,7 +690,7 @@ def stach_batch_imgs(tensor_list: List[torch.Tensor],
pad_value (int, float): The padding value. Defaults to 0.
Returns:
Tensor: The 4D-tensor.
Tensor: The n dim tensor.
"""
assert isinstance(
tensor_list,

View File

@ -115,11 +115,13 @@ class TestBaseModel(TestCase):
inputs = torch.randn(3, 1, 1).cuda()
data = dict(inputs=inputs)
model = ToyModel().cuda()
model.val_step([data])
out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda')
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
def test_to(self):
inputs = torch.randn(3, 1, 1).cuda()
inputs = torch.randn(3, 1, 1).to('cuda:0')
data = dict(inputs=inputs)
model = ToyModel().to(torch.cuda.current_device())
model.val_step([data])
out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda')

View File

@ -13,7 +13,7 @@ class TestBaseDataPreprocessor(TestCase):
def test_init(self):
base_data_preprocessor = BaseDataPreprocessor()
self.assertEqual(base_data_preprocessor.device, 'cpu')
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
def test_forward(self):
base_data_preprocessor = BaseDataPreprocessor()
@ -35,6 +35,19 @@ class TestBaseDataPreprocessor(TestCase):
assert_allclose(label1, batch_labels[0])
assert_allclose(label2, batch_labels[1])
if torch.cuda.is_available():
base_data_preprocessor = base_data_preprocessor.cuda()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertEqual(batch_inputs.device.type, 'cuda')
base_data_preprocessor = base_data_preprocessor.cpu()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertEqual(batch_inputs.device.type, 'cpu')
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertEqual(batch_inputs.device.type, 'cuda')
class TestImageDataPreprocessor(TestBaseDataPreprocessor):