[Feature] Support using gradient checkpointing in FSDP (#1382)

This commit is contained in:
Mashiro 2023-10-09 21:04:55 +08:00 committed by GitHub
parent bf30c444de
commit 8015d62202
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import os
import os.path as osp
import time
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union
import torch.nn as nn
@ -25,7 +26,7 @@ from mmengine.model import BaseDataPreprocessor, is_model_wrapper
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy
@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
:meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None.
activation_checkpointing (dict, optional): Config dict for gradient
checkpoint.
Examples:
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
``check_fn`` field should behave consistently with
``auto_wrap_policy`` defined in `model_wrapper`, and other
fields will be passed to ``apply_activation_checkpointing``
`New in version 0.9.0.`
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501
@ -100,6 +114,7 @@ class FSDPStrategy(DDPStrategy):
model_wrapper: Optional[dict] = None,
skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local',
activation_checkpointing: Optional[dict] = None,
**kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg)
@ -107,6 +122,7 @@ class FSDPStrategy(DDPStrategy):
raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights
self.activation_checkpointing = activation_checkpointing
def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
@ -119,6 +135,12 @@ class FSDPStrategy(DDPStrategy):
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``.
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
apply_activation_checkpointing # noqa: E501
except ImportError:
apply_activation_checkpointing = None
for module in model.modules():
if isinstance(module, BaseDataPreprocessor):
module.to(get_device())
@ -138,6 +160,27 @@ class FSDPStrategy(DDPStrategy):
model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config,
self.optim_state_dict_config)
if self.activation_checkpointing is not None:
if apply_activation_checkpointing is None:
raise RuntimeError(
'activation_checkpointing maybe deprecated by current '
'PyTorch version, maybe you could switch to PyTorch 2.0 '
'or 2.1 to use `activation_checkpointing`.')
cfg = copy.deepcopy(self.activation_checkpointing)
with FUNCTIONS.switch_scope_and_registry(None):
check_fn = cfg.pop('check_fn')
if isinstance(check_fn, str):
check_fn = FUNCTIONS.get(check_fn)
elif isinstance(check_fn, dict):
fn_type = check_fn.pop('type')
if isinstance(fn_type, str):
fn_type = FUNCTIONS.get(fn_type)
check_fn = partial(fn_type, **cfg)
if not callable(check_fn):
raise TypeError('`check_fn` must be a callable function')
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
return model
def _is_full_state_dict(self):

View File

@ -146,51 +146,53 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
'`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}')
if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)
with FUNCTIONS.switch_scope_and_registry(None):
if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)
if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')
if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')
if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')
if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
)
if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)
if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')
if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')
def parse_dtype(dtype):
if dtype is None: