mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] BaseModel & BaseDataPreprocessor to
method to be consistent with torch.nn.Module (#783)
* fix BaseModel `to` method to be consistent with torch.nn.Module * fix data_preprocessor as well * fix docstring alignment * fix docstring alignment
This commit is contained in:
parent
0dd0a22e75
commit
bd6791382f
@ -155,9 +155,9 @@ class BaseModel(BaseModule):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[Tensor, dict]: There are two elements. The first is the
|
tuple[Tensor, dict]: There are two elements. The first is the
|
||||||
loss tensor passed to optim_wrapper which may be a weighted sum of
|
loss tensor passed to optim_wrapper which may be a weighted sum
|
||||||
all losses, and the second is log_vars which will be sent to the
|
of all losses, and the second is log_vars which will be sent to
|
||||||
logger.
|
the logger.
|
||||||
"""
|
"""
|
||||||
log_vars = []
|
log_vars = []
|
||||||
for loss_name, loss_value in losses.items():
|
for loss_name, loss_value in losses.items():
|
||||||
@ -177,23 +177,17 @@ class BaseModel(BaseModule):
|
|||||||
|
|
||||||
return loss, log_vars # type: ignore
|
return loss, log_vars # type: ignore
|
||||||
|
|
||||||
def to(self,
|
def to(self, *args, **kwargs) -> nn.Module:
|
||||||
device: Optional[Union[int, str, torch.device]] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs) -> nn.Module:
|
|
||||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
||||||
additionally.
|
additionally.
|
||||||
|
|
||||||
Args:
|
|
||||||
device (int, str or torch.device, optional): the desired device
|
|
||||||
of the parameters and buffers in this module.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
|
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self._set_device(torch.device(device))
|
self._set_device(torch.device(device))
|
||||||
return super().to(device)
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
def cuda(
|
def cuda(
|
||||||
self,
|
self,
|
||||||
@ -244,7 +238,7 @@ class BaseModel(BaseModule):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
device (torch.device): the desired device of the parameters and
|
device (torch.device): the desired device of the parameters and
|
||||||
buffers in this module.
|
buffers in this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def apply_fn(module):
|
def apply_fn(module):
|
||||||
|
@ -84,19 +84,16 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self._device
|
return self._device
|
||||||
|
|
||||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
def to(self, *args, **kwargs) -> nn.Module:
|
||||||
**kwargs) -> nn.Module:
|
|
||||||
"""Overrides this method to set the :attr:`device`
|
"""Overrides this method to set the :attr:`device`
|
||||||
|
|
||||||
Args:
|
|
||||||
device (int or torch.device, optional): The desired device of the
|
|
||||||
parameters and buffers in this module.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The model itself.
|
nn.Module: The model itself.
|
||||||
"""
|
"""
|
||||||
self._device = torch.device(device)
|
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||||
return super().to(device)
|
if device is not None:
|
||||||
|
self._device = torch.device(device)
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||||
"""Overrides this method to set the :attr:`device`
|
"""Overrides this method to set the :attr:`device`
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from parameterized import parameterized
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
|
|
||||||
from mmengine.model import BaseDataPreprocessor, BaseModel
|
from mmengine.model import BaseDataPreprocessor, BaseModel
|
||||||
@ -11,6 +13,18 @@ from mmengine.optim import OptimWrapper
|
|||||||
from mmengine.registry import MODELS
|
from mmengine.registry import MODELS
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half]
|
||||||
|
|
||||||
|
cpu_devices = ['cpu', torch.device('cpu')]
|
||||||
|
cuda_devices = ['cuda', 0, torch.device('cuda')]
|
||||||
|
devices_to_test = cpu_devices
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
devices_to_test += cuda_devices
|
||||||
|
|
||||||
|
|
||||||
|
def list_product(*args):
|
||||||
|
return list(itertools.product(*args))
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
class CustomDataPreprocessor(BaseDataPreprocessor):
|
class CustomDataPreprocessor(BaseDataPreprocessor):
|
||||||
@ -158,3 +172,32 @@ class TestBaseModel(TestCase):
|
|||||||
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
|
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
|
||||||
self.assertEqual(model.toy_model.data_preprocessor._device,
|
self.assertEqual(model.toy_model.data_preprocessor._device,
|
||||||
torch.device('cuda'))
|
torch.device('cuda'))
|
||||||
|
|
||||||
|
@parameterized.expand(list_product(devices_to_test))
|
||||||
|
def test_to_device(self, device):
|
||||||
|
model = ToyModel().to(device)
|
||||||
|
self.assertTrue(
|
||||||
|
all(p.device.type == torch.device(device).type
|
||||||
|
for p in model.parameters())
|
||||||
|
and model.data_preprocessor._device == torch.device(device))
|
||||||
|
|
||||||
|
@parameterized.expand(list_product(dtypes_to_test))
|
||||||
|
def test_to_dtype(self, dtype):
|
||||||
|
model = ToyModel().to(dtype)
|
||||||
|
self.assertTrue(all(p.dtype == dtype for p in model.parameters()))
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
list_product(devices_to_test, dtypes_to_test,
|
||||||
|
['args', 'kwargs', 'hybrid']))
|
||||||
|
def test_to_device_and_dtype(self, device, dtype, mode):
|
||||||
|
if mode == 'args':
|
||||||
|
model = ToyModel().to(device, dtype)
|
||||||
|
elif mode == 'kwargs':
|
||||||
|
model = ToyModel().to(device=device, dtype=dtype)
|
||||||
|
elif mode == 'hybrid':
|
||||||
|
model = ToyModel().to(device, dtype=dtype)
|
||||||
|
self.assertTrue(
|
||||||
|
all(p.dtype == dtype for p in model.parameters())
|
||||||
|
and model.data_preprocessor._device == torch.device(device)
|
||||||
|
and all(p.device.type == torch.device(device).type
|
||||||
|
for p in model.parameters()))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user