[Feature] Support using gradient checkpointing in FSDP (#1382)

This commit is contained in:
Mashiro 2023-10-09 21:04:55 +08:00 committed by GitHub
parent bf30c444de
commit 8015d62202
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import os
import os.path as osp import os.path as osp
import time import time
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
import torch.nn as nn import torch.nn as nn
@ -25,7 +26,7 @@ from mmengine.model import BaseDataPreprocessor, is_model_wrapper
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper) build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry) PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy from .distributed import DDPStrategy
@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
:meth:`setup_env`. Defaults to None. :meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in - log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None. :meth:`build_logger`. Defaults to None.
activation_checkpointing (dict, optional): Config dict for gradient
checkpoint.
Examples:
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
``check_fn`` field should behave consistently with
``auto_wrap_policy`` defined in `model_wrapper`, and other
fields will be passed to ``apply_activation_checkpointing``
`New in version 0.9.0.`
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501 """ # noqa: E501
@ -100,6 +114,7 @@ class FSDPStrategy(DDPStrategy):
model_wrapper: Optional[dict] = None, model_wrapper: Optional[dict] = None,
skip_init_weights=False, skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local', state_dict_cfg: Union[str, dict] = 'local',
activation_checkpointing: Optional[dict] = None,
**kwargs): **kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs) super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg) self._init_state_dict_cfg(state_dict_cfg)
@ -107,6 +122,7 @@ class FSDPStrategy(DDPStrategy):
raise TypeError('skip_init_weights must be a boolean, but got ' raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}') f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights self.skip_init_weights = skip_init_weights
self.activation_checkpointing = activation_checkpointing
def _wrap_model(self, model: nn.Module) -> None: def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
@ -119,6 +135,12 @@ class FSDPStrategy(DDPStrategy):
FullyShardedDataParallel: ``MMFullyShardedDataParallel`` FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``. or subclass of ``FullyShardedDataParallel``.
""" """
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
apply_activation_checkpointing # noqa: E501
except ImportError:
apply_activation_checkpointing = None
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseDataPreprocessor): if isinstance(module, BaseDataPreprocessor):
module.to(get_device()) module.to(get_device())
@ -138,6 +160,27 @@ class FSDPStrategy(DDPStrategy):
model.set_state_dict_type(model, self.state_dict_type, model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config, self.state_dict_config,
self.optim_state_dict_config) self.optim_state_dict_config)
if self.activation_checkpointing is not None:
if apply_activation_checkpointing is None:
raise RuntimeError(
'activation_checkpointing maybe deprecated by current '
'PyTorch version, maybe you could switch to PyTorch 2.0 '
'or 2.1 to use `activation_checkpointing`.')
cfg = copy.deepcopy(self.activation_checkpointing)
with FUNCTIONS.switch_scope_and_registry(None):
check_fn = cfg.pop('check_fn')
if isinstance(check_fn, str):
check_fn = FUNCTIONS.get(check_fn)
elif isinstance(check_fn, dict):
fn_type = check_fn.pop('type')
if isinstance(fn_type, str):
fn_type = FUNCTIONS.get(fn_type)
check_fn = partial(fn_type, **cfg)
if not callable(check_fn):
raise TypeError('`check_fn` must be a callable function')
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
return model return model
def _is_full_state_dict(self): def _is_full_state_dict(self):

View File

@ -146,51 +146,53 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
'`cpu_offload` should be `None`, `bool`' '`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}') f'or `CPUOffload`, but has type {type(cpu_offload)}')
if isinstance(auto_wrap_policy, str): with FUNCTIONS.switch_scope_and_registry(None):
auto_wrap_policy = FUNCTIONS.get( # type: ignore if isinstance(auto_wrap_policy, str):
auto_wrap_policy) auto_wrap_policy = FUNCTIONS.get( # type: ignore
if auto_wrap_policy is None: auto_wrap_policy)
raise ValueError('`auto_wrap_policy` is not registered!') if auto_wrap_policy is None:
elif isinstance(auto_wrap_policy, dict): raise ValueError('`auto_wrap_policy` is not registered!')
policy = auto_wrap_policy.pop('type') elif isinstance(auto_wrap_policy, dict):
if isinstance(policy, str): policy = auto_wrap_policy.pop('type')
policy = FUNCTIONS.get(policy) # type: ignore if isinstance(policy, str):
if policy is None: policy = FUNCTIONS.get(policy) # type: ignore
raise ValueError('`auto_wrap_policy` is not registered!') if policy is None:
auto_wrap_policy = partial(policy, **auto_wrap_policy) raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)
if not (auto_wrap_policy is None if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a ' raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type ' 'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}') f'{type(auto_wrap_policy)}')
if isinstance(backward_prefetch, str): if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch] backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch) if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None): or backward_prefetch is None):
raise TypeError( raise TypeError(
'`backward_prefetch` should be `None`, string of ' '`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or ' '"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}') f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
)
if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str): if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore param_init_fn = FUNCTIONS.get( # type: ignore
if init_fn is None: param_init_fn)
raise ValueError('`param_init_fn` is not registered!') if param_init_fn is None:
param_init_fn = partial(init_fn, **param_init_fn) raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)
if not (callable(param_init_fn) or param_init_fn is None): if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a ' raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type ' 'callable, a dict or None, but has type '
f'{type(param_init_fn)}') f'{type(param_init_fn)}')
def parse_dtype(dtype): def parse_dtype(dtype):
if dtype is None: if dtype is None: