[Experimental] Add support for FSDP (#1213)

This commit is contained in:
Mashiro 2023-06-28 16:50:52 +08:00 committed by Zaida Zhou
parent ccd5dc8b18
commit 399f76ffa8
18 changed files with 1317 additions and 144 deletions

View File

@ -15,3 +15,4 @@ mmengine._strategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
FSDPStrategy

View File

@ -15,3 +15,4 @@ mmengine._strategy
SingleDeviceStrategy
DDPStrategy
DeepSpeedStrategy
FSDPStrategy

View File

@ -4,7 +4,6 @@ import argparse
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
@ -43,8 +42,8 @@ class Accuracy(BaseMetric):
def parse_args():
parser = argparse.ArgumentParser(description='Distributed Training')
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
parser.add_argument('--use-fsdp', action='store_true')
parser.add_argument('--use-deepspeed', action='store_true')
args = parser.parse_args()
return args
@ -94,20 +93,33 @@ def main():
),
inputs_to_half=[0],
zero_optimization=dict(
stage=0,
stage=3,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False))
cpu_offload=False),
)
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
optimizer=dict(type='AdamW', lr=1e-3))
elif args.use_fsdp:
from functools import partial
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
size_based_auto_wrap_policy = partial(
size_based_auto_wrap_policy, min_num_params=1e7)
strategy = dict(
type='FSDPStrategy',
model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))
optim_wrapper = dict(
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
else:
strategy = None
optim_wrapper = dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
optim_wrapper = dict(
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
runner = FlexibleRunner(
model=MMResNet50(),
@ -124,4 +136,8 @@ def main():
if __name__ == '__main__':
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp # noqa: 501
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed # noqa: 501
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py
# python distributed_training_with_flexible_runner.py
main()

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_installed
from mmengine.utils import digit_version, is_installed
from mmengine.utils.dl_utils import TORCH_VERSION
from .base import BaseStrategy
from .distributed import DDPStrategy
from .single_device import SingleDeviceStrategy
@ -9,3 +10,10 @@ __all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']
if is_installed('deepspeed'):
from .deepspeed import DeepSpeedStrategy # noqa: F401
__all__.append('DeepSpeedStrategy')
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
try:
from .fsdp import FSDPStrategy # noqa:F401
__all__.append('FSDPStrategy')
except: # noqa: E722
pass

View File

@ -91,6 +91,7 @@ class BaseStrategy(metaclass=ABCMeta):
self._auto_scale_lr = auto_scale_lr
self.dispatch_kwargs: dict = {}
self._prepared = False
@property
def work_dir(self):
@ -342,7 +343,8 @@ class BaseStrategy(metaclass=ABCMeta):
def _init_model_weights(self, model: nn.Module) -> nn.Module:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
if hasattr(model, 'init_weights'):
if (hasattr(model, 'init_weights') and self.dispatch_kwargs.get(
'init_weights_for_test_or_val', True)):
model.init_weights()
# sync params and buffers
for _, params in model.state_dict().items():
@ -633,9 +635,9 @@ class BaseStrategy(metaclass=ABCMeta):
"""
if default_args is None:
default_args = {}
if 'num_batches_per_epoch' in self.dispatch_kwargs:
if 'epoch_length' in self.dispatch_kwargs:
default_args['epoch_length'] = self.dispatch_kwargs[
'num_batches_per_epoch']
'epoch_length']
if 'max_epochs' in self.dispatch_kwargs:
default_args['max_epochs'] = self.dispatch_kwargs['max_epochs']
if 'max_iters' in self.dispatch_kwargs:
@ -962,3 +964,13 @@ class BaseStrategy(metaclass=ABCMeta):
runtime_env['GPU number'] = self.world_size
return system_env, runtime_env
def _prepared_components(self):
return_items = [self.model]
if hasattr(self, 'optim_wrapper'):
return_items.append(self.optim_wrapper)
if hasattr(self, 'param_schedulers'):
return_items.append(self.param_schedulers)
return return_items[0] if len(return_items) == 1 else return_items

View File

@ -64,6 +64,8 @@ class DeepSpeedStrategy(BaseStrategy):
amp: Optional[dict] = None,
activation_checkpointing: Optional[dict] = None,
aio: Optional[dict] = None,
train_micro_batch_size_per_gpu: Optional[int] = None,
gradient_accumulation_steps: int = 1,
# disable the log printed by deepseed
steps_per_print: int = 10000000000000,
# the following args are for BaseStrategy
@ -86,8 +88,21 @@ class DeepSpeedStrategy(BaseStrategy):
if aio is not None:
self.config['aio'] = aio
self.config['steps_per_print'] = steps_per_print
if ('train_micro_batch_size_per_gpu' not in self.config
and 'train_batch_size' not in self.config):
assert train_micro_batch_size_per_gpu is not None, (
'`train_micro_batch_size_per_gpu` or `train_batch_size` '
'should be set!')
self.config['train_micro_batch_size_per_gpu'] = \
train_micro_batch_size_per_gpu
if train_micro_batch_size_per_gpu is not None:
self.config['train_micro_batch_size_per_gpu'] = \
train_micro_batch_size_per_gpu
self.config['gradient_accumulation_steps'] = \
gradient_accumulation_steps
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
def _parse_config(self, config):
@ -145,11 +160,11 @@ class DeepSpeedStrategy(BaseStrategy):
dispatch_kwargs (dict, optional): Kwargs to be passed to other
methods of Strategy. Defaults to None.
"""
if self._prepared:
return self._prepared_components()
assert dispatch_kwargs is not None
self.dispatch_kwargs.update(dispatch_kwargs)
return_items = []
model = self.build_model(model)
model = self._init_model_weights(model)
@ -159,23 +174,16 @@ class DeepSpeedStrategy(BaseStrategy):
self.optim_wrapper.model = self.model # type: ignore
return_items.append(self.model)
return_items.append(self.optim_wrapper)
else:
self.model = self._wrap_model(model)
return_items.append(self.model)
if param_scheduler is not None:
self.param_schedulers = self.build_param_scheduler(
param_scheduler, self.optim_wrapper)
return_items.append(self.param_schedulers)
return return_items[0] if len(return_items) == 1 else return_items
self._prepared = True
return self._prepared_components()
def _wrap_model(self, model: nn.Module) -> nn.Module:
self.config['train_micro_batch_size_per_gpu'] = self.dispatch_kwargs[
'train_micro_batch_size_per_gpu']
if hasattr(self, 'optim_wrapper'):
engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize(
model=model,
@ -246,7 +254,7 @@ class DeepSpeedStrategy(BaseStrategy):
if resume_optimizer:
self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper'))
if resume_param_scheduler:
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
param_schedulers = extra_ckpt.pop('param_schedulers')
self.load_scheduler_state_dict(param_schedulers)
@ -297,12 +305,12 @@ class DeepSpeedStrategy(BaseStrategy):
mmengine=mmengine.__version__ + get_git_hash(),
)
if save_optimizer:
if save_optimizer and hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be thrown
# when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
if save_param_scheduler:
if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
dirname, basename = osp.split(filename)

View File

@ -31,11 +31,6 @@ class DDPStrategy(SingleDeviceStrategy):
**kwargs,
):
super().__init__(**kwargs)
if model_wrapper is None:
# set broadcast_buffers as False to keep compatibility with
# OpenMMLab repos
model_wrapper = dict(
type='MMDistributedDataParallel', broadcast_buffers=False)
self.model_wrapper = model_wrapper
self.sync_bn = sync_bn
@ -95,8 +90,16 @@ class DDPStrategy(SingleDeviceStrategy):
model = self.convert_model(model)
default_args = dict(module=model)
default_args.setdefault('device_ids', [int(os.environ['LOCAL_RANK'])])
if self.model_wrapper is None:
# set broadcast_buffers as False to keep compatibility with
# OpenMMLab repos
self.model_wrapper = dict(
type='MMDistributedDataParallel', broadcast_buffers=False)
default_args = dict(
type='MMDistributedDataParallel',
module=model,
device_ids=[int(os.environ['LOCAL_RANK'])])
model = MODEL_WRAPPERS.build(
self.model_wrapper, default_args=default_args)
return model

595
mmengine/_strategy/fsdp.py Normal file
View File

@ -0,0 +1,595 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import os
import os.path as osp
import time
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Sequence, Union
import torch.nn as nn
from torch.distributed.fsdp import (FullStateDictConfig,
FullyShardedDataParallel,
LocalStateDictConfig, StateDictType)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig,
StateDictConfig)
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.dist import get_rank, is_main_process
from mmengine.model import is_model_wrapper
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy
from .utils import MetaTensorContext
FSDP = FullyShardedDataParallel
FSDP_CONFIGS = Registry('fsdp configs')
FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig)
FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig)
FSDP_CONFIGS.register_module(module=FullStateDictConfig)
FSDP_CONFIGS.register_module(module=LocalStateDictConfig)
@STRATEGIES.register_module()
class FSDPStrategy(DDPStrategy):
"""Support training model with FullyShardedDataParallel (FSDP).
Keyword Args::
model_wrapper (dict, optional): Config dict for model wrapper. The
default configuration is:
Examples:
>>> model_wrapper = dict(
>>> type='MMFullyShardedDataParallel',
>>> use_orig_params=True,
>>> )
See more configurable arguments in
:class:`MMFullyShardedDataParallel`. Defaults to None
skip_init_weights (bool, optional): Whether to skip initialization of
weights. Defaults to False. This is useful when the parameters of
the large model are loaded from a checkpoint, since skipping the
initialization of weights can save a lot of time.
state_dict_cfg (str or dict): Configuration for
how to save and load the state dict of the model, optimizer, and
scheduler.
- "local": save and load the sharded state dict in all ranks.
- "full": save and load the full state dict in rank 0.
- `dict` object: save and load the state dict more flexibly. For
example, you can first offload the state dict to the 'cpu' and
then save it to the disk. This can help you to load the
checkpoint in a non-gpu environment:
Examples:
>>> state_dict_cfg=dict(
>>> state_dict_type='FULL_STATE_DICT',
>>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True),
>>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True),
See more configurable arguments for ``state_dict_cfg``,
``state_dict_config``, and ``optim_state_dict_config``in
`FSDP official api documents`_
kwargs (dict): Additional arguments passed to :class:`DDPStrategy`:
- work_dir (str): The working directory to save checkpoints.
The logs will be saved in the subdirectory of `work_dir` named
:attr:`timestamp`. Defaults to 'work_dirs'.
- experiment_name (str, optional): Name of current experiment. If
not specified, timestamp will be used as :attr:`experiment_name`.
Defaults to None.
- env_kwargs (dict, optional): Environment config passed in
:meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None.
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501
def __init__(self,
*,
model_wrapper: Optional[dict] = None,
skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local',
**kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg)
if not isinstance(skip_init_weights, bool):
raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights
def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
custom fully sharded data parallel module wrappers.
Args:
model (nn.Module): Model to be wrapped.
Returns:
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``.
"""
if is_model_wrapper(model):
return
if self.model_wrapper is None:
self.model_wrapper = dict(type='MMFullyShardedDataParallel')
default_args = dict(
module=model,
device_id=int(os.environ['LOCAL_RANK']),
type='MMFullyShardedDataParallel')
model = MODEL_WRAPPERS.build(
self.model_wrapper, default_args=default_args)
model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config,
self.optim_state_dict_config)
return model
def _is_full_state_dict(self):
"""Whether to save and load the full state_dict in rank 0."""
return self.state_dict_type == StateDictType.FULL_STATE_DICT
def build_model(self, model: Union[nn.Module, dict]) -> nn.Module:
"""Build model.
If skip_init_weights is True, the model will be built with an empty
weights. It means that :meth:`load_checkpoint` must be called to fill
the weights before training.
Args:
model (nn.Module or dict): A ``nn.Module`` object or a dict to
build ``nn.Module`` object. If ``model`` is a ``nn.Module``
object, just returns itself.
Returns:
nn.Module: Model build from ``model``.
"""
if self.skip_init_weights:
if isinstance(model, dict):
# Accelerate initialization by skipping init weights
with MetaTensorContext():
model = super().build_model(model)
model.to_empty(device='cpu')
else:
model = super().build_model(model)
# `id_to_name` will be used to convert the `optim_state_dict` of the
# raw optimizer to the `optim_state_dict`
# returned by `FSDP.optim_state_dict` in
# `StateDictType.FULL_STATE_DICT` mode.
self.id_to_name = dict()
for name, param in model.named_parameters():
self.id_to_name[id(param)] = name
return model
def save_checkpoint(self,
filename: str,
*,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
extra_ckpt: Optional[dict] = None,
callback: Optional[Callable] = None) -> None:
"""Save checkpoint to given ``filename``.
If ``state_dict_type`` is `full`, the checkpoint will only be saved in
rank0. The structure of the saved checkpoint is the same as the one
saved by ``DDPStrategy``
If ``state_dict_type`` is `local`, each rank will save the sharded
state dict to a directory, which means the saved structure will look
like this:
.. code-block:: bash
epoch_0.pth
rank0.pth
rank1.pth
...
rank8.pth
Args:
filename (str): Filename to save checkpoint.
Keyword Args:
save_optimizer (bool): Whether to save the optimizer to
the checkpoint. Defaults to True.
save_param_scheduler (bool): Whether to save the param_scheduler
to the checkpoint. Defaults to True.
extra_ckpt (dict, optional): Extra checkpoint to save.
Defaults to None.
callback (callable, callable): Callback function to modify the
checkpoint before saving the checkpoint.
Defaults to None.
"""
from mmengine.runner.checkpoint import save_checkpoint
state_dict: dict = dict()
state_dict['state_dict'] = self.model_state_dict()
# save optimizer state dict
if save_optimizer and hasattr(self, 'optim_wrapper'):
state_dict['optimizer'] = self.optim_state_dict()
# save param scheduler state dict
if save_param_scheduler and hasattr(self, 'param_schedulers'):
state_dict['param_schedulers'] = self.scheduler_state_dict()
# save extra checkpoint passed by users
if extra_ckpt is None:
extra_ckpt = dict()
if 'meta' not in extra_ckpt:
extra_ckpt['meta'] = dict()
extra_ckpt['meta'].update(
seed=self.seed,
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
mmengine=mmengine.__version__ + get_git_hash(),
)
state_dict.update(extra_ckpt)
# users can do some modification before saving checkpoint
if callback is not None:
callback(state_dict)
# In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint
# of different ranks in different files.
if not self._is_full_state_dict():
rank = get_rank()
mkdir_or_exist(filename)
ckpt_name = f'rank{rank}.pth'
filename = osp.join(filename, ckpt_name)
save_checkpoint(state_dict, filename)
if is_main_process():
save_checkpoint(state_dict, filename)
def model_state_dict(self) -> dict:
"""Get model state dict based on the ``state_dict_type``.
If ``state_dict_type`` is `full`, the model state dict will be the
same as the one of original unsharded model.
If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True``
in ``model_wrapper``. The key of the state dict will be the same as
the one of original unsharded model, but its value will be the sharded
one
If ``state_dict_type`` is `local`, and ```use_orig_params``` is
``False`` in ``model_wrapper``, the flatten and sharded state dict will
be returned.
See more details in the `official api documents`_
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict
""" # noqa: E501
# We've set state_dict by `FSDP.set_state_dict_type`, therefore we
# should get model state dict by `FSDP.state_dict`
return self.model.state_dict()
def optim_state_dict(self) -> dict:
"""Get model state dict based on the ``state_dict_type``.
If ``state_dict_type`` is ``full``, the optimizer state dict can be
loaded by the original unsharded optimizer.
Otherwise, the optimizer state dict could only be loaded by the
optimizer with sharded parameters.
Note:
The optimizer state dict is not the same as the one of original
optimizer even if in ``full`` mode, although they can be loaded
correctly.
See more details in the `official api documents`_
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict
""" # noqa: E501
return FSDP.optim_state_dict(self.model, self.optim_wrapper)
def load_checkpoint(self, filename: str, **kwargs) -> dict:
"""Load checkpoint from given ``filename``.
Note:
If ``state_dict_type`` is `local`, the filename should be a
directory contains ``rank{i}.pth``.
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``.
Keyword Args:
map_location (str or callable): A string or a callable function to
specifying how to remap storage locations.
Defaults to 'cpu'.
strict (bool): strict (bool): Whether to allow different params for
the model and checkpoint.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Defaults to strip
the prefix 'module.' by [(r'^module\\.', '')].
callback (callable, callable): Callback function to modify the
checkpoint after loading the checkpoint.
Defaults to None.
"""
if self._is_full_state_dict():
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs)
else:
rank = get_rank()
filename = osp.join(filename, f'rank{rank}.pth')
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs)
def load_model_state_dict(
self,
state_dict: dict,
*,
strict: bool = False,
revise_keys: list = [(r'^module.', '')],
) -> None: # type: ignore
"""Load model state from dict.
Warning:
`revise_keys` is not supported yet.
Args:
state_dict (dict): Model state dict returned by
:meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type``
is ``full``. ``state_dict`` could be the result of
``model.state_dict()``
strict (bool): Whether to load model state dict strictly.
Defaults to False.
"""
# We should load state dict by `FSDP.load_state_dict`
self.model.load_state_dict(state_dict, strict=strict)
def load_optim_state_dict(self, state_dict: dict) -> None:
"""Load optimizer state from dict.
Args:
state_dict (dict): The optimizer state dict. If ``state_dict_type``
is ``full``. ``state_dict`` could be the result of
``optimizer.state_dict()``
"""
optim_state_dict = FSDP.optim_state_dict_to_load(
state_dict, self.model, self.optim_wrapper.optimizer)
self.optim_wrapper.load_state_dict(optim_state_dict)
def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
"""Make ``state_dict_type`` and ``state_dict_config`` can be configured
with string."""
if isinstance(state_dict_cfg, str):
if state_dict_cfg == 'full':
self.state_dict_type = StateDictType.FULL_STATE_DICT
self.state_dict_config = FullStateDictConfig(
rank0_only=True, offload_to_cpu=True)
self.optim_state_dict_config = FullOptimStateDictConfig(
rank0_only=True, offload_to_cpu=True)
elif state_dict_cfg == 'local':
self.state_dict_type = StateDictType.LOCAL_STATE_DICT
self.state_dict_config = LocalStateDictConfig()
self.optim_state_dict_config = LocalOptimStateDictConfig()
else:
raise ValueError('FSDP only supports `full` and `local` '
f'state_dict_type, but got {state_dict_cfg}')
elif isinstance(state_dict_cfg, dict):
if 'state_dict_type' not in state_dict_cfg:
self.state_dict_type = StateDictType.LOCAL_STATE_DICT
else:
state_dict_type = state_dict_cfg['state_dict_type']
if isinstance(state_dict_type, str):
self.state_dict_type = StateDictType[
state_dict_cfg['state_dict_type']]
else:
self.state_dict_type = state_dict_type
state_dict_config = state_dict_cfg.get('state_dict_config')
if state_dict_config is None:
self.state_dict_config = LocalStateDictConfig()
elif isinstance(state_dict_config, dict):
self.state_dict_config = FSDP_CONFIGS.build(
state_dict_cfg['state_dict_config'])
else:
self.state_dict_config = state_dict_config
optim_state_dict_config = state_dict_cfg.get(
'optim_state_dict_config')
if optim_state_dict_config is None:
self.optim_state_dict_config = LocalOptimStateDictConfig()
elif isinstance(optim_state_dict_config, dict):
self.optim_state_dict_config = FSDP_CONFIGS.build(
state_dict_cfg['optim_state_dict_config'])
else:
self.optim_state_dict_config = optim_state_dict_config
else:
raise TypeError('state_dict_cfg should be a `str` or a `dict`, '
f'but got {type(state_dict_cfg)}')
if not isinstance(self.state_dict_type, StateDictType):
raise TypeError('state_dict_type must be StateDictType, but got '
f'{type(self.state_dict_type)}')
if not isinstance(self.state_dict_config, StateDictConfig):
raise TypeError('state_dict_config must be StateDictConfig, but '
f'got {type(self.state_dict_config)}')
if not isinstance(self.optim_state_dict_config, OptimStateDictConfig):
raise TypeError('optim_state_dict_config must be '
'OptimStateDictConfig, but got '
f'{type(self.optim_state_dict_config)}')
def build_optim_wrapper(
self,
optim_wrapper: Union[Optimizer, OptimWrapper, dict],
model: Optional[nn.Module] = None,
) -> BaseOptimWrapper:
"""Support sharding the optimizer state dict given a built optimizer or
optim_wrapper.
See specific usage in :meth:`BaseStrategy.build_optim_wrapper`.
"""
if isinstance(optim_wrapper, Optimizer):
optim_wrapper = OptimWrapper(optim_wrapper)
if isinstance(optim_wrapper, BaseOptimWrapper):
assert model is not None
# NOTE: The only difference is that FSDPStrategy will shard
# the the built OptimWrapper
optimizer = optim_wrapper.optimizer
param_groups = optimizer.param_groups
optim_state_dict = optimizer.state_dict()
assert not optim_state_dict['state'], (
'Optimizer state_dict should be empty when giving an built '
'optim_wrapper to FSDPStrategy')
# Align the state_dict with state_dict generated by
# FSDP.full_optim_state_dict
new_param_groups = []
for group in param_groups:
new_group = {
key: value
for key, value in group.items() if key != 'param'
}
new_group['params'] = [
self.id_to_name[id(param)] for param in group['params']
]
new_param_groups.append(new_group)
optim_state_dict['param_groups'] = new_param_groups
defaults = {
k: v
for k, v in optimizer.defaults.items() if k != 'differentiable'
}
params_dict = {}
for k, v in model.named_parameters():
if '_fsdp_wrapped_module' in k:
k = k.replace('_fsdp_wrapped_module.', '')
params_dict[k] = v
params = []
for param_group in new_param_groups:
_params = []
for param_name in param_group['params']:
if param_name not in params_dict:
raise RuntimeError(
'Failed to reconstruct the sharded optimizer. '
'You can try to set `use_orig_params=True` in '
'`model_wrapper`')
_params.append(params_dict[param_name])
param_group = {
k: v
for k, v in param_group.items() if k != 'param'
}
param_group['params'] = _params
params.append(param_group)
new_optimizer = optimizer.__class__(params, **defaults)
# Force to load the converted optim_state_dict in full mode.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
optim_state_dict = FSDP.optim_state_dict_to_load(
optim_state_dict, model, new_optimizer)
new_optimizer.load_state_dict(optim_state_dict)
optim_wrapper.optimizer = new_optimizer
return optim_wrapper
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
assert model is not None
# optimizer must be defined for single optimizer training.
optimizer = optim_wrapper.get('optimizer', None)
optim_wrapper.setdefault('type', 'OptimWrapper')
if optim_wrapper.get('type',
'AmpOptimWrapper') in ('AmpOptimWrapper',
AmpOptimWrapper):
optim_wrapper.setdefault('use_fsdp', True)
# If optimizer is a built `Optimizer` instance, the optimizer
# wrapper should be built by `OPTIM_WRAPPERS` registry.
if isinstance(optimizer, Optimizer):
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
# If `optimizer` is not None or `constructor` is defined, it means,
# optimizer wrapper will be built by optimizer wrapper
# constructor. Therefore, `build_optim_wrapper` should be called.
if optimizer is not None or 'constructor' in optim_wrapper:
return build_optim_wrapper(model, optim_wrapper)
else:
# if `optimizer` is not defined, it should be the case of
# training with multiple optimizers. If `constructor` is not
# defined either, each value of `optim_wrapper` must be an
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
# will not handle the case of training with multiple
# optimizers. `build_optim_wrapper` will directly build the
# `OptimWrapperDict` instance from `optim_wrapper.`
optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper):
raise ValueError(
'each item mush be an optimizer object when '
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
def _build_param_scheduler(
self,
scheduler: Union[_ParamScheduler, Dict, List],
optim_wrapper: BaseOptimWrapper,
default_args: dict,
) -> List[_ParamScheduler]:
"""Override this method to update the scheduler with the reconstructed
sharded optimzer."""
if not isinstance(scheduler, Sequence):
schedulers = [scheduler]
else:
schedulers = scheduler
max_epochs = default_args.pop('max_epochs', None)
max_iters = default_args.pop('max_iters', None)
param_schedulers = []
for scheduler in schedulers:
# Update the built scheduler with the sharded optimizer
if isinstance(scheduler, (_ParamScheduler, LRScheduler)):
parameter_keys = inspect.signature(
scheduler.__class__).parameters.keys()
kwargs = {
k: v
for k, v in scheduler.state_dict().items()
if k in parameter_keys
}
scheduler = scheduler.__class__(optim_wrapper, **kwargs)
elif isinstance(scheduler, dict):
_scheduler = copy.deepcopy(scheduler)
# Set default end
if _scheduler.get('by_epoch', True):
if max_epochs is None:
raise ValueError(
'max_epochs must be specified in default_args')
default_end = max_epochs
else:
if max_iters is None:
raise ValueError(
'max_iters must be specified in default_args')
default_end = max_iters
_scheduler.setdefault('end', default_end)
self.logger.debug(
f'The `end` of {_scheduler["type"]} is not set. '
'Use the max epochs/iters of train loop as default.')
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
optimizer=optim_wrapper, **default_args)))
else:
raise TypeError(
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')
return param_schedulers

View File

@ -49,26 +49,24 @@ class SingleDeviceStrategy(BaseStrategy):
If ``accumulative_counts`` is set in ``optim_wrapper``, you
need to provide ``max_iters`` in ``dispatch_kwargs``.
"""
if self._prepared:
return self._prepared_components()
if dispatch_kwargs is not None:
self.dispatch_kwargs.update(dispatch_kwargs)
return_items = []
model = self.build_model(model)
model = self._init_model_weights(model)
model = self._wrap_model(model)
model = self.compile_model(model, compile=compile)
return_items.append(model)
self.model = model
if optim_wrapper is not None:
self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
return_items.append(self.optim_wrapper)
if param_scheduler is not None:
self.param_schedulers = self.build_param_scheduler(
param_scheduler, self.optim_wrapper)
return_items.append(self.param_schedulers)
if optim_wrapper is not None:
self._scale_lr()
@ -84,8 +82,8 @@ class SingleDeviceStrategy(BaseStrategy):
self.optim_wrapper.initialize_count_status( # type: ignore
self.model, 0, self.dispatch_kwargs['max_iters'])
return return_items[0] if len(return_items) == 1 else return_items
self._prepared = True
return self._prepared_components()
def _wrap_model(self, model: nn.Module) -> nn.Module:
model = self.convert_model(model)
@ -200,7 +198,7 @@ class SingleDeviceStrategy(BaseStrategy):
if resume_optimizer:
self.load_optim_state_dict(checkpoint.pop('optimizer'))
if resume_param_scheduler:
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
self.load_scheduler_state_dict(checkpoint.pop('param_schedulers'))
# resume random seed
@ -267,15 +265,7 @@ class SingleDeviceStrategy(BaseStrategy):
if save_optimizer and hasattr(self, 'optim_wrapper'):
state_dict['optimizer'] = self.optim_state_dict()
# save param scheduler state dict
if save_param_scheduler and not hasattr(self, 'param_schedulers'):
self.logger.warning(
'`save_param_scheduler` is True but strategy has no '
'param_schedulers attribute, so skip saving parameter '
'schedulers')
save_param_scheduler = False
if save_param_scheduler:
if save_param_scheduler and hasattr(self, 'param_schedulers'):
state_dict['param_schedulers'] = self.scheduler_state_dict()
# save extra checkpoint passed by users

View File

@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch._subclasses.fake_tensor import _is_tensor_constructor
from torch.utils._python_dispatch import TorchDispatchMode
class MetaTensorContext(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if _is_tensor_constructor(func):
device_idx = [arg.name
for arg in func._schema.arguments].index('device')
if len(args) > device_idx:
args = list(args)
args[device_idx] = 'meta'
else:
kwargs['device'] = 'meta'
return func(*args, **kwargs)

View File

@ -33,6 +33,6 @@ __all__ = [
'convert_sync_batchnorm', 'BaseTTAModel'
]
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
from .wrappers import MMFullyShardedDataParallel # noqa:F401
__all__.append('MMFullyShardedDataParallel')

View File

@ -10,7 +10,7 @@ __all__ = [
'MMSeparateDistributedDataParallel'
]
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
from .fully_sharded_distributed import \
MMFullyShardedDataParallel # noqa:F401
__all__.append('MMFullyShardedDataParallel')

View File

@ -1,18 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Union
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
# yapf: disable
from torch.distributed.fsdp.api import (FullStateDictConfig,
LocalOptimStateDictConfig,
LocalStateDictConfig,
OptimStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
ShardingStrategy, StateDictConfig,
StateDictSettings, StateDictType)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch, CPUOffload, FullyShardedDataParallel)
BackwardPrefetch, CPUOffload, FullOptimStateDictConfig,
FullyShardedDataParallel, MixedPrecision)
# yapf: enable
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS, Registry
from mmengine.registry import FUNCTIONS, MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
# support customize fsdp policy
FSDP_WRAP_POLICIES = Registry('fsdp wrap policy')
from mmengine.utils import digit_version, is_seq_of
@MODEL_WRAPPERS.register_module()
@ -40,8 +52,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
Args:
module (nn.Module): module to be wrapped with FSDP.
process_group (Optional[ProcessGroup]): process group for sharding.
cpu_offload (Optional[Union[bool,CPUOffload]]):
process_group (ProcessGroup, optional): process group for sharding.
cpu_offload (bool, CPUOffload, optional):
CPU offloading config.
Different from FullyShardedDataParallel,Since it can be set by
users' pre-defined config in MMEngine,its type is expected to be
@ -54,7 +66,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
for params and grads to be on same device to work with optimizer.
This API is subject to change. Default is ``None`` in which case
there will be no offloading.
fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]):
auto_wrap_policy (str or Callable, optional):
Specifying a policy to recursively wrap layers with FSDP.
Different from FullyShardedDataParallel, Since it can be set by
users' pre-defined config in MMEngine, its type is expected to be
@ -68,12 +80,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
the returned FSDP root instance.
``default_auto_wrap_policy`` written in
``torch.distributed.fsdp.wrap`` is an example of
``fsdp_auto_wrap_policy`` callable, this policy wraps layers with
``auto_wrap_policy`` callable, this policy wraps layers with
parameter sizes larger than 100M. Users can supply the customized
``fsdp_auto_wrap_policy`` callable that should accept following
``auto_wrap_policy`` callable that should accept following
arguments: ``module: nn.Module``, ``recurse: bool``,
``unwrapped_params: int``, extra customized arguments could be
added to the customized ``fsdp_auto_wrap_policy`` callable as well.
added to the customized ``auto_wrap_policy`` callable as well.
Example::
@ -86,18 +98,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
>>> ) -> bool:
>>> return unwrapped_params >= min_num_params
backward_prefetch: (Optional[Union[str,BackwardPrefetch]]):
Different from FullyShardedDataParallel, Since it will be set by
users' pre-defined config in MMEngine,its type is expected to be
`None`, `str` or `BackwardPrefetch`.
backward_prefetch (str or BackwardPrefetch, optional):
Different from FullyShardedDataParallel, this argument could be a
string or a BackwardPrefetch instance. If it's a string, then
it should be ``BACKWARD_PRE`` or ``BACKWARD_POST``
mixed_precision (dict or MixedPrecision, optional):
This configures native mixed precision for FSDP. If this is set to
``None``. Different from the native FSDP, this argument can a dict
like this:
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different
backward_prefetch algorithms to help backward communication and
computation overlapping.
Pros and cons of each algorithm is explained in class
``BackwardPrefetch``.
Examples:
>>> mixed_precision=dict(param_dtype='float16',
>>> buffer_dtype='float32',
>>> reduce_dtype='float32')
Defaults to None.
use_orig_params (bool): Different from native
``FullyShardedDataParallel``, it defaults to True.
**kwargs: Keyword arguments passed to
:class:`FullyShardedDataParallel`.
"""
@ -105,61 +122,117 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
cpu_offload: Optional[Union[bool, CPUOffload]] = None,
fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None,
backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None,
process_group: Union[dict, ProcessGroup, None] = None,
sharding_strategy: Union[str, ShardingStrategy] = None,
cpu_offload: Union[bool, CPUOffload, None] = None,
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
ignored_modules: Union[Iterable[str], Iterable[nn.Module],
None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
use_orig_params: bool = True,
**kwargs,
):
if isinstance(sharding_strategy, str):
sharding_strategy = ShardingStrategy[sharding_strategy]
if not (isinstance(sharding_strategy, ShardingStrategy)
or sharding_strategy is None):
raise TypeError(
'sharding_strategy must be str or enum of `ShardingStrategy` '
f', but got {sharding_strategy}')
if cpu_offload is not None:
if isinstance(cpu_offload, bool):
cpu_offload = CPUOffload(offload_params=cpu_offload)
elif not isinstance(cpu_offload, CPUOffload):
if isinstance(cpu_offload, bool):
cpu_offload = CPUOffload(offload_params=cpu_offload)
if not (isinstance(cpu_offload, CPUOffload) or cpu_offload is None):
raise TypeError(
'`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}')
if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
ori_func = FUNCTIONS.get( # type: ignore
auto_wrap_policy.pop('type'))
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(ori_func, **auto_wrap_policy)
if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')
if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')
if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type'))
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(param_init_fn, **param_init_fn)
if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')
def parse_dtype(dtype):
if dtype is None:
return None
elif isinstance(dtype, str):
return getattr(torch, dtype)
elif isinstance(dtype, torch.dtype):
return dtype
else:
raise TypeError(
'`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}')
'`dtype` should be `None`, `str` or `torch.dtype`, '
f'but has type {type(dtype)}')
if fsdp_auto_wrap_policy is not None:
if isinstance(fsdp_auto_wrap_policy, str):
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \
'`FSDP_WRAP_POLICIES` has no ' \
f'function {fsdp_auto_wrap_policy}'
fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore
fsdp_auto_wrap_policy)
if not isinstance(fsdp_auto_wrap_policy,
Callable): # type: ignore
raise TypeError(
'Registered `fsdp_auto_wrap_policy` needs to be '
'`Callable`, but has type '
f'{type(fsdp_auto_wrap_policy)}')
elif not isinstance(fsdp_auto_wrap_policy,
Callable): # type: ignore
raise TypeError(
'`fsdp_auto_wrap_policy` should be `None`, `str` '
'or `Callable`, but has type '
f'{type(fsdp_auto_wrap_policy)}')
if backward_prefetch is not None:
if isinstance(backward_prefetch, str):
assert backward_prefetch in ['pre', 'post'], \
'`backward_prefetch` should be either `pre` or `post`,' \
f' but get {backward_prefetch}'
if backward_prefetch == 'pre':
backward_prefetch = BackwardPrefetch.BACKWARD_PRE
else:
backward_prefetch = BackwardPrefetch.BACKWARD_POST
elif not isinstance(backward_prefetch, BackwardPrefetch):
raise TypeError('`backward_prefetch` should be `None`, `str` '
'or `BackwardPrefetch`, but has type '
f'{type(backward_prefetch)}')
if isinstance(mixed_precision, dict):
mixed_precision['param_dtype'] = parse_dtype(
mixed_precision.get('param_dtype', None))
mixed_precision['reduce_dtype'] = parse_dtype(
mixed_precision.get('reduce_dtype', None))
mixed_precision['buffer_dtype'] = parse_dtype(
mixed_precision.get('buffer_dtype', None))
mixed_precision = MixedPrecision(**mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
mixed_precision = mixed_precision
elif mixed_precision is not None:
raise TypeError(
'`mixed_precision` should be `None`, `dict` or '
f'`MixedPrecision`, but has type {type(mixed_precision)}')
self._fixed_modules = self._get_fixed_module(module, ignored_modules)
ignored_modules = [] if ignored_modules is None else ignored_modules
ignored_modules = chain(ignored_modules, self._fixed_modules)
super().__init__(
module=module,
process_group=process_group,
auto_wrap_policy=fsdp_auto_wrap_policy,
sharding_strategy=sharding_strategy,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
use_orig_params=use_orig_params,
**kwargs)
def train_step(self, data: dict,
@ -207,8 +280,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
Returns:
List[BaseDataElement] or dict: The predictions of given data.
"""
inputs, data_sample = self.module.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
data = self.module.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: dict) -> List[BaseDataElement]:
"""Gets the predictions of module during testing process.
@ -219,5 +292,156 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
Returns:
List[BaseDataElement]: The predictions of given data.
"""
inputs, data_sample = self.module.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
data = self.module.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results
def _get_fixed_module(self, module, ignored_modules):
module_dict = dict(module.named_modules())
if is_seq_of(ignored_modules, str):
ignored_modules = [module_dict[name] for name in ignored_modules]
if not is_seq_of(ignored_modules,
nn.Module) and ignored_modules is not None:
raise TypeError(
'`ignored_modules` should be `None`, `Iterable[str]` or '
f'`Iterable[nn.Module]`, but has type {type(ignored_modules)}')
def find_fixed_modules_recursively(
root_module: nn.Module) -> List[nn.Module]:
"""Helper function to find fixed modules whose parameters are all
untrainable, i.e. `requires_grad=False`.
This function performs
recursively.
Args:
root_module (nn.Module): root module for recursion
Returns:
List[nn.Module]: fixed modules in root_module
"""
if all(p.requires_grad for p in root_module.parameters()):
return []
if all(not p.requires_grad for p in root_module.parameters()):
return [root_module]
fixed_modules = []
for sub_module in root_module.children():
fixed_modules.extend(
find_fixed_modules_recursively(sub_module))
return fixed_modules
fixed_modules = find_fixed_modules_recursively(module)
return fixed_modules
if digit_version(torch.__version__) < digit_version('2.0.1'):
@staticmethod
def optim_state_dict(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
model=model,
optim=optim,
optim_state_dict=optim.state_dict(),
optim_input=None,
rank0_only=getattr(state_dict_settings.optim_state_dict_config,
'rank0_only', False),
full_state_dict=state_dict_settings.state_dict_type ==
StateDictType.FULL_STATE_DICT,
group=group,
)
@staticmethod
def set_state_dict_type(
module: nn.Module,
state_dict_type: StateDictType,
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
}
_optim_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig,
StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig,
StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig,
}
# Use the default config if a state_dict config is not set.
state_dict_config_type = _state_dict_type_to_config[
state_dict_type]
optim_state_dict_config_type = _optim_state_dict_type_to_config[
state_dict_type]
if state_dict_config is None:
state_dict_config = state_dict_config_type()
if optim_state_dict_config is None:
optim_state_dict_config = optim_state_dict_config_type()
if state_dict_config_type != type(state_dict_config):
raise RuntimeError('Expected state_dict_config of type '
f'{state_dict_config_type} '
f'but got {type(state_dict_config)}')
if optim_state_dict_config_type != type(optim_state_dict_config):
raise RuntimeError('Expected optim_state_dict_config of type '
f'{optim_state_dict_config_type} '
f'but got {type(optim_state_dict_config)}')
# Set the state_dict type and configurations.
prev_state_dict_type = None
prev_state_dict_config = None
prev_optim_state_dict_config = None
for submodule in traversal_utils._get_fsdp_states(module):
if prev_state_dict_type is None:
prev_state_dict_type = submodule._state_dict_type
else:
assert (
prev_state_dict_type == submodule._state_dict_type
), 'All FSDP modules should have the same state_dict_type.'
if prev_state_dict_config is None:
prev_state_dict_config = submodule._state_dict_config
else:
assert isinstance(
submodule._state_dict_config,
type(prev_state_dict_config)), (
'All FSDP modules must have the same type of '
'state_dict_config.')
if prev_optim_state_dict_config is None:
prev_optim_state_dict_config = \
submodule._optim_state_dict_config
else:
assert isinstance(
submodule._optim_state_dict_config,
type(prev_optim_state_dict_config),
), ('All FSDP modules must have the same type of '
'optim_state_dict_config.')
submodule._state_dict_type = state_dict_type
submodule._state_dict_config = state_dict_config
submodule._optim_state_dict_config = optim_state_dict_config
return StateDictSettings(prev_state_dict_type,
prev_state_dict_config,
prev_optim_state_dict_config)

View File

@ -48,6 +48,10 @@ class AmpOptimWrapper(OptimWrapper):
`'float64'`. If set to ``None``, the default data type will be used.
Defaults to None.
`New in version 0.6.1.`
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should
be enabled when using ``FullyShardedDataParallel``.
Defaults to False.
`New in version 0.8.0.`
**kwargs: Keyword arguments passed to OptimWrapper.
Warnings:
@ -65,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper):
def __init__(self,
loss_scale: str = 'dynamic',
dtype: Union[str, torch.dtype] = None,
use_fsdp: bool = False,
**kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
@ -73,17 +78,29 @@ class AmpOptimWrapper(OptimWrapper):
'on gpu, npu or mlu')
super().__init__(**kwargs)
self._scale_update_param = None
if use_fsdp:
if digit_version(torch.__version__) >= digit_version('2.0.0'):
from torch.distributed.fsdp.sharded_grad_scaler import \
ShardedGradScaler
scaler_type = ShardedGradScaler
else:
raise RuntimeError(
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`')
else:
scaler_type = GradScaler
if loss_scale == 'dynamic':
# If loss_scale is a string, it must be 'dynamic', then dynamic
# loss scaling will be used.
self.loss_scaler = GradScaler()
self.loss_scaler = scaler_type()
elif isinstance(loss_scale, float):
# Static loss scaling
self._scale_update_param = loss_scale
self.loss_scaler = GradScaler(init_scale=loss_scale)
self.loss_scaler = scaler_type(init_scale=loss_scale)
elif isinstance(loss_scale, dict):
# More specific configuration.
self.loss_scaler = GradScaler(**loss_scale)
self.loss_scaler = scaler_type(**loss_scale)
else:
raise TypeError('loss_scale must be of type float, dict, or '
f'"dynamic", but got {loss_scale}')

View File

@ -632,6 +632,18 @@ class FlexibleRunner:
assert isinstance(strategy, dict)
# train_micro_batch_size_per_gpu is required by DeepSpeed
if isinstance(strategy['type'], str):
strategy_name = strategy['type']
else:
strategy_name = strategy['type'].__name__
if strategy_name == 'DeepSpeedStrategy':
if self._train_dataloader is None:
strategy['train_micro_batch_size_per_gpu'] = 1
else:
strategy['train_micro_batch_size_per_gpu'] = \
_get_batch_size(self._train_dataloader)
strategy.setdefault('work_dir', self._work_dir)
strategy.setdefault('experiment_name', experiment_name)
strategy.setdefault('auto_scale_lr', self._auto_scale_lr)
@ -1140,18 +1152,13 @@ class FlexibleRunner:
compile = copy.copy(self._compile)
compile.setdefault('target', 'train_step')
if self.train_dataloader.batch_size is not None:
micro_batch_size = self.train_dataloader.batch_size
else:
micro_batch_size = self.train_dataloader.batch_sampler.batch_size
dispatch_kwargs = dict(
train_micro_batch_size_per_gpu=micro_batch_size,
num_batches_per_epoch=len(self.train_dataloader),
epoch_length=len(self.train_dataloader),
max_epochs=self.max_epochs,
max_iters=self.max_iters,
)
result = self.strategy.prepare(
self.strategy.prepare(
self.model,
optim_wrapper=self.optim_wrapper,
param_scheduler=self.param_schedulers,
@ -1159,10 +1166,10 @@ class FlexibleRunner:
dispatch_kwargs=dispatch_kwargs,
)
self.model = self.strategy.model
self.optim_wrapper = self.strategy.optim_wrapper # type: ignore
if self.param_schedulers is not None:
self.model, self.optim_wrapper, self.param_schedulers, *_ = result
else:
self.model, self.optim_wrapper, *_ = result
self.param_schedulers = self.strategy.param_schedulers
self.load_or_resume()
@ -1187,7 +1194,11 @@ class FlexibleRunner:
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
self.model = self.strategy.prepare(self.model)
dispatch_kwargs = dict(
init_weights_for_test_or_val=self.cfg.get(
'init_weights_for_test_or_val', True))
self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs)
self.model = self.strategy.model
self.load_or_resume()
@ -1210,8 +1221,11 @@ class FlexibleRunner:
'`test_evaluator` arguments when initializing runner.')
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
self.model = self.strategy.prepare(self.model)
dispatch_kwargs = dict(
init_weights_for_test_or_val=self.cfg.get(
'init_weights_for_test_or_val', True))
self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs)
self.model = self.strategy.model
self.load_or_resume()
@ -1613,3 +1627,20 @@ class FlexibleRunner:
if self.cfg._cfg_dict:
self.logger.info(f'Config:\n{self.cfg.pretty_text}')
def _get_batch_size(dataloader):
if isinstance(dataloader, dict):
if 'batch_size' in dataloader:
return dataloader['batch_size']
elif ('batch_sampler' in dataloader
and 'batch_size' in dataloader['batch_sampler']):
return dataloader['batch_sampler']['batch_size']
else:
raise ValueError('Please set batch_size in `Dataloader` or '
'`batch_sampler`')
elif isinstance(dataloader, DataLoader):
return dataloader.batch_sampler.batch_size
else:
raise ValueError('dataloader should be a dict or a Dataloader '
f'instance, but got {type(dataloader)}')

View File

@ -36,14 +36,14 @@ class ToyModel(BaseModel):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_sample = torch.stack(data_samples)
data_samples = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (data_sample - outputs).sum()
loss = (data_samples - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':

View File

@ -8,7 +8,7 @@ import torch.distributed as torch_dist
import torch.nn as nn
from torch.optim import SGD
from mmengine.dist import all_gather
from mmengine.dist import all_gather, broadcast
from mmengine.model import (BaseDataPreprocessor, BaseModel,
ExponentialMovingAverage,
MMDistributedDataParallel,
@ -222,8 +222,8 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
@unittest.skipIf(
torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp')
@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.11.0'),
reason='fsdp needs Pytorch 1.11 or higher')
digit_version(TORCH_VERSION) < digit_version('2.0.0'),
reason='fsdp needs Pytorch 2.0.0 or higher')
class TestMMFullyShardedDataParallel(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):
@ -247,19 +247,38 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase):
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
optimizer = SGD(fsdp_model.parameters(), lr=0)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
data = dict(inputs=inputs, data_sample=MagicMock())
fsdp_model.train()
self.assertTrue(fsdp_model.training)
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
# require_grad=False
model = ToyModel()
for _, param in model.state_dict().items():
broadcast(param)
model.conv1.requires_grad_(False)
ori_weight = model.conv1.weight.clone()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
optimizer = SGD(fsdp_model.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
fsdp_model.train()
self.assertTrue(fsdp_model.training)
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
with fsdp_model.summon_full_params(fsdp_model):
updated_weight = fsdp_model.module.conv1.weight.cpu()
assert_allclose(ori_weight, updated_weight)
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
data = dict(inputs=inputs, data_sample=MagicMock())
# Test get predictions.
predictions = fsdp_model.val_step(data)
self.assertIsInstance(predictions, torch.Tensor)

View File

@ -0,0 +1,231 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf
import torch
import torch.nn as nn
try:
from torch.distributed.fsdp import (FullStateDictConfig,
FullyShardedDataParallel,
LocalStateDictConfig, StateDictType)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, LocalOptimStateDictConfig)
from mmengine._strategy import FSDPStrategy
except: # noqa: E722
pass
from torch.multiprocessing.spawn import start_processes
from torch.optim import SGD
from mmengine.dist import (all_gather_object, broadcast_object_list,
is_main_process)
from mmengine.optim import LinearLR, OptimWrapper
from mmengine.testing.runner_test_case import ToyModel
from mmengine.utils import digit_version
def linear_wrap_policy(
module,
recurse,
nonwrapped_numel,
) -> bool:
if recurse:
return True # always recurse
return isinstance(module, nn.Linear)
@skipIf(
digit_version(torch.__version__) < digit_version('2.0.0')
or not torch.cuda.is_available(),
'Only test FSDP with CUDA and PyTorch >= 2.0.0')
class TestStrategy(TestCase):
def setUp(self):
self.world_size = 2
self.temp_dir = TemporaryDirectory()
def tearDown(self) -> None:
self.temp_dir.cleanup()
def test_init(self):
strategy = FSDPStrategy()
self.assertFalse(strategy.skip_init_weights)
strategy = FSDPStrategy(state_dict_cfg='local')
self._assert_local(strategy)
strategy = FSDPStrategy(state_dict_cfg='full')
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.LOCAL_STATE_DICT))
self._assert_local(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(),
optim_state_dict_config=FullOptimStateDictConfig(),
))
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type='FULL_STATE_DICT',
state_dict_config=dict(type='FullStateDictConfig'),
optim_state_dict_config=dict(type='FullOptimStateDictConfig'),
))
self._assert_full(strategy)
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=dict(type=FullOptimStateDictConfig),
))
self._assert_full(strategy)
with self.assertRaises(ValueError):
strategy = FSDPStrategy(state_dict_cfg='error-str')
# state_dict_cfg should be a str or a dict
with self.assertRaises(TypeError):
strategy = FSDPStrategy(state_dict_cfg=[])
# state_dict_type must be a str or a enumerate of StateDictType
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=[],
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=dict(
type=FullOptimStateDictConfig),
))
# state_dict_config should be a dict or a subclass of StateDictConfig
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=[],
optim_state_dict_config=dict(
type=FullOptimStateDictConfig),
))
# optim_state_dict_config should be a dict or a subclass of
# OptimStateDictConfig
with self.assertRaises(TypeError):
strategy = FSDPStrategy(
state_dict_cfg=dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=dict(type=FullStateDictConfig),
optim_state_dict_config=[],
))
def run_strategy(self):
# Strategy can run with the built model, optimizer and schedulers.
for skip_init_weights, state_dict_cfg in [(True, 'local'),
(False, 'full')]:
strategy = FSDPStrategy(
skip_init_weights=skip_init_weights,
state_dict_cfg=state_dict_cfg,
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
model = ToyModel()
optim = OptimWrapper(SGD(model.parameters(), lr=0.1, momentum=0.9))
lr_scheduler = LinearLR(optimizer=optim)
model, optim, lr_scheduler = strategy.prepare(
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
self.assertIsInstance(model, FullyShardedDataParallel)
self.assertIsInstance(model.linear1, FullyShardedDataParallel)
self.assertIsInstance(model.linear2, FullyShardedDataParallel)
data = torch.ones(2, 2).cuda()
data_samples = torch.zeros(2, 2).cuda()
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
ckpt_path = osp.join(self.temp_dir.name,
f'checkpoint_{state_dict_cfg}.pth')
strategy.save_checkpoint(ckpt_path)
if state_dict_cfg == 'full':
if not is_main_process():
self.assertFalse(osp.exists(ckpt_path))
ckpt_path = [ckpt_path]
broadcast_object_list(ckpt_path)
ckpt_path = ckpt_path[0]
strategy.load_checkpoint(ckpt_path)
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
# optimizer with multiple param_groups can be reconstructed.
model = ToyModel()
strategy = FSDPStrategy(
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
param_groups = []
for param in model.parameters():
param_groups.append(dict(params=[param], lr=0.1))
optim = SGD(param_groups, lr=0.1, momentum=0.9)
lr_scheduler = LinearLR(optimizer=optim)
model, optim, lr_scheduler = strategy.prepare(
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
data = torch.ones(2, 2).cuda()
data_samples = torch.zeros(2, 2).cuda()
loss = model(data, data_samples=data_samples, mode='loss')['loss']
loss.backward()
optim.step()
[scheduler.step() for scheduler in lr_scheduler]
optim_state = optim.state_dict()['state']
optim_state = all_gather_object(optim_state)
@classmethod
def _worker(cls, rank, func):
# local mode
self = cls()
self.setUp()
self.rank = rank
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(self.world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(12123)
torch.cuda.set_device(f'cuda:{rank}')
getattr(self, func)()
self.tearDown()
def test_run_strategy(self):
start_processes(
TestStrategy._worker,
args=('run_strategy', ),
nprocs=self.world_size)
def test_build_model(self):
...
# TODO
# strategy = FSDPStrategy()
# model = ToyModel()
# state_dict = dict()
def _assert_local(self, strategy):
self.assertEqual(strategy.state_dict_type,
StateDictType.LOCAL_STATE_DICT)
self.assertIsInstance(strategy.state_dict_config, LocalStateDictConfig)
self.assertIsInstance(strategy.optim_state_dict_config,
LocalOptimStateDictConfig)
def _assert_full(self, strategy):
self.assertEqual(strategy.state_dict_type,
StateDictType.FULL_STATE_DICT)
self.assertIsInstance(strategy.state_dict_config, FullStateDictConfig)
self.assertIsInstance(strategy.optim_state_dict_config,
FullOptimStateDictConfig)