[Fix] Fix is_module_wrapper (#1900)

* fix is_module_wrapper

* test is_module_wrapper

* fix code style
pull/1998/head
whcao 2022-05-25 17:32:41 +08:00 committed by GitHub
parent e9f48a4f8e
commit c90f2be0be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 3 deletions

View File

@ -8,7 +8,8 @@ def is_module_wrapper(module):
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
module (nn.Module): The module to be checked.
@ -16,5 +17,14 @@ def is_module_wrapper(module):
Returns:
bool: True if the input module is a module wrapper.
"""
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
return isinstance(module, module_wrappers)
def is_module_in_wrapper(module, module_wrapper):
module_wrappers = tuple(module_wrapper.module_dict.values())
if isinstance(module, module_wrappers):
return True
for child in module_wrapper.children.values():
if is_module_in_wrapper(module, child):
return True
return False
return is_module_in_wrapper(module, MODULE_WRAPPERS)

View File

@ -11,6 +11,7 @@ from mmcv.parallel import (MODULE_WRAPPERS, MMDataParallel,
from mmcv.parallel._functions import Scatter, get_input_device, scatter
from mmcv.parallel.distributed_deprecated import \
MMDistributedDataParallel as DeprecatedMMDDP
from mmcv.utils import Registry
def mock(*args, **kwargs):
@ -74,6 +75,36 @@ def test_is_module_wrapper():
module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)
# test module wrapper registry in downstream repo
MMRAZOR_MODULE_WRAPPERS = Registry(
'mmrazor module wrapper', parent=MODULE_WRAPPERS, scope='mmrazor')
MMPOSE_MODULE_WRAPPERS = Registry(
'mmpose module wrapper', parent=MODULE_WRAPPERS, scope='mmpose')
@MMRAZOR_MODULE_WRAPPERS.register_module()
class ModuleWrapperInRazor(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
@MMPOSE_MODULE_WRAPPERS.register_module()
class ModuleWrapperInPose(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
wrapped_module = ModuleWrapperInRazor(model)
assert is_module_wrapper(wrapped_module)
wrapped_module = ModuleWrapperInPose(model)
assert is_module_wrapper(wrapped_module)
def test_get_input_device():
# if the device is CPU, return -1