mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support using gradient checkpointing in FSDP (#1382)
This commit is contained in:
parent
bf30c444de
commit
8015d62202
@ -5,6 +5,7 @@ import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch.nn as nn
|
||||
@ -25,7 +26,7 @@ from mmengine.model import BaseDataPreprocessor, is_model_wrapper
|
||||
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
|
||||
OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
|
||||
from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, STRATEGIES, Registry)
|
||||
from mmengine.utils import get_git_hash, mkdir_or_exist
|
||||
from .distributed import DDPStrategy
|
||||
@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
|
||||
:meth:`setup_env`. Defaults to None.
|
||||
- log_kwargs (dict, optional): Logger config passed in
|
||||
:meth:`build_logger`. Defaults to None.
|
||||
activation_checkpointing (dict, optional): Config dict for gradient
|
||||
checkpoint.
|
||||
|
||||
Examples:
|
||||
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
|
||||
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
|
||||
|
||||
|
||||
``check_fn`` field should behave consistently with
|
||||
``auto_wrap_policy`` defined in `model_wrapper`, and other
|
||||
fields will be passed to ``apply_activation_checkpointing``
|
||||
|
||||
`New in version 0.9.0.`
|
||||
|
||||
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
|
||||
""" # noqa: E501
|
||||
@ -100,6 +114,7 @@ class FSDPStrategy(DDPStrategy):
|
||||
model_wrapper: Optional[dict] = None,
|
||||
skip_init_weights=False,
|
||||
state_dict_cfg: Union[str, dict] = 'local',
|
||||
activation_checkpointing: Optional[dict] = None,
|
||||
**kwargs):
|
||||
super().__init__(model_wrapper=model_wrapper, **kwargs)
|
||||
self._init_state_dict_cfg(state_dict_cfg)
|
||||
@ -107,6 +122,7 @@ class FSDPStrategy(DDPStrategy):
|
||||
raise TypeError('skip_init_weights must be a boolean, but got '
|
||||
f'{type(skip_init_weights)}')
|
||||
self.skip_init_weights = skip_init_weights
|
||||
self.activation_checkpointing = activation_checkpointing
|
||||
|
||||
def _wrap_model(self, model: nn.Module) -> None:
|
||||
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
|
||||
@ -119,6 +135,12 @@ class FSDPStrategy(DDPStrategy):
|
||||
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
|
||||
or subclass of ``FullyShardedDataParallel``.
|
||||
"""
|
||||
try:
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
|
||||
apply_activation_checkpointing # noqa: E501
|
||||
except ImportError:
|
||||
apply_activation_checkpointing = None
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseDataPreprocessor):
|
||||
module.to(get_device())
|
||||
@ -138,6 +160,27 @@ class FSDPStrategy(DDPStrategy):
|
||||
model.set_state_dict_type(model, self.state_dict_type,
|
||||
self.state_dict_config,
|
||||
self.optim_state_dict_config)
|
||||
|
||||
if self.activation_checkpointing is not None:
|
||||
if apply_activation_checkpointing is None:
|
||||
raise RuntimeError(
|
||||
'activation_checkpointing maybe deprecated by current '
|
||||
'PyTorch version, maybe you could switch to PyTorch 2.0 '
|
||||
'or 2.1 to use `activation_checkpointing`.')
|
||||
cfg = copy.deepcopy(self.activation_checkpointing)
|
||||
with FUNCTIONS.switch_scope_and_registry(None):
|
||||
check_fn = cfg.pop('check_fn')
|
||||
if isinstance(check_fn, str):
|
||||
check_fn = FUNCTIONS.get(check_fn)
|
||||
elif isinstance(check_fn, dict):
|
||||
fn_type = check_fn.pop('type')
|
||||
if isinstance(fn_type, str):
|
||||
fn_type = FUNCTIONS.get(fn_type)
|
||||
check_fn = partial(fn_type, **cfg)
|
||||
|
||||
if not callable(check_fn):
|
||||
raise TypeError('`check_fn` must be a callable function')
|
||||
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
|
||||
return model
|
||||
|
||||
def _is_full_state_dict(self):
|
||||
|
@ -146,51 +146,53 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
'`cpu_offload` should be `None`, `bool`'
|
||||
f'or `CPUOffload`, but has type {type(cpu_offload)}')
|
||||
|
||||
if isinstance(auto_wrap_policy, str):
|
||||
auto_wrap_policy = FUNCTIONS.get( # type: ignore
|
||||
auto_wrap_policy)
|
||||
if auto_wrap_policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
elif isinstance(auto_wrap_policy, dict):
|
||||
policy = auto_wrap_policy.pop('type')
|
||||
if isinstance(policy, str):
|
||||
policy = FUNCTIONS.get(policy) # type: ignore
|
||||
if policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
auto_wrap_policy = partial(policy, **auto_wrap_policy)
|
||||
with FUNCTIONS.switch_scope_and_registry(None):
|
||||
if isinstance(auto_wrap_policy, str):
|
||||
auto_wrap_policy = FUNCTIONS.get( # type: ignore
|
||||
auto_wrap_policy)
|
||||
if auto_wrap_policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
elif isinstance(auto_wrap_policy, dict):
|
||||
policy = auto_wrap_policy.pop('type')
|
||||
if isinstance(policy, str):
|
||||
policy = FUNCTIONS.get(policy) # type: ignore
|
||||
if policy is None:
|
||||
raise ValueError('`auto_wrap_policy` is not registered!')
|
||||
auto_wrap_policy = partial(policy, **auto_wrap_policy)
|
||||
|
||||
if not (auto_wrap_policy is None
|
||||
or callable(auto_wrap_policy)): # type: ignore
|
||||
raise TypeError('`auto_wrap_policy` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(auto_wrap_policy)}')
|
||||
if not (auto_wrap_policy is None
|
||||
or callable(auto_wrap_policy)): # type: ignore
|
||||
raise TypeError('`auto_wrap_policy` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(auto_wrap_policy)}')
|
||||
|
||||
if isinstance(backward_prefetch, str):
|
||||
backward_prefetch = BackwardPrefetch[backward_prefetch]
|
||||
if not (isinstance(backward_prefetch, BackwardPrefetch)
|
||||
or backward_prefetch is None):
|
||||
raise TypeError(
|
||||
'`backward_prefetch` should be `None`, string of '
|
||||
'"BACKWARD_PRE" and "BACKWARD_POST", or '
|
||||
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')
|
||||
if isinstance(backward_prefetch, str):
|
||||
backward_prefetch = BackwardPrefetch[backward_prefetch]
|
||||
if not (isinstance(backward_prefetch, BackwardPrefetch)
|
||||
or backward_prefetch is None):
|
||||
raise TypeError(
|
||||
'`backward_prefetch` should be `None`, string of '
|
||||
'"BACKWARD_PRE" and "BACKWARD_POST", or '
|
||||
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
|
||||
)
|
||||
|
||||
if isinstance(param_init_fn, str):
|
||||
param_init_fn = FUNCTIONS.get( # type: ignore
|
||||
param_init_fn)
|
||||
if param_init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
elif isinstance(param_init_fn, dict):
|
||||
init_fn = param_init_fn.pop('type')
|
||||
if isinstance(param_init_fn, str):
|
||||
init_fn = FUNCTIONS.get(init_fn) # type: ignore
|
||||
if init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
param_init_fn = partial(init_fn, **param_init_fn)
|
||||
param_init_fn = FUNCTIONS.get( # type: ignore
|
||||
param_init_fn)
|
||||
if param_init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
elif isinstance(param_init_fn, dict):
|
||||
init_fn = param_init_fn.pop('type')
|
||||
if isinstance(param_init_fn, str):
|
||||
init_fn = FUNCTIONS.get(init_fn) # type: ignore
|
||||
if init_fn is None:
|
||||
raise ValueError('`param_init_fn` is not registered!')
|
||||
param_init_fn = partial(init_fn, **param_init_fn)
|
||||
|
||||
if not (callable(param_init_fn) or param_init_fn is None):
|
||||
raise TypeError('`param_init_fn` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(param_init_fn)}')
|
||||
if not (callable(param_init_fn) or param_init_fn is None):
|
||||
raise TypeError('`param_init_fn` should be a str, a '
|
||||
'callable, a dict or None, but has type '
|
||||
f'{type(param_init_fn)}')
|
||||
|
||||
def parse_dtype(dtype):
|
||||
if dtype is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user