[Feature] Support using gradient checkpointing in FSDP (#1382)

This commit is contained in:
Mashiro 2023-10-09 21:04:55 +08:00 committed by GitHub
parent bf30c444de
commit 8015d62202
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import os
import os.path as osp import os.path as osp
import time import time
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
import torch.nn as nn import torch.nn as nn
@ -25,7 +26,7 @@ from mmengine.model import BaseDataPreprocessor, is_model_wrapper
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper) build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry) PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy from .distributed import DDPStrategy
@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
:meth:`setup_env`. Defaults to None. :meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in - log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None. :meth:`build_logger`. Defaults to None.
activation_checkpointing (dict, optional): Config dict for gradient
checkpoint.
Examples:
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
``check_fn`` field should behave consistently with
``auto_wrap_policy`` defined in `model_wrapper`, and other
fields will be passed to ``apply_activation_checkpointing``
`New in version 0.9.0.`
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501 """ # noqa: E501
@ -100,6 +114,7 @@ class FSDPStrategy(DDPStrategy):
model_wrapper: Optional[dict] = None, model_wrapper: Optional[dict] = None,
skip_init_weights=False, skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local', state_dict_cfg: Union[str, dict] = 'local',
activation_checkpointing: Optional[dict] = None,
**kwargs): **kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs) super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg) self._init_state_dict_cfg(state_dict_cfg)
@ -107,6 +122,7 @@ class FSDPStrategy(DDPStrategy):
raise TypeError('skip_init_weights must be a boolean, but got ' raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}') f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights self.skip_init_weights = skip_init_weights
self.activation_checkpointing = activation_checkpointing
def _wrap_model(self, model: nn.Module) -> None: def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
@ -119,6 +135,12 @@ class FSDPStrategy(DDPStrategy):
FullyShardedDataParallel: ``MMFullyShardedDataParallel`` FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``. or subclass of ``FullyShardedDataParallel``.
""" """
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
apply_activation_checkpointing # noqa: E501
except ImportError:
apply_activation_checkpointing = None
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseDataPreprocessor): if isinstance(module, BaseDataPreprocessor):
module.to(get_device()) module.to(get_device())
@ -138,6 +160,27 @@ class FSDPStrategy(DDPStrategy):
model.set_state_dict_type(model, self.state_dict_type, model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config, self.state_dict_config,
self.optim_state_dict_config) self.optim_state_dict_config)
if self.activation_checkpointing is not None:
if apply_activation_checkpointing is None:
raise RuntimeError(
'activation_checkpointing maybe deprecated by current '
'PyTorch version, maybe you could switch to PyTorch 2.0 '
'or 2.1 to use `activation_checkpointing`.')
cfg = copy.deepcopy(self.activation_checkpointing)
with FUNCTIONS.switch_scope_and_registry(None):
check_fn = cfg.pop('check_fn')
if isinstance(check_fn, str):
check_fn = FUNCTIONS.get(check_fn)
elif isinstance(check_fn, dict):
fn_type = check_fn.pop('type')
if isinstance(fn_type, str):
fn_type = FUNCTIONS.get(fn_type)
check_fn = partial(fn_type, **cfg)
if not callable(check_fn):
raise TypeError('`check_fn` must be a callable function')
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
return model return model
def _is_full_state_dict(self): def _is_full_state_dict(self):

View File

@ -146,6 +146,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
'`cpu_offload` should be `None`, `bool`' '`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}') f'or `CPUOffload`, but has type {type(cpu_offload)}')
with FUNCTIONS.switch_scope_and_registry(None):
if isinstance(auto_wrap_policy, str): if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy) auto_wrap_policy)
@ -172,7 +173,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
raise TypeError( raise TypeError(
'`backward_prefetch` should be `None`, string of ' '`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or ' '"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}') f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
)
if isinstance(param_init_fn, str): if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore param_init_fn = FUNCTIONS.get( # type: ignore