mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix BaseDataPreprocessor and BaseModel (#285)
* fix BaseDataPreprocessor * fix BaseDataPreprocessor * change device type to torch.device * change device type to torch.device * fix cpu method of base model
This commit is contained in:
parent
6f321f88ee
commit
a9afdad7a8
@ -3,7 +3,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
|||||||
StochasticWeightAverage)
|
StochasticWeightAverage)
|
||||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||||
from .base_module import BaseModule
|
from .base_module import BaseModule
|
||||||
from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs
|
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||||
from .wrappers import (MMDistributedDataParallel,
|
from .wrappers import (MMDistributedDataParallel,
|
||||||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||||
|
|
||||||
@ -11,6 +11,6 @@ __all__ = [
|
|||||||
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
|
'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
|
||||||
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
||||||
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
||||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs',
|
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
||||||
'merge_dict', 'detect_anomalous_params'
|
'merge_dict', 'detect_anomalous_params'
|
||||||
]
|
]
|
||||||
|
@ -179,8 +179,8 @@ class BaseModel(BaseModule):
|
|||||||
|
|
||||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
def to(self, device: Optional[Union[int, torch.device]], *args,
|
||||||
**kwargs) -> nn.Module:
|
**kwargs) -> nn.Module:
|
||||||
"""Overrides this method to set the ``device`` attribute of
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
||||||
:obj:`BaseDataPreprocessor` additionally
|
additionally.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (int or torch.device, optional): the desired device of the
|
device (int or torch.device, optional): the desired device of the
|
||||||
@ -189,19 +189,29 @@ class BaseModel(BaseModule):
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
self.data_preprocessor.device = torch.device(device)
|
self.data_preprocessor.to(device)
|
||||||
return super().to(device)
|
return super().to(device)
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||||
"""Overrides this method to set the ``device`` attribute of
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
|
||||||
:obj:`BaseDataPreprocessor` additionally
|
additionally.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
self.data_preprocessor.device = torch.cuda.current_device()
|
self.data_preprocessor.cuda()
|
||||||
return super().cuda()
|
return super().cuda()
|
||||||
|
|
||||||
|
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||||
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
|
||||||
|
additionally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model itself.
|
||||||
|
"""
|
||||||
|
self.data_preprocessor.cpu()
|
||||||
|
return super().cpu()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self,
|
def forward(self,
|
||||||
batch_inputs: torch.Tensor,
|
batch_inputs: torch.Tensor,
|
||||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
from mmengine.registry import MODELS
|
from mmengine.registry import MODELS
|
||||||
from ..utils import stach_batch_imgs
|
from ..utils import stack_batch
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
@ -25,18 +25,15 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
forward method to implement custom data pre-processing, such as
|
forward method to implement custom data pre-processing, such as
|
||||||
batch-resize, MixUp, or CutMix.
|
batch-resize, MixUp, or CutMix.
|
||||||
|
|
||||||
Args:
|
|
||||||
device (int or torch.device): Target device.
|
|
||||||
|
|
||||||
Warnings:
|
Warnings:
|
||||||
Each item of data sampled from dataloader must be a dict and at least
|
Each item of data sampled from dataloader must be a dict and at least
|
||||||
contain the ``inputs`` key. Furthermore, the value of ``inputs``
|
contain the ``inputs`` key. Furthermore, the value of ``inputs``
|
||||||
must be a ``Tensor`` with the same shape.
|
must be a ``Tensor`` with the same shape.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device: Union[int, torch.device] = 'cpu'):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self._device = torch.device('cpu')
|
||||||
|
|
||||||
def collate_data(
|
def collate_data(
|
||||||
self,
|
self,
|
||||||
@ -56,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||||
tensor and list of labels at target device.
|
tensor and list of labels at target device.
|
||||||
"""
|
"""
|
||||||
inputs = [_data['inputs'].to(self.device) for _data in data]
|
inputs = [_data['inputs'].to(self._device) for _data in data]
|
||||||
batch_data_samples: List[BaseDataElement] = []
|
batch_data_samples: List[BaseDataElement] = []
|
||||||
# Model can get predictions without any data samples.
|
# Model can get predictions without any data samples.
|
||||||
for _data in data:
|
for _data in data:
|
||||||
@ -64,7 +61,7 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
batch_data_samples.append(_data['data_sample'])
|
batch_data_samples.append(_data['data_sample'])
|
||||||
# Move data from CPU to corresponding device.
|
# Move data from CPU to corresponding device.
|
||||||
batch_data_samples = [
|
batch_data_samples = [
|
||||||
data_sample.to(self.device) for data_sample in batch_data_samples
|
data_sample.to(self._device) for data_sample in batch_data_samples
|
||||||
]
|
]
|
||||||
|
|
||||||
if not batch_data_samples:
|
if not batch_data_samples:
|
||||||
@ -93,6 +90,10 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
batch_inputs = torch.stack(inputs, dim=0)
|
batch_inputs = torch.stack(inputs, dim=0)
|
||||||
return batch_inputs, batch_data_samples
|
return batch_inputs, batch_data_samples
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self._device
|
||||||
|
|
||||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
def to(self, device: Optional[Union[int, torch.device]], *args,
|
||||||
**kwargs) -> nn.Module:
|
**kwargs) -> nn.Module:
|
||||||
"""Overrides this method to set the :attr:`device`
|
"""Overrides this method to set the :attr:`device`
|
||||||
@ -104,7 +105,7 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
return super().to(device)
|
return super().to(device)
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||||
@ -113,9 +114,18 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
self.device = torch.cuda.current_device()
|
self._device = torch.device(torch.cuda.current_device())
|
||||||
return super().cuda()
|
return super().cuda()
|
||||||
|
|
||||||
|
def cpu(self, *args, **kwargs) -> nn.Module:
|
||||||
|
"""Overrides this method to set the :attr:`device`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model itself.
|
||||||
|
"""
|
||||||
|
self._device = torch.device('cpu')
|
||||||
|
return super().cpu()
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class ImgDataPreprocessor(BaseDataPreprocessor):
|
class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||||
@ -158,7 +168,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
Defaults to False.
|
Defaults to False.
|
||||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
device (int or torch.device): Target device.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -167,9 +176,8 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
pad_size_divisor: int = 1,
|
pad_size_divisor: int = 1,
|
||||||
pad_value: Union[float, int] = 0,
|
pad_value: Union[float, int] = 0,
|
||||||
bgr_to_rgb: bool = False,
|
bgr_to_rgb: bool = False,
|
||||||
rgb_to_bgr: bool = False,
|
rgb_to_bgr: bool = False):
|
||||||
device: Union[int, torch.device] = 'cpu'):
|
super().__init__()
|
||||||
super().__init__(device)
|
|
||||||
assert len(mean) == 3 or len(mean) == 1, (
|
assert len(mean) == 3 or len(mean) == 1, (
|
||||||
'The length of mean should be 1 or 3 to be compatible with RGB '
|
'The length of mean should be 1 or 3 to be compatible with RGB '
|
||||||
f'or gray image, but got {len(mean)}')
|
f'or gray image, but got {len(mean)}')
|
||||||
@ -208,6 +216,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
# Normalization.
|
# Normalization.
|
||||||
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
||||||
# Pad and stack Tensor.
|
# Pad and stack Tensor.
|
||||||
batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor,
|
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
|
||||||
self.pad_value)
|
self.pad_value)
|
||||||
return batch_inputs, batch_data_samples
|
return batch_inputs, batch_data_samples
|
||||||
|
@ -673,10 +673,10 @@ def trunc_normal_(tensor: Tensor,
|
|||||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||||
|
|
||||||
|
|
||||||
def stach_batch_imgs(tensor_list: List[torch.Tensor],
|
def stack_batch(tensor_list: List[torch.Tensor],
|
||||||
pad_size_divisor: int = 1,
|
pad_size_divisor: int = 1,
|
||||||
pad_value: Union[int, float] = 0) -> torch.Tensor:
|
pad_value: Union[int, float] = 0) -> torch.Tensor:
|
||||||
"""Stack multiple tensors to form a batch and pad the images to the max
|
"""Stack multiple tensors to form a batch and pad the tensor to the max
|
||||||
shape use the right bottom padding mode in these images. If
|
shape use the right bottom padding mode in these images. If
|
||||||
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
|
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
|
||||||
divisible by ``pad_size_divisor``.
|
divisible by ``pad_size_divisor``.
|
||||||
@ -690,7 +690,7 @@ def stach_batch_imgs(tensor_list: List[torch.Tensor],
|
|||||||
pad_value (int, float): The padding value. Defaults to 0.
|
pad_value (int, float): The padding value. Defaults to 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The 4D-tensor.
|
Tensor: The n dim tensor.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_list,
|
tensor_list,
|
||||||
|
@ -115,11 +115,13 @@ class TestBaseModel(TestCase):
|
|||||||
inputs = torch.randn(3, 1, 1).cuda()
|
inputs = torch.randn(3, 1, 1).cuda()
|
||||||
data = dict(inputs=inputs)
|
data = dict(inputs=inputs)
|
||||||
model = ToyModel().cuda()
|
model = ToyModel().cuda()
|
||||||
model.val_step([data])
|
out = model.val_step([data])
|
||||||
|
self.assertEqual(out.device.type, 'cuda')
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
|
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
|
||||||
def test_to(self):
|
def test_to(self):
|
||||||
inputs = torch.randn(3, 1, 1).cuda()
|
inputs = torch.randn(3, 1, 1).to('cuda:0')
|
||||||
data = dict(inputs=inputs)
|
data = dict(inputs=inputs)
|
||||||
model = ToyModel().to(torch.cuda.current_device())
|
model = ToyModel().to(torch.cuda.current_device())
|
||||||
model.val_step([data])
|
out = model.val_step([data])
|
||||||
|
self.assertEqual(out.device.type, 'cuda')
|
||||||
|
@ -13,7 +13,7 @@ class TestBaseDataPreprocessor(TestCase):
|
|||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
base_data_preprocessor = BaseDataPreprocessor()
|
base_data_preprocessor = BaseDataPreprocessor()
|
||||||
self.assertEqual(base_data_preprocessor.device, 'cpu')
|
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
|
||||||
|
|
||||||
def test_forward(self):
|
def test_forward(self):
|
||||||
base_data_preprocessor = BaseDataPreprocessor()
|
base_data_preprocessor = BaseDataPreprocessor()
|
||||||
@ -35,6 +35,19 @@ class TestBaseDataPreprocessor(TestCase):
|
|||||||
assert_allclose(label1, batch_labels[0])
|
assert_allclose(label1, batch_labels[0])
|
||||||
assert_allclose(label2, batch_labels[1])
|
assert_allclose(label2, batch_labels[1])
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||||
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||||
|
|
||||||
|
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||||
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||||
|
|
||||||
|
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||||
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||||
|
|
||||||
|
|
||||||
class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user