mmengine/tests/test_model/test_wrappers/test_data_parallel.py

128 lines
3.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2022-02-27 20:34:52 +08:00
from unittest import TestCase
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
from mmengine.registry import MODEL_WRAPPERS
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_model_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# whether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
model = Model()
assert not is_model_wrapper(model)
mmdp = MMDataParallel(model)
assert is_model_wrapper(mmdp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(mmddp)
# test model wrapper registry
@MODEL_WRAPPERS.register_module()
class ModelWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
model_wrapper = ModelWrapper(model)
assert is_model_wrapper(model_wrapper)
2022-02-27 20:34:52 +08:00
class TestMMDataParallel(TestCase):
def setUp(self):
"""Setup the demo image in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
2022-02-27 20:34:52 +08:00
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
def train_step(self, x):
return self.forward(x)
def val_step(self, x):
return self.forward(x)
self.model = Model()
def test_train_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without train_step attribute
with pytest.raises(AssertionError):
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
2022-02-27 20:34:52 +08:00
out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)
def test_val_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without val_step attribute
with pytest.raises(AssertionError):
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
2022-02-27 20:34:52 +08:00
out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
assert out.shape == (1, 2, 3, 3)