[Fix] BaseModel & BaseDataPreprocessor `to` method to be consistent with torch.nn.Module (#783)
* fix BaseModel `to` method to be consistent with torch.nn.Module * fix data_preprocessor as well * fix docstring alignment * fix docstring alignmentpull/791/head
parent
0dd0a22e75
commit
bd6791382f
|
@ -155,9 +155,9 @@ class BaseModel(BaseModule):
|
|||
|
||||
Returns:
|
||||
tuple[Tensor, dict]: There are two elements. The first is the
|
||||
loss tensor passed to optim_wrapper which may be a weighted sum of
|
||||
all losses, and the second is log_vars which will be sent to the
|
||||
logger.
|
||||
loss tensor passed to optim_wrapper which may be a weighted sum
|
||||
of all losses, and the second is log_vars which will be sent to
|
||||
the logger.
|
||||
"""
|
||||
log_vars = []
|
||||
for loss_name, loss_value in losses.items():
|
||||
|
@ -177,23 +177,17 @@ class BaseModel(BaseModule):
|
|||
|
||||
return loss, log_vars # type: ignore
|
||||
|
||||
def to(self,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
*args,
|
||||
**kwargs) -> nn.Module:
|
||||
def to(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
||||
additionally.
|
||||
|
||||
Args:
|
||||
device (int, str or torch.device, optional): the desired device
|
||||
of the parameters and buffers in this module.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||
if device is not None:
|
||||
self._set_device(torch.device(device))
|
||||
return super().to(device)
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(
|
||||
self,
|
||||
|
@ -244,7 +238,7 @@ class BaseModel(BaseModule):
|
|||
|
||||
Args:
|
||||
device (torch.device): the desired device of the parameters and
|
||||
buffers in this module.
|
||||
buffers in this module.
|
||||
"""
|
||||
|
||||
def apply_fn(module):
|
||||
|
|
|
@ -84,19 +84,16 @@ class BaseDataPreprocessor(nn.Module):
|
|||
def device(self):
|
||||
return self._device
|
||||
|
||||
def to(self, device: Optional[Union[int, torch.device]], *args,
|
||||
**kwargs) -> nn.Module:
|
||||
def to(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
Args:
|
||||
device (int or torch.device, optional): The desired device of the
|
||||
parameters and buffers in this module.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self._device = torch.device(device)
|
||||
return super().to(device)
|
||||
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||
if device is not None:
|
||||
self._device = torch.device(device)
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from parameterized import parameterized
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel
|
||||
|
@ -11,6 +13,18 @@ from mmengine.optim import OptimWrapper
|
|||
from mmengine.registry import MODELS
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half]
|
||||
|
||||
cpu_devices = ['cpu', torch.device('cpu')]
|
||||
cuda_devices = ['cuda', 0, torch.device('cuda')]
|
||||
devices_to_test = cpu_devices
|
||||
if torch.cuda.is_available():
|
||||
devices_to_test += cuda_devices
|
||||
|
||||
|
||||
def list_product(*args):
|
||||
return list(itertools.product(*args))
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CustomDataPreprocessor(BaseDataPreprocessor):
|
||||
|
@ -158,3 +172,32 @@ class TestBaseModel(TestCase):
|
|||
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
|
||||
self.assertEqual(model.toy_model.data_preprocessor._device,
|
||||
torch.device('cuda'))
|
||||
|
||||
@parameterized.expand(list_product(devices_to_test))
|
||||
def test_to_device(self, device):
|
||||
model = ToyModel().to(device)
|
||||
self.assertTrue(
|
||||
all(p.device.type == torch.device(device).type
|
||||
for p in model.parameters())
|
||||
and model.data_preprocessor._device == torch.device(device))
|
||||
|
||||
@parameterized.expand(list_product(dtypes_to_test))
|
||||
def test_to_dtype(self, dtype):
|
||||
model = ToyModel().to(dtype)
|
||||
self.assertTrue(all(p.dtype == dtype for p in model.parameters()))
|
||||
|
||||
@parameterized.expand(
|
||||
list_product(devices_to_test, dtypes_to_test,
|
||||
['args', 'kwargs', 'hybrid']))
|
||||
def test_to_device_and_dtype(self, device, dtype, mode):
|
||||
if mode == 'args':
|
||||
model = ToyModel().to(device, dtype)
|
||||
elif mode == 'kwargs':
|
||||
model = ToyModel().to(device=device, dtype=dtype)
|
||||
elif mode == 'hybrid':
|
||||
model = ToyModel().to(device, dtype=dtype)
|
||||
self.assertTrue(
|
||||
all(p.dtype == dtype for p in model.parameters())
|
||||
and model.data_preprocessor._device == torch.device(device)
|
||||
and all(p.device.type == torch.device(device).type
|
||||
for p in model.parameters()))
|
||||
|
|
Loading…
Reference in New Issue