mmengine/tests/test_strategies/test_fsdp.py

232 lines
8.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf
import torch
import torch.nn as nn
try:
from torch.distributed.fsdp import (FullStateDictConfig,
FullyShardedDataParallel,
LocalStateDictConfig, StateDictType)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, LocalOptimStateDictConfig)
from mmengine._strategy import FSDPStrategy
except: # noqa: E722
pass
from torch.multiprocessing.spawn import start_processes
from torch.optim import SGD
from mmengine.dist import (all_gather_object, broadcast_object_list,
is_main_process)
from mmengine.optim import LinearLR, OptimWrapper
from mmengine.testing.runner_test_case import ToyModel
from mmengine.utils import digit_version
def linear_wrap_policy(
module,
recurse,
nonwrapped_numel,
) -> bool:
if recurse:
return True # always recurse
return isinstance(module, nn.Linear)
@skipIf(
digit_version(torch.__version__) < digit_version('2.0.0')
or not torch.cuda.is_available(),
'Only test FSDP with CUDA and PyTorch >= 2.0.0')
class TestStrategy(TestCase):
def setUp(self):
self.world_size = 2
self.temp_dir = TemporaryDirectory()
def tearDown(self) -> None:
self.temp_dir.cleanup()
def test_init(self):
strategy = FSDPStrategy()
self.assertFalse(strategy.skip_init_weights)
strategy = FSDPStrategy(state_dict_cfg='local')
self._assert_local(strategy)
strategy = FSDPStrategy(state_dict_cfg='full')
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.LOCAL_STATE_DICT))
self._assert_local(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(),
optim_state_dict_config=FullOptimStateDictConfig(),
))
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type='FULL_STATE_DICT',
state_dict_config=dict(type='FullStateDictConfig'),
optim_state_dict_config=dict(type='FullOptimStateDictConfig'),
))
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=dict(type=FullOptimStateDictConfig),
))
self._assert_full(strategy)
with self.assertRaises(ValueError):
strategy = FSDPStrategy(state_dict_cfg='error-str')
# state_dict_cfg should be a str or a dict
with self.assertRaises(TypeError):
strategy = FSDPStrategy(state_dict_cfg=[])
# state_dict_type must be a str or a enumerate of StateDictType
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=[],
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=dict(
type=FullOptimStateDictConfig),
))
# state_dict_config should be a dict or a subclass of StateDictConfig
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=[],
optim_state_dict_config=dict(
type=FullOptimStateDictConfig),
))
# optim_state_dict_config should be a dict or a subclass of
# OptimStateDictConfig
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=[],
))
def run_strategy(self):
# Strategy can run with the built model, optimizer and schedulers.
for skip_init_weights, state_dict_cfg in [(True, 'local'),
(False, 'full')]:
strategy = FSDPStrategy(
skip_init_weights=skip_init_weights,
state_dict_cfg=state_dict_cfg,
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
model = ToyModel()
optim = OptimWrapper(SGD(model.parameters(), lr=0.1, momentum=0.9))
lr_scheduler = LinearLR(optimizer=optim)
model, optim, lr_scheduler = strategy.prepare(
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
self.assertIsInstance(model, FullyShardedDataParallel)
self.assertIsInstance(model.linear1, FullyShardedDataParallel)
self.assertIsInstance(model.linear2, FullyShardedDataParallel)
data = torch.ones(2, 2).cuda()
data_samples = torch.zeros(2, 2).cuda()
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
ckpt_path = osp.join(self.temp_dir.name,
f'checkpoint_{state_dict_cfg}.pth')
strategy.save_checkpoint(ckpt_path)
if state_dict_cfg == 'full':
if not is_main_process():
self.assertFalse(osp.exists(ckpt_path))
ckpt_path = [ckpt_path]
broadcast_object_list(ckpt_path)
ckpt_path = ckpt_path[0]
strategy.load_checkpoint(ckpt_path)
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
# optimizer with multiple param_groups can be reconstructed.
model = ToyModel()
strategy = FSDPStrategy(
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
param_groups = []
for param in model.parameters():
param_groups.append(dict(params=[param], lr=0.1))
optim = SGD(param_groups, lr=0.1, momentum=0.9)
lr_scheduler = LinearLR(optimizer=optim)
model, optim, lr_scheduler = strategy.prepare(
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
data = torch.ones(2, 2).cuda()
data_samples = torch.zeros(2, 2).cuda()
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
optim_state = optim.state_dict()['state']
optim_state = all_gather_object(optim_state)
@classmethod
def _worker(cls, rank, func):
# local mode
self = cls()
self.setUp()
self.rank = rank
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(self.world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(12123)
torch.cuda.set_device(f'cuda:{rank}')
getattr(self, func)()
self.tearDown()
def test_run_strategy(self):
start_processes(
TestStrategy._worker,
args=('run_strategy', ),
nprocs=self.world_size)
def test_build_model(self):
...
# TODO
# strategy = FSDPStrategy()
# model = ToyModel()
# state_dict = dict()
def _assert_local(self, strategy):
self.assertEqual(strategy.state_dict_type,
StateDictType.LOCAL_STATE_DICT)
self.assertIsInstance(strategy.state_dict_config, LocalStateDictConfig)
self.assertIsInstance(strategy.optim_state_dict_config,
LocalOptimStateDictConfig)
def _assert_full(self, strategy):
self.assertEqual(strategy.state_dict_type,
StateDictType.FULL_STATE_DICT)
self.assertIsInstance(strategy.state_dict_config, FullStateDictConfig)
self.assertIsInstance(strategy.optim_state_dict_config,
FullOptimStateDictConfig)