[Fix] Regist pytorch ddp and dp to `MODEL_WRAPPERS`, add unit test to `is_model_wrapper` (#474)
* regist pytorch ddp and dp, add unit test * minor refine * Support check custom wrapper * enhance utpull/442/head
parent
d0a74f9af6
commit
576e5c8f91
|
@ -2,12 +2,15 @@
|
|||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from ..utils import detect_anomalous_params
|
||||
|
||||
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
|
||||
MODEL_WRAPPERS.register_module(module=DataParallel)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class MMDistributedDataParallel(DistributedDataParallel):
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.registry import MODEL_WRAPPERS, Registry
|
||||
|
||||
|
||||
def is_model_wrapper(model):
|
||||
def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS):
|
||||
"""Check if a module is a model wrapper.
|
||||
|
||||
The following 4 model in MMEngine (and their subclasses) are regarded as
|
||||
|
@ -12,9 +14,17 @@ def is_model_wrapper(model):
|
|||
|
||||
Args:
|
||||
model (nn.Module): The model to be checked.
|
||||
registry (Registry): The parent registry to search for model wrappers.
|
||||
|
||||
Returns:
|
||||
bool: True if the input model is a model wrapper.
|
||||
"""
|
||||
model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values())
|
||||
return isinstance(model, model_wrappers)
|
||||
module_wrappers = tuple(registry.module_dict.values())
|
||||
if isinstance(model, module_wrappers):
|
||||
return True
|
||||
|
||||
if not registry.children:
|
||||
return False
|
||||
|
||||
for child in registry.children.values():
|
||||
return is_model_wrapper(model, child)
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import destroy_process_group, init_process_group
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.model import (MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel,
|
||||
is_model_wrapper, revert_sync_batchnorm)
|
||||
from mmengine.registry import MODEL_WRAPPERS, Registry
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -18,3 +25,53 @@ def test_revert_syncbn():
|
|||
conv = revert_sync_batchnorm(conv)
|
||||
y = conv(x)
|
||||
assert y.shape == (1, 8, 9, 9)
|
||||
|
||||
|
||||
def test_is_model_wrapper():
|
||||
# Test basic module wrapper.
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '29510'
|
||||
os.environ['RANK'] = str(0)
|
||||
init_process_group(backend='gloo', rank=0, world_size=1)
|
||||
model = nn.Linear(1, 1)
|
||||
|
||||
for wrapper in [
|
||||
DistributedDataParallel, MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, DataParallel
|
||||
]:
|
||||
wrapper_model = wrapper(model)
|
||||
assert is_model_wrapper(wrapper_model)
|
||||
|
||||
# Test `is_model_wrapper` can check model wrapper registered in custom
|
||||
# registry.
|
||||
CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS)
|
||||
|
||||
class CustomModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
pass
|
||||
|
||||
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
|
||||
|
||||
for wrapper in [
|
||||
DistributedDataParallel, MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper
|
||||
]:
|
||||
wrapper_model = wrapper(model)
|
||||
assert is_model_wrapper(wrapper_model)
|
||||
|
||||
# Test `is_model_wrapper` will not check model wrapper in parent
|
||||
# registry from a child registry.
|
||||
for wrapper in [
|
||||
DistributedDataParallel, MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, DataParallel
|
||||
]:
|
||||
wrapper_model = wrapper(model)
|
||||
assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
|
||||
|
||||
wrapper_model = CustomModelWrapper(model)
|
||||
assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
|
||||
destroy_process_group()
|
||||
|
|
Loading…
Reference in New Issue