mmengine/tests/test_model/test_wrappers/test_data_parallel.py
Mashiro 3e3866c1b9
[Feature] Add optimizer wrapper (#265)
* Support multiple optimizers

* minor refinement

* improve unit tests

* minor fix

* Update unit tests for resuming or saving ckpt for multiple optimizers

* refine docstring

* refine docstring

* fix typo

* update docstring

* refactor the logic to build multiple optimizers

* resolve comments

* ParamSchedulers spports multiple optimizers

* add optimizer_wrapper

* fix comment and docstirng

* fix unit test

* add unit test

* refine docstring

* RuntimeInfoHook supports printing multi learning rates

* resolve comments

* add optimizer_wrapper

* fix mypy

* fix lint

* fix OptimizerWrapperDict docstring and add unit test

* rename OptimizerWrapper to OptimWrapper, OptimWrapperDict inherit OptimWrapper, and fix as comment

* Fix AmpOptimizerWrapper

* rename build_optmizer_wrapper to build_optim_wrapper

* refine optimizer wrapper

* fix AmpOptimWrapper.step, docstring

* resolve confict

* rename DefaultOptimConstructor

* fix as comment

* rename clig grad auguments

* refactor optim_wrapper config

* fix docstring of DefaultOptimWrapperConstructor

fix docstring of DefaultOptimWrapperConstructor

* add get_lr method to OptimWrapper and OptimWrapperDict

* skip some amp unit test

* fix unit test

* fix get_lr, get_momentum docstring

* refactor get_lr, get_momentum, fix as comment

* fix error message

Co-authored-by: zhouzaida <zhouzaida@163.com>
2022-06-01 18:04:38 +08:00

142 lines
3.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
from mmengine.registry import MODEL_WRAPPERS
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_model_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# whether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
# _verify_model_across_ranks is added in torch1.11.0 so we should check
# whether _verify_params_across_processes is the member of
# torch.distributed before mocking
if hasattr(torch.distributed, '_verify_params_across_processes'):
torch.distributed._verify_params_across_processes = mock
model = Model()
assert not is_model_wrapper(model)
mmdp = MMDataParallel(model)
assert is_model_wrapper(mmdp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(mmddp)
torch_dp = DataParallel(model)
assert is_model_wrapper(torch_dp)
torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(torch_ddp)
# test model wrapper registry
@MODEL_WRAPPERS.register_module()
class ModelWrapper:
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
model_wrapper = ModelWrapper(model)
assert is_model_wrapper(model_wrapper)
class TestMMDataParallel(TestCase):
def setUp(self):
"""Setup the demo image in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
def train_step(self, x):
return self.forward(x)
def val_step(self, x):
return self.forward(x)
self.model = Model()
def test_train_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without train_step attribute
with pytest.raises(AssertionError):
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)
def test_val_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without val_step attribute
with pytest.raises(AssertionError):
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)