162 lines
5.7 KiB
Python
162 lines
5.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import unittest
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
import torch.distributed as torch_dist
|
|
import torch.nn as nn
|
|
from torch.optim import SGD
|
|
|
|
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
|
MMSeparateDistributedDataParallel)
|
|
from mmengine.optim import OptimWrapper, OptimWrapperDict
|
|
from mmengine.testing import assert_allclose
|
|
from mmengine.testing._internal import MultiProcessTestCase
|
|
|
|
|
|
class ToyModel(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 1, 1)
|
|
self.conv2 = nn.Conv2d(1, 1, 1)
|
|
|
|
def forward(self, x, data_samples=None, mode='tensor'):
|
|
if mode == 'loss':
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
return dict(loss=x)
|
|
elif mode == 'predict':
|
|
return x
|
|
else:
|
|
return x
|
|
|
|
|
|
class ComplexModel(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 1, 1)
|
|
self.conv2 = nn.Conv2d(3, 1, 1)
|
|
|
|
def train_step(self, data, optim_wrapper):
|
|
batch_inputs, _ = self.data_preprocessor(data)
|
|
loss1 = self.conv1(batch_inputs)
|
|
optim_wrapper['optim_wrapper1'].update_params(loss1)
|
|
loss2 = self.conv2(batch_inputs)
|
|
optim_wrapper['optim_wrapper2'].update_params(loss2)
|
|
return dict(loss1=loss1, loss2=loss2)
|
|
|
|
def val_step(self, data):
|
|
return 1
|
|
|
|
def test_step(self, data):
|
|
return 2
|
|
|
|
def forward(self):
|
|
pass
|
|
|
|
|
|
class TestModelWrapper(MultiProcessTestCase):
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def test_train_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
# Test `optim_wrapper` is a instance of `OptimWrapper`
|
|
model = ToyModel()
|
|
ddp_model = MMDistributedDataParallel(module=model)
|
|
optimizer = SGD(ddp_model.parameters(), lr=0)
|
|
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
|
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
|
data = dict(inputs=inputs, data_sample=MagicMock())
|
|
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
|
|
grad = ddp_model.module.conv1.weight.grad
|
|
assert_allclose(grad, torch.zeros_like(grad))
|
|
|
|
def test_val_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
model = ToyModel()
|
|
ddp_model = MMDistributedDataParallel(module=model)
|
|
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
|
data = dict(inputs=inputs, data_sample=MagicMock())
|
|
# Test get predictions.
|
|
predictions = ddp_model.val_step([data])
|
|
self.assertIsInstance(predictions, torch.Tensor)
|
|
|
|
def test_test_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
model = ToyModel()
|
|
ddp_model = MMDistributedDataParallel(module=model)
|
|
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
|
data = dict(inputs=inputs, data_sample=MagicMock())
|
|
predictions = ddp_model.test_step([data])
|
|
self.assertIsInstance(predictions, torch.Tensor)
|
|
|
|
def _init_dist_env(self, rank, world_size):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29510'
|
|
os.environ['RANK'] = str(rank)
|
|
torch_dist.init_process_group(
|
|
backend='gloo', rank=rank, world_size=world_size)
|
|
|
|
|
|
@unittest.skipIf(
|
|
not torch.cuda.is_available(), reason='cuda should be available')
|
|
class TestMMSeparateDistributedDataParallel(TestModelWrapper):
|
|
|
|
def test_train_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
# Test `optim_wrapper` is a dict. In this case,
|
|
# There will be two independently updated `DistributedDataParallel`
|
|
# submodules.
|
|
model = ComplexModel()
|
|
ddp_model = MMSeparateDistributedDataParallel(model.cuda())
|
|
optimizer1 = SGD(model.conv1.parameters(), lr=0.1)
|
|
optimizer2 = SGD(model.conv1.parameters(), lr=0.2)
|
|
optim_wrapper1 = OptimWrapper(optimizer1, 1)
|
|
optim_wrapper2 = OptimWrapper(optimizer2, 1)
|
|
optim_wrapper_dict = OptimWrapperDict(
|
|
optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2)
|
|
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
|
|
data = dict(inputs=inputs)
|
|
# Automatically sync grads of `optim_wrapper1` since
|
|
# `cumulative_iters` = 1
|
|
ddp_model.train()
|
|
self.assertTrue(ddp_model.training)
|
|
ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict)
|
|
|
|
def test_val_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
model = ComplexModel()
|
|
ddp_model = MMSeparateDistributedDataParallel(model)
|
|
data = torch.randn(3, 1, 1)
|
|
# Test get predictions.
|
|
ddp_model.eval()
|
|
self.assertFalse(ddp_model.training)
|
|
predictions = ddp_model.val_step([data])
|
|
self.assertEqual(predictions, 1)
|
|
|
|
def test_test_step(self):
|
|
self._init_dist_env(self.rank, self.world_size)
|
|
model = ComplexModel()
|
|
ddp_model = MMSeparateDistributedDataParallel(model)
|
|
data = torch.randn(3, 1, 1)
|
|
# Test get predictions.
|
|
ddp_model.eval()
|
|
self.assertFalse(ddp_model.training)
|
|
predictions = ddp_model.test_step(data)
|
|
self.assertEqual(predictions, 2)
|
|
|
|
def _init_dist_env(self, rank, world_size):
|
|
"""Initialize the distributed environment."""
|
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
os.environ['MASTER_PORT'] = '29515'
|
|
os.environ['RANK'] = str(rank)
|
|
torch_dist.init_process_group(
|
|
backend='gloo', rank=rank, world_size=world_size)
|