mmengine/tests/test_model/test_wrappers/test_data_parallel.py

136 lines
3.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
from mmengine.registry import MODEL_WRAPPERS
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_model_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# whether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
model = Model()
assert not is_model_wrapper(model)
mmdp = MMDataParallel(model)
assert is_model_wrapper(mmdp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(mmddp)
torch_dp = DataParallel(model)
assert is_model_wrapper(torch_dp)
torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(torch_ddp)
# test model wrapper registry
@MODEL_WRAPPERS.register_module()
class ModelWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
model_wrapper = ModelWrapper(model)
assert is_model_wrapper(model_wrapper)
class TestMMDataParallel(TestCase):
def setUp(self):
"""Setup the demo image in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
def train_step(self, x):
return self.forward(x)
def val_step(self, x):
return self.forward(x)
self.model = Model()
def test_train_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without train_step attribute
with pytest.raises(AssertionError):
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)
def test_val_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without val_step attribute
with pytest.raises(AssertionError):
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)