mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix BaseDataPreprocessor and BaseModel (#285)
* fix BaseDataPreprocessor * fix BaseDataPreprocessor * change device type to torch.device * change device type to torch.device * fix cpu method of base model
This commit is contained in:
parent
6f321f88ee
commit
a9afdad7a8
@ -3,7 +3,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from .base_module import BaseModule
|
||||
from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs
|
||||
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||
from .wrappers import (MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||
|
||||
@ -11,6 +11,6 @@ __all__ = [
|
||||
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
|
||||
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
||||
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs',
|
||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
||||
'merge_dict', 'detect_anomalous_params'
|
||||
]
|
||||
|
@ -179,8 +179,8 @@ class BaseModel(BaseModule):
|
||||
|
||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
||||
**kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the ``device`` attribute of
|
||||
:obj:`BaseDataPreprocessor` additionally
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
||||
additionally.
|
||||
|
||||
Args:
|
||||
device (int or torch.device, optional): the desired device of the
|
||||
@ -189,19 +189,29 @@ class BaseModel(BaseModule):
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self.data_preprocessor.device = torch.device(device)
|
||||
self.data_preprocessor.to(device)
|
||||
return super().to(device)
|
||||
|
||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the ``device`` attribute of
|
||||
:obj:`BaseDataPreprocessor` additionally
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
|
||||
additionally.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self.data_preprocessor.device = torch.cuda.current_device()
|
||||
self.data_preprocessor.cuda()
|
||||
return super().cuda()
|
||||
|
||||
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
|
||||
additionally.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self.data_preprocessor.cpu()
|
||||
return super().cpu()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import MODELS
|
||||
from ..utils import stach_batch_imgs
|
||||
from ..utils import stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -25,18 +25,15 @@ class BaseDataPreprocessor(nn.Module):
|
||||
forward method to implement custom data pre-processing, such as
|
||||
batch-resize, MixUp, or CutMix.
|
||||
|
||||
Args:
|
||||
device (int or torch.device): Target device.
|
||||
|
||||
Warnings:
|
||||
Each item of data sampled from dataloader must be a dict and at least
|
||||
contain the ``inputs`` key. Furthermore, the value of ``inputs``
|
||||
must be a ``Tensor`` with the same shape.
|
||||
"""
|
||||
|
||||
def __init__(self, device: Union[int, torch.device] = 'cpu'):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
def collate_data(
|
||||
self,
|
||||
@ -56,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
|
||||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||
tensor and list of labels at target device.
|
||||
"""
|
||||
inputs = [_data['inputs'].to(self.device) for _data in data]
|
||||
inputs = [_data['inputs'].to(self._device) for _data in data]
|
||||
batch_data_samples: List[BaseDataElement] = []
|
||||
# Model can get predictions without any data samples.
|
||||
for _data in data:
|
||||
@ -64,7 +61,7 @@ class BaseDataPreprocessor(nn.Module):
|
||||
batch_data_samples.append(_data['data_sample'])
|
||||
# Move data from CPU to corresponding device.
|
||||
batch_data_samples = [
|
||||
data_sample.to(self.device) for data_sample in batch_data_samples
|
||||
data_sample.to(self._device) for data_sample in batch_data_samples
|
||||
]
|
||||
|
||||
if not batch_data_samples:
|
||||
@ -93,6 +90,10 @@ class BaseDataPreprocessor(nn.Module):
|
||||
batch_inputs = torch.stack(inputs, dim=0)
|
||||
return batch_inputs, batch_data_samples
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
||||
**kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
@ -104,7 +105,7 @@ class BaseDataPreprocessor(nn.Module):
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self.device = torch.device(device)
|
||||
self._device = torch.device(device)
|
||||
return super().to(device)
|
||||
|
||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||
@ -113,9 +114,18 @@ class BaseDataPreprocessor(nn.Module):
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self.device = torch.cuda.current_device()
|
||||
self._device = torch.device(torch.cuda.current_device())
|
||||
return super().cuda()
|
||||
|
||||
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self._device = torch.device('cpu')
|
||||
return super().cpu()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||
@ -158,7 +168,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||
Defaults to False.
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
device (int or torch.device): Target device.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -167,9 +176,8 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[float, int] = 0,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
device: Union[int, torch.device] = 'cpu'):
|
||||
super().__init__(device)
|
||||
rgb_to_bgr: bool = False):
|
||||
super().__init__()
|
||||
assert len(mean) == 3 or len(mean) == 1, (
|
||||
'The length of mean should be 1 or 3 to be compatible with RGB '
|
||||
f'or gray image, but got {len(mean)}')
|
||||
@ -208,6 +216,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||
# Normalization.
|
||||
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
||||
# Pad and stack Tensor.
|
||||
batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
return batch_inputs, batch_data_samples
|
||||
|
@ -673,10 +673,10 @@ def trunc_normal_(tensor: Tensor,
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def stach_batch_imgs(tensor_list: List[torch.Tensor],
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[int, float] = 0) -> torch.Tensor:
|
||||
"""Stack multiple tensors to form a batch and pad the images to the max
|
||||
def stack_batch(tensor_list: List[torch.Tensor],
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[int, float] = 0) -> torch.Tensor:
|
||||
"""Stack multiple tensors to form a batch and pad the tensor to the max
|
||||
shape use the right bottom padding mode in these images. If
|
||||
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
|
||||
divisible by ``pad_size_divisor``.
|
||||
@ -690,7 +690,7 @@ def stach_batch_imgs(tensor_list: List[torch.Tensor],
|
||||
pad_value (int, float): The padding value. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
Tensor: The n dim tensor.
|
||||
"""
|
||||
assert isinstance(
|
||||
tensor_list,
|
||||
|
@ -115,11 +115,13 @@ class TestBaseModel(TestCase):
|
||||
inputs = torch.randn(3, 1, 1).cuda()
|
||||
data = dict(inputs=inputs)
|
||||
model = ToyModel().cuda()
|
||||
model.val_step([data])
|
||||
out = model.val_step([data])
|
||||
self.assertEqual(out.device.type, 'cuda')
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
|
||||
def test_to(self):
|
||||
inputs = torch.randn(3, 1, 1).cuda()
|
||||
inputs = torch.randn(3, 1, 1).to('cuda:0')
|
||||
data = dict(inputs=inputs)
|
||||
model = ToyModel().to(torch.cuda.current_device())
|
||||
model.val_step([data])
|
||||
out = model.val_step([data])
|
||||
self.assertEqual(out.device.type, 'cuda')
|
||||
|
@ -13,7 +13,7 @@ class TestBaseDataPreprocessor(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
base_data_preprocessor = BaseDataPreprocessor()
|
||||
self.assertEqual(base_data_preprocessor.device, 'cpu')
|
||||
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
|
||||
|
||||
def test_forward(self):
|
||||
base_data_preprocessor = BaseDataPreprocessor()
|
||||
@ -35,6 +35,19 @@ class TestBaseDataPreprocessor(TestCase):
|
||||
assert_allclose(label1, batch_labels[0])
|
||||
assert_allclose(label2, batch_labels[1])
|
||||
|
||||
if torch.cuda.is_available():
|
||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
|
||||
class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user