128 lines
3.3 KiB
Python
128 lines
3.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
|
|
is_model_wrapper)
|
|
from mmengine.registry import MODEL_WRAPPERS
|
|
|
|
|
|
def mock(*args, **kwargs):
|
|
pass
|
|
|
|
|
|
@patch('torch.distributed._broadcast_coalesced', mock)
|
|
@patch('torch.distributed.broadcast', mock)
|
|
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
|
def test_is_model_wrapper():
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(2, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
# _verify_model_across_ranks is added in torch1.9.0 so we should check
|
|
# whether _verify_model_across_ranks is the member of torch.distributed
|
|
# before mocking
|
|
if hasattr(torch.distributed, '_verify_model_across_ranks'):
|
|
torch.distributed._verify_model_across_ranks = mock
|
|
|
|
model = Model()
|
|
assert not is_model_wrapper(model)
|
|
|
|
mmdp = MMDataParallel(model)
|
|
assert is_model_wrapper(mmdp)
|
|
|
|
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
|
|
assert is_model_wrapper(mmddp)
|
|
|
|
# test model wrapper registry
|
|
@MODEL_WRAPPERS.register_module()
|
|
class ModelWrapper(object):
|
|
|
|
def __init__(self, module):
|
|
self.module = module
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
model_wrapper = ModelWrapper(model)
|
|
assert is_model_wrapper(model_wrapper)
|
|
|
|
|
|
class TestMMDataParallel(TestCase):
|
|
|
|
def setUp(self):
|
|
"""Setup the demo image in every test method.
|
|
|
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
|
tearDown() -> cleanUp()
|
|
"""
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
def train_step(self, x):
|
|
return self.forward(x)
|
|
|
|
def val_step(self, x):
|
|
return self.forward(x)
|
|
|
|
self.model = Model()
|
|
|
|
def test_train_step(self):
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
model = Model()
|
|
mmdp = MMDataParallel(model)
|
|
|
|
# test without train_step attribute
|
|
with pytest.raises(AssertionError):
|
|
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
|
|
|
|
out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
|
|
assert out.shape == (1, 2, 3, 3)
|
|
|
|
def test_val_step(self):
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(1, 2, 1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
model = Model()
|
|
mmdp = MMDataParallel(model)
|
|
|
|
# test without val_step attribute
|
|
with pytest.raises(AssertionError):
|
|
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
|
|
|
|
out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
|
|
assert out.shape == (1, 2, 3, 3)
|