mmengine/tests/test_model/test_wrappers/test_model_wrapper.py
Mashiro 8770c6c7fc
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
2022-08-24 22:04:55 +08:00

262 lines
9.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import unittest
from unittest.mock import MagicMock
import torch
import torch.distributed as torch_dist
import torch.nn as nn
from torch.optim import SGD
from mmengine.dist import all_gather
from mmengine.model import (BaseDataPreprocessor, BaseModel,
ExponentialMovingAverage,
MMDistributedDataParallel,
MMSeparateDistributedDataParallel)
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
from mmengine.model import MMFullyShardedDataParallel # noqa: F401
class ToyDataPreprocessor(BaseDataPreprocessor):
def forward(self, data: dict, training: bool = False):
self.called = True
return super().forward(data, training)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(data_preprocessor=ToyDataPreprocessor())
self.conv1 = nn.Conv2d(3, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, inputs, data_sample=None, mode='tensor'):
x = self.conv1(inputs)
x = self.conv2(x)
if mode == 'loss':
return dict(loss=x)
elif mode == 'predict':
return x
else:
return x
class ComplexModel(BaseModel):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 1, 1)
self.conv2 = nn.Conv2d(3, 1, 1)
def train_step(self, data, optim_wrapper):
inputs = self.data_preprocessor(data)['inputs']
loss1 = self.conv1(inputs)
optim_wrapper['optim_wrapper1'].update_params(loss1)
loss2 = self.conv2(inputs)
optim_wrapper['optim_wrapper2'].update_params(loss2)
return dict(loss1=loss1, loss2=loss2)
def val_step(self, data):
return 1
def test_step(self, data):
return 2
def forward(self):
pass
class TestDistributedDataParallel(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available')
def test_train_step(self):
self._init_dist_env(self.rank, self.world_size)
# Mixed precision training and gradient asynchronous should be valid at
# the same time
model = ToyModel().cuda()
ddp_model = MMDistributedDataParallel(module=model)
optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
self.assertIs(res.dtype, torch.float16)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
with self.assertRaises(AssertionError):
assert_allclose(all_grads[0], all_grads[1])
# Gradient accumulation
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
# Test update params and clean grads.
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
ddp_model = MMDistributedDataParallel(module=model)
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
# Test get predictions.
predictions = ddp_model.val_step(data)
self.assertIsInstance(predictions, torch.Tensor)
self.assertTrue(model.data_preprocessor.called)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
ddp_model = MMDistributedDataParallel(module=model)
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
predictions = ddp_model.test_step(data)
self.assertIsInstance(predictions, torch.Tensor)
self.assertTrue(model.data_preprocessor.called)
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
@unittest.skipIf(
not torch.cuda.is_available(), reason='cuda should be available')
class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
def test_init(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
model.ema = ExponentialMovingAverage(nn.Conv2d(1, 1, 1))
model.act = nn.ReLU()
ddp_model = MMSeparateDistributedDataParallel(model.cuda())
self.assertIsInstance(ddp_model.module.ema, ExponentialMovingAverage)
self.assertIsInstance(ddp_model.module.conv1,
MMDistributedDataParallel)
self.assertIsInstance(ddp_model.module.act, nn.ReLU)
def test_train_step(self):
self._init_dist_env(self.rank, self.world_size)
# Test `optim_wrapper` is a dict. In this case,
# There will be two independently updated `DistributedDataParallel`
# submodules.
model = ComplexModel()
ddp_model = MMSeparateDistributedDataParallel(model.cuda())
optimizer1 = SGD(model.conv1.parameters(), lr=0.1)
optimizer2 = SGD(model.conv1.parameters(), lr=0.2)
optim_wrapper1 = OptimWrapper(optimizer1, 1)
optim_wrapper2 = OptimWrapper(optimizer2, 1)
optim_wrapper_dict = OptimWrapperDict(
optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2)
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
# Automatically sync grads of `optim_wrapper1` since
# `cumulative_iters` = 1
ddp_model.train()
self.assertTrue(ddp_model.training)
ddp_model.train_step(data, optim_wrapper=optim_wrapper_dict)
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
ddp_model = MMSeparateDistributedDataParallel(model)
data = torch.randn(1, 3, 1, 1)
# Test get predictions.
ddp_model.eval()
self.assertFalse(ddp_model.training)
predictions = ddp_model.val_step(data)
self.assertEqual(predictions, 1)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
ddp_model = MMSeparateDistributedDataParallel(model)
data = torch.randn(1, 3, 1, 1)
# Test get predictions.
ddp_model.eval()
self.assertFalse(ddp_model.training)
predictions = ddp_model.test_step(data)
self.assertEqual(predictions, 2)
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29515'
os.environ['RANK'] = str(rank)
torch_dist.init_process_group(
backend='gloo', rank=rank, world_size=world_size)
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp')
@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.11.0'),
reason='fsdp needs Pytorch 1.11 or higher')
class TestMMFullyShardedDataParallel(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29520'
os.environ['RANK'] = str(rank)
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
torch_dist.init_process_group(
backend='nccl', rank=rank, world_size=world_size)
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def test_train_step(self):
self._init_dist_env(self.rank, self.world_size)
# Test `optim_wrapper` is a instance of `OptimWrapper`
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
optimizer = SGD(fsdp_model.parameters(), lr=0)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
fsdp_model.train()
self.assertTrue(fsdp_model.training)
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
# Test get predictions.
predictions = fsdp_model.val_step(data)
self.assertIsInstance(predictions, torch.Tensor)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
predictions = fsdp_model.test_step(data)
self.assertIsInstance(predictions, torch.Tensor)