mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix is_module_wrapper (#1900)
* fix is_module_wrapper * test is_module_wrapper * fix code stylepull/1998/head
parent
e9f48a4f8e
commit
c90f2be0be
|
@ -8,7 +8,8 @@ def is_module_wrapper(module):
|
|||
The following 3 modules in MMCV (and their subclasses) are regarded as
|
||||
module wrappers: DataParallel, DistributedDataParallel,
|
||||
MMDistributedDataParallel (the deprecated version). You may add you own
|
||||
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
|
||||
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
|
||||
its children registries.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to be checked.
|
||||
|
@ -16,5 +17,14 @@ def is_module_wrapper(module):
|
|||
Returns:
|
||||
bool: True if the input module is a module wrapper.
|
||||
"""
|
||||
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
|
||||
return isinstance(module, module_wrappers)
|
||||
|
||||
def is_module_in_wrapper(module, module_wrapper):
|
||||
module_wrappers = tuple(module_wrapper.module_dict.values())
|
||||
if isinstance(module, module_wrappers):
|
||||
return True
|
||||
for child in module_wrapper.children.values():
|
||||
if is_module_in_wrapper(module, child):
|
||||
return True
|
||||
return False
|
||||
|
||||
return is_module_in_wrapper(module, MODULE_WRAPPERS)
|
||||
|
|
|
@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
|
|||
from mmcv.parallel._functions import Scatter, get_input_device, scatter
|
||||
from mmcv.parallel.distributed_deprecated import \
|
||||
MMDistributedDataParallel as DeprecatedMMDDP
|
||||
from mmcv.utils import Registry
|
||||
|
||||
|
||||
def mock(*args, **kwargs):
|
||||
|
@ -74,6 +75,36 @@ def test_is_module_wrapper():
|
|||
module_wraper = ModuleWrapper(model)
|
||||
assert is_module_wrapper(module_wraper)
|
||||
|
||||
# test module wrapper registry in downstream repo
|
||||
MMRAZOR_MODULE_WRAPPERS = Registry(
|
||||
'mmrazor module wrapper', parent=MODULE_WRAPPERS, scope='mmrazor')
|
||||
MMPOSE_MODULE_WRAPPERS = Registry(
|
||||
'mmpose module wrapper', parent=MODULE_WRAPPERS, scope='mmpose')
|
||||
|
||||
@MMRAZOR_MODULE_WRAPPERS.register_module()
|
||||
class ModuleWrapperInRazor(object):
|
||||
|
||||
def __init__(self, module):
|
||||
self.module = module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
@MMPOSE_MODULE_WRAPPERS.register_module()
|
||||
class ModuleWrapperInPose(object):
|
||||
|
||||
def __init__(self, module):
|
||||
self.module = module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
wrapped_module = ModuleWrapperInRazor(model)
|
||||
assert is_module_wrapper(wrapped_module)
|
||||
|
||||
wrapped_module = ModuleWrapperInPose(model)
|
||||
assert is_module_wrapper(wrapped_module)
|
||||
|
||||
|
||||
def test_get_input_device():
|
||||
# if the device is CPU, return -1
|
||||
|
|
Loading…
Reference in New Issue