mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Experimental] Add support for FSDP (#1213)
This commit is contained in:
parent
ccd5dc8b18
commit
399f76ffa8
@ -15,3 +15,4 @@ mmengine._strategy
|
||||
SingleDeviceStrategy
|
||||
DDPStrategy
|
||||
DeepSpeedStrategy
|
||||
FSDPStrategy
|
||||
|
@ -15,3 +15,4 @@ mmengine._strategy
|
||||
SingleDeviceStrategy
|
||||
DDPStrategy
|
||||
DeepSpeedStrategy
|
||||
FSDPStrategy
|
||||
|
@ -4,7 +4,6 @@ import argparse
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.model import BaseModel
|
||||
@ -43,8 +42,8 @@ class Accuracy(BaseMetric):
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Distributed Training')
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
parser.add_argument('--use-fsdp', action='store_true')
|
||||
parser.add_argument('--use-deepspeed', action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -94,20 +93,33 @@ def main():
|
||||
),
|
||||
inputs_to_half=[0],
|
||||
zero_optimization=dict(
|
||||
stage=0,
|
||||
stage=3,
|
||||
allgather_partitions=True,
|
||||
reduce_scatter=True,
|
||||
allgather_bucket_size=50000000,
|
||||
reduce_bucket_size=50000000,
|
||||
overlap_comm=True,
|
||||
contiguous_gradients=True,
|
||||
cpu_offload=False))
|
||||
cpu_offload=False),
|
||||
)
|
||||
optim_wrapper = dict(
|
||||
type='DeepSpeedOptimWrapper',
|
||||
optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
|
||||
optimizer=dict(type='AdamW', lr=1e-3))
|
||||
elif args.use_fsdp:
|
||||
from functools import partial
|
||||
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
size_based_auto_wrap_policy = partial(
|
||||
size_based_auto_wrap_policy, min_num_params=1e7)
|
||||
strategy = dict(
|
||||
type='FSDPStrategy',
|
||||
model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
|
||||
else:
|
||||
strategy = None
|
||||
optim_wrapper = dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9))
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3))
|
||||
|
||||
runner = FlexibleRunner(
|
||||
model=MMResNet50(),
|
||||
@ -124,4 +136,8 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp # noqa: 501
|
||||
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed # noqa: 501
|
||||
# torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py
|
||||
# python distributed_training_with_flexible_runner.py
|
||||
main()
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import is_installed
|
||||
from mmengine.utils import digit_version, is_installed
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
from .base import BaseStrategy
|
||||
from .distributed import DDPStrategy
|
||||
from .single_device import SingleDeviceStrategy
|
||||
@ -9,3 +10,10 @@ __all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy']
|
||||
if is_installed('deepspeed'):
|
||||
from .deepspeed import DeepSpeedStrategy # noqa: F401
|
||||
__all__.append('DeepSpeedStrategy')
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
||||
try:
|
||||
from .fsdp import FSDPStrategy # noqa:F401
|
||||
__all__.append('FSDPStrategy')
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
@ -91,6 +91,7 @@ class BaseStrategy(metaclass=ABCMeta):
|
||||
self._auto_scale_lr = auto_scale_lr
|
||||
|
||||
self.dispatch_kwargs: dict = {}
|
||||
self._prepared = False
|
||||
|
||||
@property
|
||||
def work_dir(self):
|
||||
@ -342,7 +343,8 @@ class BaseStrategy(metaclass=ABCMeta):
|
||||
def _init_model_weights(self, model: nn.Module) -> nn.Module:
|
||||
"""Initialize the model weights if the model has
|
||||
:meth:`init_weights`"""
|
||||
if hasattr(model, 'init_weights'):
|
||||
if (hasattr(model, 'init_weights') and self.dispatch_kwargs.get(
|
||||
'init_weights_for_test_or_val', True)):
|
||||
model.init_weights()
|
||||
# sync params and buffers
|
||||
for _, params in model.state_dict().items():
|
||||
@ -633,9 +635,9 @@ class BaseStrategy(metaclass=ABCMeta):
|
||||
"""
|
||||
if default_args is None:
|
||||
default_args = {}
|
||||
if 'num_batches_per_epoch' in self.dispatch_kwargs:
|
||||
if 'epoch_length' in self.dispatch_kwargs:
|
||||
default_args['epoch_length'] = self.dispatch_kwargs[
|
||||
'num_batches_per_epoch']
|
||||
'epoch_length']
|
||||
if 'max_epochs' in self.dispatch_kwargs:
|
||||
default_args['max_epochs'] = self.dispatch_kwargs['max_epochs']
|
||||
if 'max_iters' in self.dispatch_kwargs:
|
||||
@ -962,3 +964,13 @@ class BaseStrategy(metaclass=ABCMeta):
|
||||
runtime_env['GPU number'] = self.world_size
|
||||
|
||||
return system_env, runtime_env
|
||||
|
||||
def _prepared_components(self):
|
||||
return_items = [self.model]
|
||||
if hasattr(self, 'optim_wrapper'):
|
||||
return_items.append(self.optim_wrapper)
|
||||
|
||||
if hasattr(self, 'param_schedulers'):
|
||||
return_items.append(self.param_schedulers)
|
||||
|
||||
return return_items[0] if len(return_items) == 1 else return_items
|
||||
|
@ -64,6 +64,8 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
amp: Optional[dict] = None,
|
||||
activation_checkpointing: Optional[dict] = None,
|
||||
aio: Optional[dict] = None,
|
||||
train_micro_batch_size_per_gpu: Optional[int] = None,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
# disable the log printed by deepseed
|
||||
steps_per_print: int = 10000000000000,
|
||||
# the following args are for BaseStrategy
|
||||
@ -86,8 +88,21 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
if aio is not None:
|
||||
self.config['aio'] = aio
|
||||
|
||||
self.config['steps_per_print'] = steps_per_print
|
||||
if ('train_micro_batch_size_per_gpu' not in self.config
|
||||
and 'train_batch_size' not in self.config):
|
||||
assert train_micro_batch_size_per_gpu is not None, (
|
||||
'`train_micro_batch_size_per_gpu` or `train_batch_size` '
|
||||
'should be set!')
|
||||
self.config['train_micro_batch_size_per_gpu'] = \
|
||||
train_micro_batch_size_per_gpu
|
||||
|
||||
if train_micro_batch_size_per_gpu is not None:
|
||||
self.config['train_micro_batch_size_per_gpu'] = \
|
||||
train_micro_batch_size_per_gpu
|
||||
|
||||
self.config['gradient_accumulation_steps'] = \
|
||||
gradient_accumulation_steps
|
||||
self.config['steps_per_print'] = steps_per_print
|
||||
self._inputs_to_half = inputs_to_half
|
||||
|
||||
def _parse_config(self, config):
|
||||
@ -145,11 +160,11 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
dispatch_kwargs (dict, optional): Kwargs to be passed to other
|
||||
methods of Strategy. Defaults to None.
|
||||
"""
|
||||
if self._prepared:
|
||||
return self._prepared_components()
|
||||
assert dispatch_kwargs is not None
|
||||
self.dispatch_kwargs.update(dispatch_kwargs)
|
||||
|
||||
return_items = []
|
||||
|
||||
model = self.build_model(model)
|
||||
model = self._init_model_weights(model)
|
||||
|
||||
@ -159,23 +174,16 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
|
||||
self.optim_wrapper.model = self.model # type: ignore
|
||||
|
||||
return_items.append(self.model)
|
||||
return_items.append(self.optim_wrapper)
|
||||
else:
|
||||
self.model = self._wrap_model(model)
|
||||
return_items.append(self.model)
|
||||
|
||||
if param_scheduler is not None:
|
||||
self.param_schedulers = self.build_param_scheduler(
|
||||
param_scheduler, self.optim_wrapper)
|
||||
return_items.append(self.param_schedulers)
|
||||
|
||||
return return_items[0] if len(return_items) == 1 else return_items
|
||||
self._prepared = True
|
||||
return self._prepared_components()
|
||||
|
||||
def _wrap_model(self, model: nn.Module) -> nn.Module:
|
||||
self.config['train_micro_batch_size_per_gpu'] = self.dispatch_kwargs[
|
||||
'train_micro_batch_size_per_gpu']
|
||||
|
||||
if hasattr(self, 'optim_wrapper'):
|
||||
engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize(
|
||||
model=model,
|
||||
@ -246,7 +254,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
if resume_optimizer:
|
||||
self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper'))
|
||||
|
||||
if resume_param_scheduler:
|
||||
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
param_schedulers = extra_ckpt.pop('param_schedulers')
|
||||
self.load_scheduler_state_dict(param_schedulers)
|
||||
|
||||
@ -297,12 +305,12 @@ class DeepSpeedStrategy(BaseStrategy):
|
||||
mmengine=mmengine.__version__ + get_git_hash(),
|
||||
)
|
||||
|
||||
if save_optimizer:
|
||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be thrown
|
||||
# when loading or resuming checkpoint.
|
||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
if save_param_scheduler:
|
||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
|
@ -31,11 +31,6 @@ class DDPStrategy(SingleDeviceStrategy):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if model_wrapper is None:
|
||||
# set broadcast_buffers as False to keep compatibility with
|
||||
# OpenMMLab repos
|
||||
model_wrapper = dict(
|
||||
type='MMDistributedDataParallel', broadcast_buffers=False)
|
||||
self.model_wrapper = model_wrapper
|
||||
self.sync_bn = sync_bn
|
||||
|
||||
@ -95,8 +90,16 @@ class DDPStrategy(SingleDeviceStrategy):
|
||||
|
||||
model = self.convert_model(model)
|
||||
|
||||
default_args = dict(module=model)
|
||||
default_args.setdefault('device_ids', [int(os.environ['LOCAL_RANK'])])
|
||||
if self.model_wrapper is None:
|
||||
# set broadcast_buffers as False to keep compatibility with
|
||||
# OpenMMLab repos
|
||||
self.model_wrapper = dict(
|
||||
type='MMDistributedDataParallel', broadcast_buffers=False)
|
||||
|
||||
default_args = dict(
|
||||
type='MMDistributedDataParallel',
|
||||
module=model,
|
||||
device_ids=[int(os.environ['LOCAL_RANK'])])
|
||||
model = MODEL_WRAPPERS.build(
|
||||
self.model_wrapper, default_args=default_args)
|
||||
return model
|
||||
|
595
mmengine/_strategy/fsdp.py
Normal file
595
mmengine/_strategy/fsdp.py
Normal file
@ -0,0 +1,595 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import (FullStateDictConfig,
|
||||
FullyShardedDataParallel,
|
||||
LocalStateDictConfig, StateDictType)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig,
|
||||
StateDictConfig)
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.dist import get_rank, is_main_process
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
|
||||
OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, STRATEGIES, Registry)
|
||||
from mmengine.utils import get_git_hash, mkdir_or_exist
|
||||
from .distributed import DDPStrategy
|
||||
from .utils import MetaTensorContext
|
||||
|
||||
FSDP = FullyShardedDataParallel
|
||||
FSDP_CONFIGS = Registry('fsdp configs')
|
||||
FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig)
|
||||
FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig)
|
||||
FSDP_CONFIGS.register_module(module=FullStateDictConfig)
|
||||
FSDP_CONFIGS.register_module(module=LocalStateDictConfig)
|
||||
|
||||
|
||||
@STRATEGIES.register_module()
|
||||
class FSDPStrategy(DDPStrategy):
|
||||
"""Support training model with FullyShardedDataParallel (FSDP).
|
||||
|
||||
Keyword Args::
|
||||
model_wrapper (dict, optional): Config dict for model wrapper. The
|
||||
default configuration is:
|
||||
|
||||
Examples:
|
||||
>>> model_wrapper = dict(
|
||||
>>> type='MMFullyShardedDataParallel',
|
||||
>>> use_orig_params=True,
|
||||
>>> )
|
||||
|
||||
See more configurable arguments in
|
||||
:class:`MMFullyShardedDataParallel`. Defaults to None
|
||||
skip_init_weights (bool, optional): Whether to skip initialization of
|
||||
weights. Defaults to False. This is useful when the parameters of
|
||||
the large model are loaded from a checkpoint, since skipping the
|
||||
initialization of weights can save a lot of time.
|
||||
state_dict_cfg (str or dict): Configuration for
|
||||
how to save and load the state dict of the model, optimizer, and
|
||||
scheduler.
|
||||
|
||||
- "local": save and load the sharded state dict in all ranks.
|
||||
- "full": save and load the full state dict in rank 0.
|
||||
- `dict` object: save and load the state dict more flexibly. For
|
||||
example, you can first offload the state dict to the 'cpu' and
|
||||
then save it to the disk. This can help you to load the
|
||||
checkpoint in a non-gpu environment:
|
||||
|
||||
Examples:
|
||||
>>> state_dict_cfg=dict(
|
||||
>>> state_dict_type='FULL_STATE_DICT',
|
||||
>>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True),
|
||||
>>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True),
|
||||
|
||||
See more configurable arguments for ``state_dict_cfg``,
|
||||
``state_dict_config``, and ``optim_state_dict_config``in
|
||||
`FSDP official api documents`_
|
||||
kwargs (dict): Additional arguments passed to :class:`DDPStrategy`:
|
||||
|
||||
- work_dir (str): The working directory to save checkpoints.
|
||||
The logs will be saved in the subdirectory of `work_dir` named
|
||||
:attr:`timestamp`. Defaults to 'work_dirs'.
|
||||
- experiment_name (str, optional): Name of current experiment. If
|
||||
not specified, timestamp will be used as :attr:`experiment_name`.
|
||||
Defaults to None.
|
||||
- env_kwargs (dict, optional): Environment config passed in
|
||||
:meth:`setup_env`. Defaults to None.
|
||||
- log_kwargs (dict, optional): Logger config passed in
|
||||
:meth:`build_logger`. Defaults to None.
|
||||
|
||||
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
model_wrapper: Optional[dict] = None,
|
||||
skip_init_weights=False,
|
||||
state_dict_cfg: Union[str, dict] = 'local',
|
||||
**kwargs):
|
||||
super().__init__(model_wrapper=model_wrapper, **kwargs)
|
||||
self._init_state_dict_cfg(state_dict_cfg)
|
||||
if not isinstance(skip_init_weights, bool):
|
||||
raise TypeError('skip_init_weights must be a boolean, but got '
|
||||
f'{type(skip_init_weights)}')
|
||||
self.skip_init_weights = skip_init_weights
|
||||
|
||||
def _wrap_model(self, model: nn.Module) -> None:
|
||||
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
|
||||
custom fully sharded data parallel module wrappers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to be wrapped.
|
||||
|
||||
Returns:
|
||||
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
|
||||
or subclass of ``FullyShardedDataParallel``.
|
||||
"""
|
||||
if is_model_wrapper(model):
|
||||
return
|
||||
|
||||
if self.model_wrapper is None:
|
||||
self.model_wrapper = dict(type='MMFullyShardedDataParallel')
|
||||
|
||||
default_args = dict(
|
||||
module=model,
|
||||
device_id=int(os.environ['LOCAL_RANK']),
|
||||
type='MMFullyShardedDataParallel')
|
||||
model = MODEL_WRAPPERS.build(
|
||||
self.model_wrapper, default_args=default_args)
|
||||
model.set_state_dict_type(model, self.state_dict_type,
|
||||
self.state_dict_config,
|
||||
self.optim_state_dict_config)
|
||||
return model
|
||||
|
||||
def _is_full_state_dict(self):
|
||||
"""Whether to save and load the full state_dict in rank 0."""
|
||||
return self.state_dict_type == StateDictType.FULL_STATE_DICT
|
||||
|
||||
def build_model(self, model: Union[nn.Module, dict]) -> nn.Module:
|
||||
"""Build model.
|
||||
|
||||
If skip_init_weights is True, the model will be built with an empty
|
||||
weights. It means that :meth:`load_checkpoint` must be called to fill
|
||||
the weights before training.
|
||||
|
||||
Args:
|
||||
model (nn.Module or dict): A ``nn.Module`` object or a dict to
|
||||
build ``nn.Module`` object. If ``model`` is a ``nn.Module``
|
||||
object, just returns itself.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model build from ``model``.
|
||||
"""
|
||||
if self.skip_init_weights:
|
||||
if isinstance(model, dict):
|
||||
# Accelerate initialization by skipping init weights
|
||||
with MetaTensorContext():
|
||||
model = super().build_model(model)
|
||||
model.to_empty(device='cpu')
|
||||
else:
|
||||
model = super().build_model(model)
|
||||
|
||||
# `id_to_name` will be used to convert the `optim_state_dict` of the
|
||||
# raw optimizer to the `optim_state_dict`
|
||||
# returned by `FSDP.optim_state_dict` in
|
||||
# `StateDictType.FULL_STATE_DICT` mode.
|
||||
self.id_to_name = dict()
|
||||
for name, param in model.named_parameters():
|
||||
self.id_to_name[id(param)] = name
|
||||
return model
|
||||
|
||||
def save_checkpoint(self,
|
||||
filename: str,
|
||||
*,
|
||||
save_optimizer: bool = True,
|
||||
save_param_scheduler: bool = True,
|
||||
extra_ckpt: Optional[dict] = None,
|
||||
callback: Optional[Callable] = None) -> None:
|
||||
"""Save checkpoint to given ``filename``.
|
||||
|
||||
If ``state_dict_type`` is `full`, the checkpoint will only be saved in
|
||||
rank0. The structure of the saved checkpoint is the same as the one
|
||||
saved by ``DDPStrategy``
|
||||
|
||||
If ``state_dict_type`` is `local`, each rank will save the sharded
|
||||
state dict to a directory, which means the saved structure will look
|
||||
like this:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
── epoch_0.pth
|
||||
├── rank0.pth
|
||||
├── rank1.pth
|
||||
├── ...
|
||||
└── rank8.pth
|
||||
|
||||
Args:
|
||||
filename (str): Filename to save checkpoint.
|
||||
|
||||
Keyword Args:
|
||||
save_optimizer (bool): Whether to save the optimizer to
|
||||
the checkpoint. Defaults to True.
|
||||
save_param_scheduler (bool): Whether to save the param_scheduler
|
||||
to the checkpoint. Defaults to True.
|
||||
extra_ckpt (dict, optional): Extra checkpoint to save.
|
||||
Defaults to None.
|
||||
callback (callable, callable): Callback function to modify the
|
||||
checkpoint before saving the checkpoint.
|
||||
Defaults to None.
|
||||
"""
|
||||
from mmengine.runner.checkpoint import save_checkpoint
|
||||
|
||||
state_dict: dict = dict()
|
||||
state_dict['state_dict'] = self.model_state_dict()
|
||||
|
||||
# save optimizer state dict
|
||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
||||
state_dict['optimizer'] = self.optim_state_dict()
|
||||
|
||||
# save param scheduler state dict
|
||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
state_dict['param_schedulers'] = self.scheduler_state_dict()
|
||||
|
||||
# save extra checkpoint passed by users
|
||||
if extra_ckpt is None:
|
||||
extra_ckpt = dict()
|
||||
if 'meta' not in extra_ckpt:
|
||||
extra_ckpt['meta'] = dict()
|
||||
|
||||
extra_ckpt['meta'].update(
|
||||
seed=self.seed,
|
||||
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
|
||||
mmengine=mmengine.__version__ + get_git_hash(),
|
||||
)
|
||||
state_dict.update(extra_ckpt)
|
||||
|
||||
# users can do some modification before saving checkpoint
|
||||
if callback is not None:
|
||||
callback(state_dict)
|
||||
|
||||
# In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint
|
||||
# of different ranks in different files.
|
||||
if not self._is_full_state_dict():
|
||||
rank = get_rank()
|
||||
mkdir_or_exist(filename)
|
||||
ckpt_name = f'rank{rank}.pth'
|
||||
filename = osp.join(filename, ckpt_name)
|
||||
save_checkpoint(state_dict, filename)
|
||||
|
||||
if is_main_process():
|
||||
save_checkpoint(state_dict, filename)
|
||||
|
||||
def model_state_dict(self) -> dict:
|
||||
"""Get model state dict based on the ``state_dict_type``.
|
||||
|
||||
If ``state_dict_type`` is `full`, the model state dict will be the
|
||||
same as the one of original unsharded model.
|
||||
|
||||
If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True``
|
||||
in ``model_wrapper``. The key of the state dict will be the same as
|
||||
the one of original unsharded model, but its value will be the sharded
|
||||
one
|
||||
|
||||
If ``state_dict_type`` is `local`, and ```use_orig_params``` is
|
||||
``False`` in ``model_wrapper``, the flatten and sharded state dict will
|
||||
be returned.
|
||||
|
||||
See more details in the `official api documents`_
|
||||
|
||||
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict
|
||||
""" # noqa: E501
|
||||
# We've set state_dict by `FSDP.set_state_dict_type`, therefore we
|
||||
# should get model state dict by `FSDP.state_dict`
|
||||
return self.model.state_dict()
|
||||
|
||||
def optim_state_dict(self) -> dict:
|
||||
"""Get model state dict based on the ``state_dict_type``.
|
||||
|
||||
If ``state_dict_type`` is ``full``, the optimizer state dict can be
|
||||
loaded by the original unsharded optimizer.
|
||||
|
||||
Otherwise, the optimizer state dict could only be loaded by the
|
||||
optimizer with sharded parameters.
|
||||
|
||||
Note:
|
||||
The optimizer state dict is not the same as the one of original
|
||||
optimizer even if in ``full`` mode, although they can be loaded
|
||||
correctly.
|
||||
|
||||
See more details in the `official api documents`_
|
||||
|
||||
.. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict
|
||||
""" # noqa: E501
|
||||
return FSDP.optim_state_dict(self.model, self.optim_wrapper)
|
||||
|
||||
def load_checkpoint(self, filename: str, **kwargs) -> dict:
|
||||
"""Load checkpoint from given ``filename``.
|
||||
|
||||
Note:
|
||||
If ``state_dict_type`` is `local`, the filename should be a
|
||||
directory contains ``rank{i}.pth``.
|
||||
|
||||
Args:
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``.
|
||||
|
||||
Keyword Args:
|
||||
map_location (str or callable): A string or a callable function to
|
||||
specifying how to remap storage locations.
|
||||
Defaults to 'cpu'.
|
||||
strict (bool): strict (bool): Whether to allow different params for
|
||||
the model and checkpoint.
|
||||
revise_keys (list): A list of customized keywords to modify the
|
||||
state_dict in checkpoint. Each item is a (pattern, replacement)
|
||||
pair of the regular expression operations. Defaults to strip
|
||||
the prefix 'module.' by [(r'^module\\.', '')].
|
||||
callback (callable, callable): Callback function to modify the
|
||||
checkpoint after loading the checkpoint.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self._is_full_state_dict():
|
||||
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs)
|
||||
else:
|
||||
rank = get_rank()
|
||||
filename = osp.join(filename, f'rank{rank}.pth')
|
||||
return super(DDPStrategy, self).load_checkpoint(filename, **kwargs)
|
||||
|
||||
def load_model_state_dict(
|
||||
self,
|
||||
state_dict: dict,
|
||||
*,
|
||||
strict: bool = False,
|
||||
revise_keys: list = [(r'^module.', '')],
|
||||
) -> None: # type: ignore
|
||||
"""Load model state from dict.
|
||||
|
||||
Warning:
|
||||
`revise_keys` is not supported yet.
|
||||
|
||||
Args:
|
||||
state_dict (dict): Model state dict returned by
|
||||
:meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type``
|
||||
is ``full``. ``state_dict`` could be the result of
|
||||
``model.state_dict()``
|
||||
strict (bool): Whether to load model state dict strictly.
|
||||
Defaults to False.
|
||||
"""
|
||||
# We should load state dict by `FSDP.load_state_dict`
|
||||
self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def load_optim_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load optimizer state from dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The optimizer state dict. If ``state_dict_type``
|
||||
is ``full``. ``state_dict`` could be the result of
|
||||
``optimizer.state_dict()``
|
||||
"""
|
||||
optim_state_dict = FSDP.optim_state_dict_to_load(
|
||||
state_dict, self.model, self.optim_wrapper.optimizer)
|
||||
self.optim_wrapper.load_state_dict(optim_state_dict)
|
||||
|
||||
def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
|
||||
"""Make ``state_dict_type`` and ``state_dict_config`` can be configured
|
||||
with string."""
|
||||
if isinstance(state_dict_cfg, str):
|
||||
if state_dict_cfg == 'full':
|
||||
self.state_dict_type = StateDictType.FULL_STATE_DICT
|
||||
self.state_dict_config = FullStateDictConfig(
|
||||
rank0_only=True, offload_to_cpu=True)
|
||||
self.optim_state_dict_config = FullOptimStateDictConfig(
|
||||
rank0_only=True, offload_to_cpu=True)
|
||||
elif state_dict_cfg == 'local':
|
||||
self.state_dict_type = StateDictType.LOCAL_STATE_DICT
|
||||
self.state_dict_config = LocalStateDictConfig()
|
||||
self.optim_state_dict_config = LocalOptimStateDictConfig()
|
||||
else:
|
||||
raise ValueError('FSDP only supports `full` and `local` '
|
||||
f'state_dict_type, but got {state_dict_cfg}')
|
||||
elif isinstance(state_dict_cfg, dict):
|
||||
if 'state_dict_type' not in state_dict_cfg:
|
||||
self.state_dict_type = StateDictType.LOCAL_STATE_DICT
|
||||
else:
|
||||
state_dict_type = state_dict_cfg['state_dict_type']
|
||||
if isinstance(state_dict_type, str):
|
||||
self.state_dict_type = StateDictType[
|
||||
state_dict_cfg['state_dict_type']]
|
||||
else:
|
||||
self.state_dict_type = state_dict_type
|
||||
state_dict_config = state_dict_cfg.get('state_dict_config')
|
||||
if state_dict_config is None:
|
||||
self.state_dict_config = LocalStateDictConfig()
|
||||
elif isinstance(state_dict_config, dict):
|
||||
self.state_dict_config = FSDP_CONFIGS.build(
|
||||
state_dict_cfg['state_dict_config'])
|
||||
else:
|
||||
self.state_dict_config = state_dict_config
|
||||
|
||||
optim_state_dict_config = state_dict_cfg.get(
|
||||
'optim_state_dict_config')
|
||||
if optim_state_dict_config is None:
|
||||
self.optim_state_dict_config = LocalOptimStateDictConfig()
|
||||
elif isinstance(optim_state_dict_config, dict):
|
||||
self.optim_state_dict_config = FSDP_CONFIGS.build(
|
||||
state_dict_cfg['optim_state_dict_config'])
|
||||
else:
|
||||
self.optim_state_dict_config = optim_state_dict_config
|
||||
else:
|
||||
raise TypeError('state_dict_cfg should be a `str` or a `dict`, '
|
||||
f'but got {type(state_dict_cfg)}')
|
||||
|
||||
if not isinstance(self.state_dict_type, StateDictType):
|
||||
raise TypeError('state_dict_type must be StateDictType, but got '
|
||||
f'{type(self.state_dict_type)}')
|
||||
if not isinstance(self.state_dict_config, StateDictConfig):
|
||||
raise TypeError('state_dict_config must be StateDictConfig, but '
|
||||
f'got {type(self.state_dict_config)}')
|
||||
if not isinstance(self.optim_state_dict_config, OptimStateDictConfig):
|
||||
raise TypeError('optim_state_dict_config must be '
|
||||
'OptimStateDictConfig, but got '
|
||||
f'{type(self.optim_state_dict_config)}')
|
||||
|
||||
def build_optim_wrapper(
|
||||
self,
|
||||
optim_wrapper: Union[Optimizer, OptimWrapper, dict],
|
||||
model: Optional[nn.Module] = None,
|
||||
) -> BaseOptimWrapper:
|
||||
"""Support sharding the optimizer state dict given a built optimizer or
|
||||
optim_wrapper.
|
||||
|
||||
See specific usage in :meth:`BaseStrategy.build_optim_wrapper`.
|
||||
"""
|
||||
if isinstance(optim_wrapper, Optimizer):
|
||||
optim_wrapper = OptimWrapper(optim_wrapper)
|
||||
if isinstance(optim_wrapper, BaseOptimWrapper):
|
||||
assert model is not None
|
||||
# NOTE: The only difference is that FSDPStrategy will shard
|
||||
# the the built OptimWrapper
|
||||
optimizer = optim_wrapper.optimizer
|
||||
param_groups = optimizer.param_groups
|
||||
optim_state_dict = optimizer.state_dict()
|
||||
assert not optim_state_dict['state'], (
|
||||
'Optimizer state_dict should be empty when giving an built '
|
||||
'optim_wrapper to FSDPStrategy')
|
||||
# Align the state_dict with state_dict generated by
|
||||
# FSDP.full_optim_state_dict
|
||||
new_param_groups = []
|
||||
for group in param_groups:
|
||||
new_group = {
|
||||
key: value
|
||||
for key, value in group.items() if key != 'param'
|
||||
}
|
||||
new_group['params'] = [
|
||||
self.id_to_name[id(param)] for param in group['params']
|
||||
]
|
||||
new_param_groups.append(new_group)
|
||||
optim_state_dict['param_groups'] = new_param_groups
|
||||
defaults = {
|
||||
k: v
|
||||
for k, v in optimizer.defaults.items() if k != 'differentiable'
|
||||
}
|
||||
|
||||
params_dict = {}
|
||||
for k, v in model.named_parameters():
|
||||
if '_fsdp_wrapped_module' in k:
|
||||
k = k.replace('_fsdp_wrapped_module.', '')
|
||||
params_dict[k] = v
|
||||
|
||||
params = []
|
||||
for param_group in new_param_groups:
|
||||
_params = []
|
||||
for param_name in param_group['params']:
|
||||
if param_name not in params_dict:
|
||||
raise RuntimeError(
|
||||
'Failed to reconstruct the sharded optimizer. '
|
||||
'You can try to set `use_orig_params=True` in '
|
||||
'`model_wrapper`')
|
||||
_params.append(params_dict[param_name])
|
||||
param_group = {
|
||||
k: v
|
||||
for k, v in param_group.items() if k != 'param'
|
||||
}
|
||||
param_group['params'] = _params
|
||||
params.append(param_group)
|
||||
|
||||
new_optimizer = optimizer.__class__(params, **defaults)
|
||||
|
||||
# Force to load the converted optim_state_dict in full mode.
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
|
||||
optim_state_dict = FSDP.optim_state_dict_to_load(
|
||||
optim_state_dict, model, new_optimizer)
|
||||
new_optimizer.load_state_dict(optim_state_dict)
|
||||
optim_wrapper.optimizer = new_optimizer
|
||||
return optim_wrapper
|
||||
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
|
||||
assert model is not None
|
||||
# optimizer must be defined for single optimizer training.
|
||||
optimizer = optim_wrapper.get('optimizer', None)
|
||||
optim_wrapper.setdefault('type', 'OptimWrapper')
|
||||
if optim_wrapper.get('type',
|
||||
'AmpOptimWrapper') in ('AmpOptimWrapper',
|
||||
AmpOptimWrapper):
|
||||
optim_wrapper.setdefault('use_fsdp', True)
|
||||
|
||||
# If optimizer is a built `Optimizer` instance, the optimizer
|
||||
# wrapper should be built by `OPTIM_WRAPPERS` registry.
|
||||
if isinstance(optimizer, Optimizer):
|
||||
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
|
||||
|
||||
# If `optimizer` is not None or `constructor` is defined, it means,
|
||||
# optimizer wrapper will be built by optimizer wrapper
|
||||
# constructor. Therefore, `build_optim_wrapper` should be called.
|
||||
if optimizer is not None or 'constructor' in optim_wrapper:
|
||||
return build_optim_wrapper(model, optim_wrapper)
|
||||
else:
|
||||
# if `optimizer` is not defined, it should be the case of
|
||||
# training with multiple optimizers. If `constructor` is not
|
||||
# defined either, each value of `optim_wrapper` must be an
|
||||
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
|
||||
# will not handle the case of training with multiple
|
||||
# optimizers. `build_optim_wrapper` will directly build the
|
||||
# `OptimWrapperDict` instance from `optim_wrapper.`
|
||||
optim_wrappers = OrderedDict()
|
||||
for name, optim in optim_wrapper.items():
|
||||
if not isinstance(optim, OptimWrapper):
|
||||
raise ValueError(
|
||||
'each item mush be an optimizer object when '
|
||||
'"type" and "constructor" are not in '
|
||||
f'optimizer, but got {name}={optim}')
|
||||
optim_wrappers[name] = optim
|
||||
return OptimWrapperDict(**optim_wrappers)
|
||||
else:
|
||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||
f'object or dict, but got {optim_wrapper}')
|
||||
|
||||
def _build_param_scheduler(
|
||||
self,
|
||||
scheduler: Union[_ParamScheduler, Dict, List],
|
||||
optim_wrapper: BaseOptimWrapper,
|
||||
default_args: dict,
|
||||
) -> List[_ParamScheduler]:
|
||||
"""Override this method to update the scheduler with the reconstructed
|
||||
sharded optimzer."""
|
||||
if not isinstance(scheduler, Sequence):
|
||||
schedulers = [scheduler]
|
||||
else:
|
||||
schedulers = scheduler
|
||||
|
||||
max_epochs = default_args.pop('max_epochs', None)
|
||||
max_iters = default_args.pop('max_iters', None)
|
||||
|
||||
param_schedulers = []
|
||||
for scheduler in schedulers:
|
||||
# Update the built scheduler with the sharded optimizer
|
||||
if isinstance(scheduler, (_ParamScheduler, LRScheduler)):
|
||||
parameter_keys = inspect.signature(
|
||||
scheduler.__class__).parameters.keys()
|
||||
kwargs = {
|
||||
k: v
|
||||
for k, v in scheduler.state_dict().items()
|
||||
if k in parameter_keys
|
||||
}
|
||||
scheduler = scheduler.__class__(optim_wrapper, **kwargs)
|
||||
elif isinstance(scheduler, dict):
|
||||
_scheduler = copy.deepcopy(scheduler)
|
||||
|
||||
# Set default end
|
||||
if _scheduler.get('by_epoch', True):
|
||||
if max_epochs is None:
|
||||
raise ValueError(
|
||||
'max_epochs must be specified in default_args')
|
||||
default_end = max_epochs
|
||||
else:
|
||||
if max_iters is None:
|
||||
raise ValueError(
|
||||
'max_iters must be specified in default_args')
|
||||
default_end = max_iters
|
||||
_scheduler.setdefault('end', default_end)
|
||||
self.logger.debug(
|
||||
f'The `end` of {_scheduler["type"]} is not set. '
|
||||
'Use the max epochs/iters of train loop as default.')
|
||||
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(
|
||||
optimizer=optim_wrapper, **default_args)))
|
||||
else:
|
||||
raise TypeError(
|
||||
'scheduler should be a _ParamScheduler object or dict, '
|
||||
f'but got {scheduler}')
|
||||
return param_schedulers
|
@ -49,26 +49,24 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
If ``accumulative_counts`` is set in ``optim_wrapper``, you
|
||||
need to provide ``max_iters`` in ``dispatch_kwargs``.
|
||||
"""
|
||||
if self._prepared:
|
||||
return self._prepared_components()
|
||||
if dispatch_kwargs is not None:
|
||||
self.dispatch_kwargs.update(dispatch_kwargs)
|
||||
|
||||
return_items = []
|
||||
model = self.build_model(model)
|
||||
model = self._init_model_weights(model)
|
||||
model = self._wrap_model(model)
|
||||
model = self.compile_model(model, compile=compile)
|
||||
return_items.append(model)
|
||||
|
||||
self.model = model
|
||||
|
||||
if optim_wrapper is not None:
|
||||
self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
|
||||
return_items.append(self.optim_wrapper)
|
||||
|
||||
if param_scheduler is not None:
|
||||
self.param_schedulers = self.build_param_scheduler(
|
||||
param_scheduler, self.optim_wrapper)
|
||||
return_items.append(self.param_schedulers)
|
||||
|
||||
if optim_wrapper is not None:
|
||||
self._scale_lr()
|
||||
@ -84,8 +82,8 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
|
||||
self.optim_wrapper.initialize_count_status( # type: ignore
|
||||
self.model, 0, self.dispatch_kwargs['max_iters'])
|
||||
|
||||
return return_items[0] if len(return_items) == 1 else return_items
|
||||
self._prepared = True
|
||||
return self._prepared_components()
|
||||
|
||||
def _wrap_model(self, model: nn.Module) -> nn.Module:
|
||||
model = self.convert_model(model)
|
||||
@ -200,7 +198,7 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
if resume_optimizer:
|
||||
self.load_optim_state_dict(checkpoint.pop('optimizer'))
|
||||
|
||||
if resume_param_scheduler:
|
||||
if resume_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
self.load_scheduler_state_dict(checkpoint.pop('param_schedulers'))
|
||||
|
||||
# resume random seed
|
||||
@ -267,15 +265,7 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
||||
state_dict['optimizer'] = self.optim_state_dict()
|
||||
|
||||
# save param scheduler state dict
|
||||
if save_param_scheduler and not hasattr(self, 'param_schedulers'):
|
||||
self.logger.warning(
|
||||
'`save_param_scheduler` is True but strategy has no '
|
||||
'param_schedulers attribute, so skip saving parameter '
|
||||
'schedulers')
|
||||
save_param_scheduler = False
|
||||
|
||||
if save_param_scheduler:
|
||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
state_dict['param_schedulers'] = self.scheduler_state_dict()
|
||||
|
||||
# save extra checkpoint passed by users
|
||||
|
17
mmengine/_strategy/utils.py
Normal file
17
mmengine/_strategy/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from torch._subclasses.fake_tensor import _is_tensor_constructor
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class MetaTensorContext(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if _is_tensor_constructor(func):
|
||||
device_idx = [arg.name
|
||||
for arg in func._schema.arguments].index('device')
|
||||
if len(args) > device_idx:
|
||||
args = list(args)
|
||||
args[device_idx] = 'meta'
|
||||
else:
|
||||
kwargs['device'] = 'meta'
|
||||
return func(*args, **kwargs)
|
@ -33,6 +33,6 @@ __all__ = [
|
||||
'convert_sync_batchnorm', 'BaseTTAModel'
|
||||
]
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
||||
from .wrappers import MMFullyShardedDataParallel # noqa:F401
|
||||
__all__.append('MMFullyShardedDataParallel')
|
||||
|
@ -10,7 +10,7 @@ __all__ = [
|
||||
'MMSeparateDistributedDataParallel'
|
||||
]
|
||||
|
||||
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
|
||||
from .fully_sharded_distributed import \
|
||||
MMFullyShardedDataParallel # noqa:F401
|
||||
__all__.append('MMFullyShardedDataParallel')
|
||||
|
@ -1,18 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
# yapf: disable
|
||||
from torch.distributed.fsdp.api import (FullStateDictConfig,
|
||||
LocalOptimStateDictConfig,
|
||||
LocalStateDictConfig,
|
||||
OptimStateDictConfig,
|
||||
ShardedOptimStateDictConfig,
|
||||
ShardedStateDictConfig,
|
||||
ShardingStrategy, StateDictConfig,
|
||||
StateDictSettings, StateDictType)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch, CPUOffload, FullyShardedDataParallel)
|
||||
BackwardPrefetch, CPUOffload, FullOptimStateDictConfig,
|
||||
FullyShardedDataParallel, MixedPrecision)
|
||||
|
||||
# yapf: enable
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS, Registry
|
||||
from mmengine.registry import FUNCTIONS, MODEL_WRAPPERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
# support customize fsdp policy
|
||||
FSDP_WRAP_POLICIES = Registry('fsdp wrap policy')
|
||||
from mmengine.utils import digit_version, is_seq_of
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
@ -40,8 +52,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to be wrapped with FSDP.
|
||||
process_group (Optional[ProcessGroup]): process group for sharding.
|
||||
cpu_offload (Optional[Union[bool,CPUOffload]]):
|
||||
process_group (ProcessGroup, optional): process group for sharding.
|
||||
cpu_offload (bool, CPUOffload, optional):
|
||||
CPU offloading config.
|
||||
Different from FullyShardedDataParallel,Since it can be set by
|
||||
users' pre-defined config in MMEngine,its type is expected to be
|
||||
@ -54,7 +66,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
for params and grads to be on same device to work with optimizer.
|
||||
This API is subject to change. Default is ``None`` in which case
|
||||
there will be no offloading.
|
||||
fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]):
|
||||
auto_wrap_policy (str or Callable, optional):
|
||||
Specifying a policy to recursively wrap layers with FSDP.
|
||||
Different from FullyShardedDataParallel, Since it can be set by
|
||||
users' pre-defined config in MMEngine, its type is expected to be
|
||||
@ -68,12 +80,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
the returned FSDP root instance.
|
||||
``default_auto_wrap_policy`` written in
|
||||
``torch.distributed.fsdp.wrap`` is an example of
|
||||
``fsdp_auto_wrap_policy`` callable, this policy wraps layers with
|
||||
``auto_wrap_policy`` callable, this policy wraps layers with
|
||||
parameter sizes larger than 100M. Users can supply the customized
|
||||
``fsdp_auto_wrap_policy`` callable that should accept following
|
||||
``auto_wrap_policy`` callable that should accept following
|
||||
arguments: ``module: nn.Module``, ``recurse: bool``,
|
||||
``unwrapped_params: int``, extra customized arguments could be
|
||||
added to the customized ``fsdp_auto_wrap_policy`` callable as well.
|
||||
added to the customized ``auto_wrap_policy`` callable as well.
|
||||
|
||||
Example::
|
||||
|
||||
@ -86,18 +98,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
>>> ) -> bool:
|
||||
>>> return unwrapped_params >= min_num_params
|
||||
|
||||
backward_prefetch: (Optional[Union[str,BackwardPrefetch]]):
|
||||
Different from FullyShardedDataParallel, Since it will be set by
|
||||
users' pre-defined config in MMEngine,its type is expected to be
|
||||
`None`, `str` or `BackwardPrefetch`.
|
||||
backward_prefetch (str or BackwardPrefetch, optional):
|
||||
Different from FullyShardedDataParallel, this argument could be a
|
||||
string or a BackwardPrefetch instance. If it's a string, then
|
||||
it should be ``BACKWARD_PRE`` or ``BACKWARD_POST``
|
||||
mixed_precision (dict or MixedPrecision, optional):
|
||||
This configures native mixed precision for FSDP. If this is set to
|
||||
``None``. Different from the native FSDP, this argument can a dict
|
||||
like this:
|
||||
|
||||
This is an experimental feature that is subject to change in the
|
||||
the near future. It allows users to enable two different
|
||||
backward_prefetch algorithms to help backward communication and
|
||||
computation overlapping.
|
||||
Pros and cons of each algorithm is explained in class
|
||||
``BackwardPrefetch``.
|
||||
Examples:
|
||||
>>> mixed_precision=dict(param_dtype='float16',
|
||||
>>> buffer_dtype='float32',
|
||||
>>> reduce_dtype='float32')
|
||||
|
||||
Defaults to None.
|
||||
use_orig_params (bool): Different from native
|
||||
``FullyShardedDataParallel``, it defaults to True.
|
||||
**kwargs: Keyword arguments passed to
|
||||
:class:`FullyShardedDataParallel`.
|
||||
"""
|
||||
@ -105,61 +122,117 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
cpu_offload: Optional[Union[bool, CPUOffload]] = None,
|
||||
fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None,
|
||||
backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None,
|
||||
process_group: Union[dict, ProcessGroup, None] = None,
|
||||
sharding_strategy: Union[str, ShardingStrategy] = None,
|
||||
cpu_offload: Union[bool, CPUOffload, None] = None,
|
||||
auto_wrap_policy: Union[str, Callable, None] = None,
|
||||
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
|
||||
mixed_precision: Union[dict, MixedPrecision, None] = None,
|
||||
ignored_modules: Union[Iterable[str], Iterable[nn.Module],
|
||||
None] = None,
|
||||
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
|
||||
use_orig_params: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(sharding_strategy, str):
|
||||
sharding_strategy = ShardingStrategy[sharding_strategy]
|
||||
if not (isinstance(sharding_strategy, ShardingStrategy)
|
||||
or sharding_strategy is None):
|
||||
raise TypeError(
|
||||
'sharding_strategy must be str or enum of `ShardingStrategy` '
|
||||
f', but got {sharding_strategy}')
|
||||
|
||||
if cpu_offload is not None:
|
||||
if isinstance(cpu_offload, bool):
|
||||
cpu_offload = CPUOffload(offload_params=cpu_offload)
|
||||
elif not isinstance(cpu_offload, CPUOffload):
|
||||
if not (isinstance(cpu_offload, CPUOffload) or cpu_offload is None):
|
||||
raise TypeError(
|
||||
'`cpu_offload` should be `None`, `bool`'
|
||||
f'or `CPUOffload`, but has type {type(cpu_offload)}')
|
||||
|
||||
if fsdp_auto_wrap_policy is not None:
|
||||
if isinstance(fsdp_auto_wrap_policy, str):
|
||||
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \
|
||||
'`FSDP_WRAP_POLICIES` has no ' \
|
||||
f'function {fsdp_auto_wrap_policy}'
|
||||
fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore
|
||||
fsdp_auto_wrap_policy)
|
||||
if not isinstance(fsdp_auto_wrap_policy,
|
||||
Callable): # type: ignore
|
||||
raise TypeError(
|
||||
'Registered `fsdp_auto_wrap_policy` needs to be '
|
||||
'`Callable`, but has type '
|
||||
f'{type(fsdp_auto_wrap_policy)}')
|
||||
elif not isinstance(fsdp_auto_wrap_policy,
|
||||
Callable): # type: ignore
|
||||
raise TypeError(
|
||||
'`fsdp_auto_wrap_policy` should be `None`, `str` '
|
||||
'or `Callable`, but has type '
|
||||
f'{type(fsdp_auto_wrap_policy)}')
|
||||
if isinstance(auto_wrap_policy, str):
|
||||
auto_wrap_policy = FUNCTIONS.get( # type: ignore
|
||||
auto_wrap_policy)
|
||||
if auto_wrap_policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
elif isinstance(auto_wrap_policy, dict):
|
||||
ori_func = FUNCTIONS.get( # type: ignore
|
||||
auto_wrap_policy.pop('type'))
|
||||
if auto_wrap_policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
auto_wrap_policy = partial(ori_func, **auto_wrap_policy)
|
||||
|
||||
if not (auto_wrap_policy is None
|
||||
or callable(auto_wrap_policy)): # type: ignore
|
||||
raise TypeError('`auto_wrap_policy` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(auto_wrap_policy)}')
|
||||
|
||||
if backward_prefetch is not None:
|
||||
if isinstance(backward_prefetch, str):
|
||||
assert backward_prefetch in ['pre', 'post'], \
|
||||
'`backward_prefetch` should be either `pre` or `post`,' \
|
||||
f' but get {backward_prefetch}'
|
||||
if backward_prefetch == 'pre':
|
||||
backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
||||
else:
|
||||
backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||
elif not isinstance(backward_prefetch, BackwardPrefetch):
|
||||
raise TypeError('`backward_prefetch` should be `None`, `str` '
|
||||
'or `BackwardPrefetch`, but has type '
|
||||
f'{type(backward_prefetch)}')
|
||||
backward_prefetch = BackwardPrefetch[backward_prefetch]
|
||||
if not (isinstance(backward_prefetch, BackwardPrefetch)
|
||||
or backward_prefetch is None):
|
||||
raise TypeError(
|
||||
'`backward_prefetch` should be `None`, string of '
|
||||
'"BACKWARD_PRE" and "BACKWARD_POST", or '
|
||||
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')
|
||||
|
||||
if isinstance(param_init_fn, str):
|
||||
param_init_fn = FUNCTIONS.get( # type: ignore
|
||||
param_init_fn)
|
||||
if param_init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
elif isinstance(param_init_fn, dict):
|
||||
param_init_fn = FUNCTIONS.get(param_init_fn.pop('type'))
|
||||
if param_init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
param_init_fn = partial(param_init_fn, **param_init_fn)
|
||||
|
||||
if not (callable(param_init_fn) or param_init_fn is None):
|
||||
raise TypeError('`param_init_fn` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(param_init_fn)}')
|
||||
|
||||
def parse_dtype(dtype):
|
||||
if dtype is None:
|
||||
return None
|
||||
elif isinstance(dtype, str):
|
||||
return getattr(torch, dtype)
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
else:
|
||||
raise TypeError(
|
||||
'`dtype` should be `None`, `str` or `torch.dtype`, '
|
||||
f'but has type {type(dtype)}')
|
||||
|
||||
if isinstance(mixed_precision, dict):
|
||||
mixed_precision['param_dtype'] = parse_dtype(
|
||||
mixed_precision.get('param_dtype', None))
|
||||
mixed_precision['reduce_dtype'] = parse_dtype(
|
||||
mixed_precision.get('reduce_dtype', None))
|
||||
mixed_precision['buffer_dtype'] = parse_dtype(
|
||||
mixed_precision.get('buffer_dtype', None))
|
||||
mixed_precision = MixedPrecision(**mixed_precision)
|
||||
elif isinstance(mixed_precision, MixedPrecision):
|
||||
mixed_precision = mixed_precision
|
||||
elif mixed_precision is not None:
|
||||
raise TypeError(
|
||||
'`mixed_precision` should be `None`, `dict` or '
|
||||
f'`MixedPrecision`, but has type {type(mixed_precision)}')
|
||||
|
||||
self._fixed_modules = self._get_fixed_module(module, ignored_modules)
|
||||
ignored_modules = [] if ignored_modules is None else ignored_modules
|
||||
ignored_modules = chain(ignored_modules, self._fixed_modules)
|
||||
super().__init__(
|
||||
module=module,
|
||||
process_group=process_group,
|
||||
auto_wrap_policy=fsdp_auto_wrap_policy,
|
||||
sharding_strategy=sharding_strategy,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
cpu_offload=cpu_offload,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
use_orig_params=use_orig_params,
|
||||
**kwargs)
|
||||
|
||||
def train_step(self, data: dict,
|
||||
@ -207,8 +280,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
Returns:
|
||||
List[BaseDataElement] or dict: The predictions of given data.
|
||||
"""
|
||||
inputs, data_sample = self.module.data_preprocessor(data, False)
|
||||
return self(inputs, data_sample, mode='predict')
|
||||
data = self.module.data_preprocessor(data, False)
|
||||
return self._run_forward(data, mode='predict') # type: ignore
|
||||
|
||||
def test_step(self, data: dict) -> List[BaseDataElement]:
|
||||
"""Gets the predictions of module during testing process.
|
||||
@ -219,5 +292,156 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
"""
|
||||
inputs, data_sample = self.module.data_preprocessor(data, False)
|
||||
return self(inputs, data_sample, mode='predict')
|
||||
data = self.module.data_preprocessor(data, False)
|
||||
return self._run_forward(data, mode='predict') # type: ignore
|
||||
|
||||
def _run_forward(self, data: Union[dict, tuple, list],
|
||||
mode: str) -> Union[Dict[str, torch.Tensor], list]:
|
||||
"""Unpacks data for :meth:`forward`
|
||||
Args:
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
mode (str): Mode of forward.
|
||||
Returns:
|
||||
dict or list: Results of training or testing mode.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
results = self(**data, mode=mode)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
results = self(*data, mode=mode)
|
||||
else:
|
||||
raise TypeError('Output of `data_preprocessor` should be '
|
||||
f'list, tuple or dict, but got {type(data)}')
|
||||
return results
|
||||
|
||||
def _get_fixed_module(self, module, ignored_modules):
|
||||
module_dict = dict(module.named_modules())
|
||||
if is_seq_of(ignored_modules, str):
|
||||
ignored_modules = [module_dict[name] for name in ignored_modules]
|
||||
if not is_seq_of(ignored_modules,
|
||||
nn.Module) and ignored_modules is not None:
|
||||
raise TypeError(
|
||||
'`ignored_modules` should be `None`, `Iterable[str]` or '
|
||||
f'`Iterable[nn.Module]`, but has type {type(ignored_modules)}')
|
||||
|
||||
def find_fixed_modules_recursively(
|
||||
root_module: nn.Module) -> List[nn.Module]:
|
||||
"""Helper function to find fixed modules whose parameters are all
|
||||
untrainable, i.e. `requires_grad=False`.
|
||||
|
||||
This function performs
|
||||
recursively.
|
||||
Args:
|
||||
root_module (nn.Module): root module for recursion
|
||||
Returns:
|
||||
List[nn.Module]: fixed modules in root_module
|
||||
"""
|
||||
if all(p.requires_grad for p in root_module.parameters()):
|
||||
return []
|
||||
if all(not p.requires_grad for p in root_module.parameters()):
|
||||
return [root_module]
|
||||
fixed_modules = []
|
||||
for sub_module in root_module.children():
|
||||
fixed_modules.extend(
|
||||
find_fixed_modules_recursively(sub_module))
|
||||
return fixed_modules
|
||||
|
||||
fixed_modules = find_fixed_modules_recursively(module)
|
||||
return fixed_modules
|
||||
|
||||
if digit_version(torch.__version__) < digit_version('2.0.1'):
|
||||
|
||||
@staticmethod
|
||||
def optim_state_dict(
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
|
||||
model)
|
||||
return FullyShardedDataParallel._optim_state_dict_impl(
|
||||
model=model,
|
||||
optim=optim,
|
||||
optim_state_dict=optim.state_dict(),
|
||||
optim_input=None,
|
||||
rank0_only=getattr(state_dict_settings.optim_state_dict_config,
|
||||
'rank0_only', False),
|
||||
full_state_dict=state_dict_settings.state_dict_type ==
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_state_dict_type(
|
||||
module: nn.Module,
|
||||
state_dict_type: StateDictType,
|
||||
state_dict_config: Optional[StateDictConfig] = None,
|
||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||
) -> StateDictSettings:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
_state_dict_type_to_config = {
|
||||
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
|
||||
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
|
||||
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
|
||||
}
|
||||
_optim_state_dict_type_to_config = {
|
||||
StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig,
|
||||
StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig,
|
||||
StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig,
|
||||
}
|
||||
|
||||
# Use the default config if a state_dict config is not set.
|
||||
state_dict_config_type = _state_dict_type_to_config[
|
||||
state_dict_type]
|
||||
optim_state_dict_config_type = _optim_state_dict_type_to_config[
|
||||
state_dict_type]
|
||||
if state_dict_config is None:
|
||||
state_dict_config = state_dict_config_type()
|
||||
if optim_state_dict_config is None:
|
||||
optim_state_dict_config = optim_state_dict_config_type()
|
||||
if state_dict_config_type != type(state_dict_config):
|
||||
raise RuntimeError('Expected state_dict_config of type '
|
||||
f'{state_dict_config_type} '
|
||||
f'but got {type(state_dict_config)}')
|
||||
if optim_state_dict_config_type != type(optim_state_dict_config):
|
||||
raise RuntimeError('Expected optim_state_dict_config of type '
|
||||
f'{optim_state_dict_config_type} '
|
||||
f'but got {type(optim_state_dict_config)}')
|
||||
|
||||
# Set the state_dict type and configurations.
|
||||
prev_state_dict_type = None
|
||||
prev_state_dict_config = None
|
||||
prev_optim_state_dict_config = None
|
||||
for submodule in traversal_utils._get_fsdp_states(module):
|
||||
if prev_state_dict_type is None:
|
||||
prev_state_dict_type = submodule._state_dict_type
|
||||
else:
|
||||
assert (
|
||||
prev_state_dict_type == submodule._state_dict_type
|
||||
), 'All FSDP modules should have the same state_dict_type.'
|
||||
if prev_state_dict_config is None:
|
||||
prev_state_dict_config = submodule._state_dict_config
|
||||
else:
|
||||
assert isinstance(
|
||||
submodule._state_dict_config,
|
||||
type(prev_state_dict_config)), (
|
||||
'All FSDP modules must have the same type of '
|
||||
'state_dict_config.')
|
||||
if prev_optim_state_dict_config is None:
|
||||
prev_optim_state_dict_config = \
|
||||
submodule._optim_state_dict_config
|
||||
else:
|
||||
assert isinstance(
|
||||
submodule._optim_state_dict_config,
|
||||
type(prev_optim_state_dict_config),
|
||||
), ('All FSDP modules must have the same type of '
|
||||
'optim_state_dict_config.')
|
||||
|
||||
submodule._state_dict_type = state_dict_type
|
||||
submodule._state_dict_config = state_dict_config
|
||||
submodule._optim_state_dict_config = optim_state_dict_config
|
||||
|
||||
return StateDictSettings(prev_state_dict_type,
|
||||
prev_state_dict_config,
|
||||
prev_optim_state_dict_config)
|
||||
|
@ -48,6 +48,10 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
`'float64'`. If set to ``None``, the default data type will be used.
|
||||
Defaults to None.
|
||||
`New in version 0.6.1.`
|
||||
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should
|
||||
be enabled when using ``FullyShardedDataParallel``.
|
||||
Defaults to False.
|
||||
`New in version 0.8.0.`
|
||||
**kwargs: Keyword arguments passed to OptimWrapper.
|
||||
|
||||
Warnings:
|
||||
@ -65,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
def __init__(self,
|
||||
loss_scale: str = 'dynamic',
|
||||
dtype: Union[str, torch.dtype] = None,
|
||||
use_fsdp: bool = False,
|
||||
**kwargs):
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
@ -73,17 +78,29 @@ class AmpOptimWrapper(OptimWrapper):
|
||||
'on gpu, npu or mlu')
|
||||
super().__init__(**kwargs)
|
||||
self._scale_update_param = None
|
||||
|
||||
if use_fsdp:
|
||||
if digit_version(torch.__version__) >= digit_version('2.0.0'):
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import \
|
||||
ShardedGradScaler
|
||||
scaler_type = ShardedGradScaler
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`')
|
||||
else:
|
||||
scaler_type = GradScaler
|
||||
|
||||
if loss_scale == 'dynamic':
|
||||
# If loss_scale is a string, it must be 'dynamic', then dynamic
|
||||
# loss scaling will be used.
|
||||
self.loss_scaler = GradScaler()
|
||||
self.loss_scaler = scaler_type()
|
||||
elif isinstance(loss_scale, float):
|
||||
# Static loss scaling
|
||||
self._scale_update_param = loss_scale
|
||||
self.loss_scaler = GradScaler(init_scale=loss_scale)
|
||||
self.loss_scaler = scaler_type(init_scale=loss_scale)
|
||||
elif isinstance(loss_scale, dict):
|
||||
# More specific configuration.
|
||||
self.loss_scaler = GradScaler(**loss_scale)
|
||||
self.loss_scaler = scaler_type(**loss_scale)
|
||||
else:
|
||||
raise TypeError('loss_scale must be of type float, dict, or '
|
||||
f'"dynamic", but got {loss_scale}')
|
||||
|
@ -632,6 +632,18 @@ class FlexibleRunner:
|
||||
|
||||
assert isinstance(strategy, dict)
|
||||
|
||||
# train_micro_batch_size_per_gpu is required by DeepSpeed
|
||||
if isinstance(strategy['type'], str):
|
||||
strategy_name = strategy['type']
|
||||
else:
|
||||
strategy_name = strategy['type'].__name__
|
||||
if strategy_name == 'DeepSpeedStrategy':
|
||||
if self._train_dataloader is None:
|
||||
strategy['train_micro_batch_size_per_gpu'] = 1
|
||||
else:
|
||||
strategy['train_micro_batch_size_per_gpu'] = \
|
||||
_get_batch_size(self._train_dataloader)
|
||||
|
||||
strategy.setdefault('work_dir', self._work_dir)
|
||||
strategy.setdefault('experiment_name', experiment_name)
|
||||
strategy.setdefault('auto_scale_lr', self._auto_scale_lr)
|
||||
@ -1140,18 +1152,13 @@ class FlexibleRunner:
|
||||
compile = copy.copy(self._compile)
|
||||
compile.setdefault('target', 'train_step')
|
||||
|
||||
if self.train_dataloader.batch_size is not None:
|
||||
micro_batch_size = self.train_dataloader.batch_size
|
||||
else:
|
||||
micro_batch_size = self.train_dataloader.batch_sampler.batch_size
|
||||
dispatch_kwargs = dict(
|
||||
train_micro_batch_size_per_gpu=micro_batch_size,
|
||||
num_batches_per_epoch=len(self.train_dataloader),
|
||||
epoch_length=len(self.train_dataloader),
|
||||
max_epochs=self.max_epochs,
|
||||
max_iters=self.max_iters,
|
||||
)
|
||||
|
||||
result = self.strategy.prepare(
|
||||
self.strategy.prepare(
|
||||
self.model,
|
||||
optim_wrapper=self.optim_wrapper,
|
||||
param_scheduler=self.param_schedulers,
|
||||
@ -1159,10 +1166,10 @@ class FlexibleRunner:
|
||||
dispatch_kwargs=dispatch_kwargs,
|
||||
)
|
||||
|
||||
self.model = self.strategy.model
|
||||
self.optim_wrapper = self.strategy.optim_wrapper # type: ignore
|
||||
if self.param_schedulers is not None:
|
||||
self.model, self.optim_wrapper, self.param_schedulers, *_ = result
|
||||
else:
|
||||
self.model, self.optim_wrapper, *_ = result
|
||||
self.param_schedulers = self.strategy.param_schedulers
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1187,7 +1194,11 @@ class FlexibleRunner:
|
||||
|
||||
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
|
||||
|
||||
self.model = self.strategy.prepare(self.model)
|
||||
dispatch_kwargs = dict(
|
||||
init_weights_for_test_or_val=self.cfg.get(
|
||||
'init_weights_for_test_or_val', True))
|
||||
self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs)
|
||||
self.model = self.strategy.model
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1210,8 +1221,11 @@ class FlexibleRunner:
|
||||
'`test_evaluator` arguments when initializing runner.')
|
||||
|
||||
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
|
||||
|
||||
self.model = self.strategy.prepare(self.model)
|
||||
dispatch_kwargs = dict(
|
||||
init_weights_for_test_or_val=self.cfg.get(
|
||||
'init_weights_for_test_or_val', True))
|
||||
self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs)
|
||||
self.model = self.strategy.model
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1613,3 +1627,20 @@ class FlexibleRunner:
|
||||
|
||||
if self.cfg._cfg_dict:
|
||||
self.logger.info(f'Config:\n{self.cfg.pretty_text}')
|
||||
|
||||
|
||||
def _get_batch_size(dataloader):
|
||||
if isinstance(dataloader, dict):
|
||||
if 'batch_size' in dataloader:
|
||||
return dataloader['batch_size']
|
||||
elif ('batch_sampler' in dataloader
|
||||
and 'batch_size' in dataloader['batch_sampler']):
|
||||
return dataloader['batch_sampler']['batch_size']
|
||||
else:
|
||||
raise ValueError('Please set batch_size in `Dataloader` or '
|
||||
'`batch_sampler`')
|
||||
elif isinstance(dataloader, DataLoader):
|
||||
return dataloader.batch_sampler.batch_size
|
||||
else:
|
||||
raise ValueError('dataloader should be a dict or a Dataloader '
|
||||
f'instance, but got {type(dataloader)}')
|
||||
|
@ -36,14 +36,14 @@ class ToyModel(BaseModel):
|
||||
if isinstance(inputs, list):
|
||||
inputs = torch.stack(inputs)
|
||||
if isinstance(data_samples, list):
|
||||
data_sample = torch.stack(data_samples)
|
||||
data_samples = torch.stack(data_samples)
|
||||
outputs = self.linear1(inputs)
|
||||
outputs = self.linear2(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (data_sample - outputs).sum()
|
||||
loss = (data_samples - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
|
@ -8,7 +8,7 @@ import torch.distributed as torch_dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.dist import all_gather, broadcast
|
||||
from mmengine.model import (BaseDataPreprocessor, BaseModel,
|
||||
ExponentialMovingAverage,
|
||||
MMDistributedDataParallel,
|
||||
@ -222,8 +222,8 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
|
||||
@unittest.skipIf(
|
||||
torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp')
|
||||
@unittest.skipIf(
|
||||
digit_version(TORCH_VERSION) < digit_version('1.11.0'),
|
||||
reason='fsdp needs Pytorch 1.11 or higher')
|
||||
digit_version(TORCH_VERSION) < digit_version('2.0.0'),
|
||||
reason='fsdp needs Pytorch 2.0.0 or higher')
|
||||
class TestMMFullyShardedDataParallel(MultiProcessTestCase):
|
||||
|
||||
def _init_dist_env(self, rank, world_size):
|
||||
@ -247,19 +247,38 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase):
|
||||
model = ToyModel()
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
optimizer = SGD(fsdp_model.parameters(), lr=0)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=[inputs], data_sample=MagicMock())
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
fsdp_model.train()
|
||||
self.assertTrue(fsdp_model.training)
|
||||
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
|
||||
|
||||
# require_grad=False
|
||||
model = ToyModel()
|
||||
for _, param in model.state_dict().items():
|
||||
broadcast(param)
|
||||
model.conv1.requires_grad_(False)
|
||||
ori_weight = model.conv1.weight.clone()
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
optimizer = SGD(fsdp_model.parameters(), lr=0.1)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1)
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
fsdp_model.train()
|
||||
self.assertTrue(fsdp_model.training)
|
||||
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
|
||||
|
||||
with fsdp_model.summon_full_params(fsdp_model):
|
||||
updated_weight = fsdp_model.module.conv1.weight.cpu()
|
||||
assert_allclose(ori_weight, updated_weight)
|
||||
|
||||
def test_val_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=[inputs], data_sample=MagicMock())
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
# Test get predictions.
|
||||
predictions = fsdp_model.val_step(data)
|
||||
self.assertIsInstance(predictions, torch.Tensor)
|
||||
|
231
tests/test_strategies/test_fsdp.py
Normal file
231
tests/test_strategies/test_fsdp.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase, skipIf
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from torch.distributed.fsdp import (FullStateDictConfig,
|
||||
FullyShardedDataParallel,
|
||||
LocalStateDictConfig, StateDictType)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig, LocalOptimStateDictConfig)
|
||||
|
||||
from mmengine._strategy import FSDPStrategy
|
||||
except: # noqa: E722
|
||||
pass
|
||||
from torch.multiprocessing.spawn import start_processes
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.dist import (all_gather_object, broadcast_object_list,
|
||||
is_main_process)
|
||||
from mmengine.optim import LinearLR, OptimWrapper
|
||||
from mmengine.testing.runner_test_case import ToyModel
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
|
||||
def linear_wrap_policy(
|
||||
module,
|
||||
recurse,
|
||||
nonwrapped_numel,
|
||||
) -> bool:
|
||||
if recurse:
|
||||
return True # always recurse
|
||||
return isinstance(module, nn.Linear)
|
||||
|
||||
|
||||
@skipIf(
|
||||
digit_version(torch.__version__) < digit_version('2.0.0')
|
||||
or not torch.cuda.is_available(),
|
||||
'Only test FSDP with CUDA and PyTorch >= 2.0.0')
|
||||
class TestStrategy(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.world_size = 2
|
||||
self.temp_dir = TemporaryDirectory()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_init(self):
|
||||
strategy = FSDPStrategy()
|
||||
self.assertFalse(strategy.skip_init_weights)
|
||||
strategy = FSDPStrategy(state_dict_cfg='local')
|
||||
self._assert_local(strategy)
|
||||
|
||||
strategy = FSDPStrategy(state_dict_cfg='full')
|
||||
self._assert_full(strategy)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=StateDictType.LOCAL_STATE_DICT))
|
||||
self._assert_local(strategy)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=FullStateDictConfig(),
|
||||
optim_state_dict_config=FullOptimStateDictConfig(),
|
||||
))
|
||||
self._assert_full(strategy)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type='FULL_STATE_DICT',
|
||||
state_dict_config=dict(type='FullStateDictConfig'),
|
||||
optim_state_dict_config=dict(type='FullOptimStateDictConfig'),
|
||||
))
|
||||
self._assert_full(strategy)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=dict(type=FullStateDictConfig),
|
||||
optim_state_dict_config=dict(type=FullOptimStateDictConfig),
|
||||
))
|
||||
self._assert_full(strategy)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
strategy = FSDPStrategy(state_dict_cfg='error-str')
|
||||
|
||||
# state_dict_cfg should be a str or a dict
|
||||
with self.assertRaises(TypeError):
|
||||
strategy = FSDPStrategy(state_dict_cfg=[])
|
||||
|
||||
# state_dict_type must be a str or a enumerate of StateDictType
|
||||
with self.assertRaises(TypeError):
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=[],
|
||||
state_dict_config=dict(type=FullStateDictConfig),
|
||||
optim_state_dict_config=dict(
|
||||
type=FullOptimStateDictConfig),
|
||||
))
|
||||
|
||||
# state_dict_config should be a dict or a subclass of StateDictConfig
|
||||
with self.assertRaises(TypeError):
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=[],
|
||||
optim_state_dict_config=dict(
|
||||
type=FullOptimStateDictConfig),
|
||||
))
|
||||
|
||||
# optim_state_dict_config should be a dict or a subclass of
|
||||
# OptimStateDictConfig
|
||||
with self.assertRaises(TypeError):
|
||||
strategy = FSDPStrategy(
|
||||
state_dict_cfg=dict(
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
state_dict_config=dict(type=FullStateDictConfig),
|
||||
optim_state_dict_config=[],
|
||||
))
|
||||
|
||||
def run_strategy(self):
|
||||
# Strategy can run with the built model, optimizer and schedulers.
|
||||
for skip_init_weights, state_dict_cfg in [(True, 'local'),
|
||||
(False, 'full')]:
|
||||
strategy = FSDPStrategy(
|
||||
skip_init_weights=skip_init_weights,
|
||||
state_dict_cfg=state_dict_cfg,
|
||||
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
|
||||
model = ToyModel()
|
||||
optim = OptimWrapper(SGD(model.parameters(), lr=0.1, momentum=0.9))
|
||||
lr_scheduler = LinearLR(optimizer=optim)
|
||||
model, optim, lr_scheduler = strategy.prepare(
|
||||
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
|
||||
self.assertIsInstance(model, FullyShardedDataParallel)
|
||||
self.assertIsInstance(model.linear1, FullyShardedDataParallel)
|
||||
self.assertIsInstance(model.linear2, FullyShardedDataParallel)
|
||||
|
||||
data = torch.ones(2, 2).cuda()
|
||||
data_samples = torch.zeros(2, 2).cuda()
|
||||
loss = model(data, data_samples=data_samples, mode='loss')['loss']
|
||||
loss.backward()
|
||||
optim.step()
|
||||
[scheduler.step() for scheduler in lr_scheduler]
|
||||
|
||||
ckpt_path = osp.join(self.temp_dir.name,
|
||||
f'checkpoint_{state_dict_cfg}.pth')
|
||||
strategy.save_checkpoint(ckpt_path)
|
||||
|
||||
if state_dict_cfg == 'full':
|
||||
if not is_main_process():
|
||||
self.assertFalse(osp.exists(ckpt_path))
|
||||
ckpt_path = [ckpt_path]
|
||||
broadcast_object_list(ckpt_path)
|
||||
ckpt_path = ckpt_path[0]
|
||||
|
||||
strategy.load_checkpoint(ckpt_path)
|
||||
loss = model(data, data_samples=data_samples, mode='loss')['loss']
|
||||
loss.backward()
|
||||
optim.step()
|
||||
[scheduler.step() for scheduler in lr_scheduler]
|
||||
|
||||
# optimizer with multiple param_groups can be reconstructed.
|
||||
model = ToyModel()
|
||||
strategy = FSDPStrategy(
|
||||
model_wrapper=dict(auto_wrap_policy=linear_wrap_policy))
|
||||
param_groups = []
|
||||
for param in model.parameters():
|
||||
param_groups.append(dict(params=[param], lr=0.1))
|
||||
optim = SGD(param_groups, lr=0.1, momentum=0.9)
|
||||
lr_scheduler = LinearLR(optimizer=optim)
|
||||
model, optim, lr_scheduler = strategy.prepare(
|
||||
model=model, optim_wrapper=optim, param_scheduler=lr_scheduler)
|
||||
data = torch.ones(2, 2).cuda()
|
||||
data_samples = torch.zeros(2, 2).cuda()
|
||||
loss = model(data, data_samples=data_samples, mode='loss')['loss']
|
||||
loss.backward()
|
||||
optim.step()
|
||||
[scheduler.step() for scheduler in lr_scheduler]
|
||||
optim_state = optim.state_dict()['state']
|
||||
optim_state = all_gather_object(optim_state)
|
||||
|
||||
@classmethod
|
||||
def _worker(cls, rank, func):
|
||||
# local mode
|
||||
self = cls()
|
||||
self.setUp()
|
||||
self.rank = rank
|
||||
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(self.world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(12123)
|
||||
torch.cuda.set_device(f'cuda:{rank}')
|
||||
|
||||
getattr(self, func)()
|
||||
self.tearDown()
|
||||
|
||||
def test_run_strategy(self):
|
||||
start_processes(
|
||||
TestStrategy._worker,
|
||||
args=('run_strategy', ),
|
||||
nprocs=self.world_size)
|
||||
|
||||
def test_build_model(self):
|
||||
...
|
||||
# TODO
|
||||
# strategy = FSDPStrategy()
|
||||
# model = ToyModel()
|
||||
# state_dict = dict()
|
||||
|
||||
def _assert_local(self, strategy):
|
||||
self.assertEqual(strategy.state_dict_type,
|
||||
StateDictType.LOCAL_STATE_DICT)
|
||||
self.assertIsInstance(strategy.state_dict_config, LocalStateDictConfig)
|
||||
self.assertIsInstance(strategy.optim_state_dict_config,
|
||||
LocalOptimStateDictConfig)
|
||||
|
||||
def _assert_full(self, strategy):
|
||||
self.assertEqual(strategy.state_dict_type,
|
||||
StateDictType.FULL_STATE_DICT)
|
||||
self.assertIsInstance(strategy.state_dict_config, FullStateDictConfig)
|
||||
self.assertIsInstance(strategy.optim_state_dict_config,
|
||||
FullOptimStateDictConfig)
|
Loading…
x
Reference in New Issue
Block a user