From 399f76ffa8e3459e87f0f0cfbbe1e09cfc25bc7a Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 28 Jun 2023 16:50:52 +0800 Subject: [PATCH] [Experimental] Add support for FSDP (#1213) --- docs/en/api/strategy.rst | 1 + docs/zh_cn/api/strategy.rst | 1 + ...stributed_training_with_flexible_runner.py | 28 +- mmengine/_strategy/__init__.py | 10 +- mmengine/_strategy/base.py | 18 +- mmengine/_strategy/deepspeed.py | 38 +- mmengine/_strategy/distributed.py | 17 +- mmengine/_strategy/fsdp.py | 595 ++++++++++++++++++ mmengine/_strategy/single_device.py | 22 +- mmengine/_strategy/utils.py | 17 + mmengine/model/__init__.py | 2 +- mmengine/model/wrappers/__init__.py | 2 +- .../wrappers/fully_sharded_distributed.py | 364 ++++++++--- .../optim/optimizer/amp_optimizer_wrapper.py | 23 +- mmengine/runner/_flexible_runner.py | 57 +- mmengine/testing/runner_test_case.py | 4 +- .../test_wrappers/test_model_wrapper.py | 31 +- tests/test_strategies/test_fsdp.py | 231 +++++++ 18 files changed, 1317 insertions(+), 144 deletions(-) create mode 100644 mmengine/_strategy/fsdp.py create mode 100644 mmengine/_strategy/utils.py create mode 100644 tests/test_strategies/test_fsdp.py diff --git a/docs/en/api/strategy.rst b/docs/en/api/strategy.rst index f96eeba5..24238577 100644 --- a/docs/en/api/strategy.rst +++ b/docs/en/api/strategy.rst @@ -15,3 +15,4 @@ mmengine._strategy SingleDeviceStrategy DDPStrategy DeepSpeedStrategy + FSDPStrategy diff --git a/docs/zh_cn/api/strategy.rst b/docs/zh_cn/api/strategy.rst index f96eeba5..24238577 100644 --- a/docs/zh_cn/api/strategy.rst +++ b/docs/zh_cn/api/strategy.rst @@ -15,3 +15,4 @@ mmengine._strategy SingleDeviceStrategy DDPStrategy DeepSpeedStrategy + FSDPStrategy diff --git a/examples/distributed_training_with_flexible_runner.py b/examples/distributed_training_with_flexible_runner.py index d14f1eb7..8a824c33 100644 --- a/examples/distributed_training_with_flexible_runner.py +++ b/examples/distributed_training_with_flexible_runner.py @@ -4,7 +4,6 @@ import argparse import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -from torch.optim import SGD from mmengine.evaluator import BaseMetric from mmengine.model import BaseModel @@ -43,8 +42,8 @@ class Accuracy(BaseMetric): def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + parser.add_argument('--use-fsdp', action='store_true') parser.add_argument('--use-deepspeed', action='store_true') - args = parser.parse_args() return args @@ -94,20 +93,33 @@ def main(): ), inputs_to_half=[0], zero_optimization=dict( - stage=0, + stage=3, allgather_partitions=True, reduce_scatter=True, allgather_bucket_size=50000000, reduce_bucket_size=50000000, overlap_comm=True, contiguous_gradients=True, - cpu_offload=False)) + cpu_offload=False), + ) optim_wrapper = dict( type='DeepSpeedOptimWrapper', - optimizer=dict(type=SGD, lr=0.001, momentum=0.9)) + optimizer=dict(type='AdamW', lr=1e-3)) + elif args.use_fsdp: + from functools import partial + + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + size_based_auto_wrap_policy = partial( + size_based_auto_wrap_policy, min_num_params=1e7) + strategy = dict( + type='FSDPStrategy', + model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy)) + optim_wrapper = dict( + type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) else: strategy = None - optim_wrapper = dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)) + optim_wrapper = dict( + type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) runner = FlexibleRunner( model=MMResNet50(), @@ -124,4 +136,8 @@ def main(): if __name__ == '__main__': + # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-fsdp # noqa: 501 + # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py --use-deepspeed # noqa: 501 + # torchrun --nproc-per-node 2 distributed_training_with_flexible_runner.py + # python distributed_training_with_flexible_runner.py main() diff --git a/mmengine/_strategy/__init__.py b/mmengine/_strategy/__init__.py index 0610d6e7..0b201aef 100644 --- a/mmengine/_strategy/__init__.py +++ b/mmengine/_strategy/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.utils import is_installed +from mmengine.utils import digit_version, is_installed +from mmengine.utils.dl_utils import TORCH_VERSION from .base import BaseStrategy from .distributed import DDPStrategy from .single_device import SingleDeviceStrategy @@ -9,3 +10,10 @@ __all__ = ['BaseStrategy', 'DDPStrategy', 'SingleDeviceStrategy'] if is_installed('deepspeed'): from .deepspeed import DeepSpeedStrategy # noqa: F401 __all__.append('DeepSpeedStrategy') + +if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): + try: + from .fsdp import FSDPStrategy # noqa:F401 + __all__.append('FSDPStrategy') + except: # noqa: E722 + pass diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 08508473..bb4cf2df 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -91,6 +91,7 @@ class BaseStrategy(metaclass=ABCMeta): self._auto_scale_lr = auto_scale_lr self.dispatch_kwargs: dict = {} + self._prepared = False @property def work_dir(self): @@ -342,7 +343,8 @@ class BaseStrategy(metaclass=ABCMeta): def _init_model_weights(self, model: nn.Module) -> nn.Module: """Initialize the model weights if the model has :meth:`init_weights`""" - if hasattr(model, 'init_weights'): + if (hasattr(model, 'init_weights') and self.dispatch_kwargs.get( + 'init_weights_for_test_or_val', True)): model.init_weights() # sync params and buffers for _, params in model.state_dict().items(): @@ -633,9 +635,9 @@ class BaseStrategy(metaclass=ABCMeta): """ if default_args is None: default_args = {} - if 'num_batches_per_epoch' in self.dispatch_kwargs: + if 'epoch_length' in self.dispatch_kwargs: default_args['epoch_length'] = self.dispatch_kwargs[ - 'num_batches_per_epoch'] + 'epoch_length'] if 'max_epochs' in self.dispatch_kwargs: default_args['max_epochs'] = self.dispatch_kwargs['max_epochs'] if 'max_iters' in self.dispatch_kwargs: @@ -962,3 +964,13 @@ class BaseStrategy(metaclass=ABCMeta): runtime_env['GPU number'] = self.world_size return system_env, runtime_env + + def _prepared_components(self): + return_items = [self.model] + if hasattr(self, 'optim_wrapper'): + return_items.append(self.optim_wrapper) + + if hasattr(self, 'param_schedulers'): + return_items.append(self.param_schedulers) + + return return_items[0] if len(return_items) == 1 else return_items diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index a96bd0fc..0c1d1e42 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -64,6 +64,8 @@ class DeepSpeedStrategy(BaseStrategy): amp: Optional[dict] = None, activation_checkpointing: Optional[dict] = None, aio: Optional[dict] = None, + train_micro_batch_size_per_gpu: Optional[int] = None, + gradient_accumulation_steps: int = 1, # disable the log printed by deepseed steps_per_print: int = 10000000000000, # the following args are for BaseStrategy @@ -86,8 +88,21 @@ class DeepSpeedStrategy(BaseStrategy): if aio is not None: self.config['aio'] = aio - self.config['steps_per_print'] = steps_per_print + if ('train_micro_batch_size_per_gpu' not in self.config + and 'train_batch_size' not in self.config): + assert train_micro_batch_size_per_gpu is not None, ( + '`train_micro_batch_size_per_gpu` or `train_batch_size` ' + 'should be set!') + self.config['train_micro_batch_size_per_gpu'] = \ + train_micro_batch_size_per_gpu + if train_micro_batch_size_per_gpu is not None: + self.config['train_micro_batch_size_per_gpu'] = \ + train_micro_batch_size_per_gpu + + self.config['gradient_accumulation_steps'] = \ + gradient_accumulation_steps + self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half def _parse_config(self, config): @@ -145,11 +160,11 @@ class DeepSpeedStrategy(BaseStrategy): dispatch_kwargs (dict, optional): Kwargs to be passed to other methods of Strategy. Defaults to None. """ + if self._prepared: + return self._prepared_components() assert dispatch_kwargs is not None self.dispatch_kwargs.update(dispatch_kwargs) - return_items = [] - model = self.build_model(model) model = self._init_model_weights(model) @@ -159,23 +174,16 @@ class DeepSpeedStrategy(BaseStrategy): self.optim_wrapper.model = self.model # type: ignore - return_items.append(self.model) - return_items.append(self.optim_wrapper) else: self.model = self._wrap_model(model) - return_items.append(self.model) if param_scheduler is not None: self.param_schedulers = self.build_param_scheduler( param_scheduler, self.optim_wrapper) - return_items.append(self.param_schedulers) - - return return_items[0] if len(return_items) == 1 else return_items + self._prepared = True + return self._prepared_components() def _wrap_model(self, model: nn.Module) -> nn.Module: - self.config['train_micro_batch_size_per_gpu'] = self.dispatch_kwargs[ - 'train_micro_batch_size_per_gpu'] - if hasattr(self, 'optim_wrapper'): engine, self.optim_wrapper.optimizer, *_ = deepspeed.initialize( model=model, @@ -246,7 +254,7 @@ class DeepSpeedStrategy(BaseStrategy): if resume_optimizer: self.load_optim_state_dict(extra_ckpt.pop('optim_wrapper')) - if resume_param_scheduler: + if resume_param_scheduler and hasattr(self, 'param_schedulers'): param_schedulers = extra_ckpt.pop('param_schedulers') self.load_scheduler_state_dict(param_schedulers) @@ -297,12 +305,12 @@ class DeepSpeedStrategy(BaseStrategy): mmengine=mmengine.__version__ + get_git_hash(), ) - if save_optimizer: + if save_optimizer and hasattr(self, 'optim_wrapper'): # The key can not be 'optimizer', otherwise error will be thrown # when loading or resuming checkpoint. extra_ckpt['optim_wrapper'] = self.optim_state_dict() - if save_param_scheduler: + if save_param_scheduler and hasattr(self, 'param_schedulers'): extra_ckpt['param_schedulers'] = self.scheduler_state_dict() dirname, basename = osp.split(filename) diff --git a/mmengine/_strategy/distributed.py b/mmengine/_strategy/distributed.py index 9c6f614f..6c969b85 100644 --- a/mmengine/_strategy/distributed.py +++ b/mmengine/_strategy/distributed.py @@ -31,11 +31,6 @@ class DDPStrategy(SingleDeviceStrategy): **kwargs, ): super().__init__(**kwargs) - if model_wrapper is None: - # set broadcast_buffers as False to keep compatibility with - # OpenMMLab repos - model_wrapper = dict( - type='MMDistributedDataParallel', broadcast_buffers=False) self.model_wrapper = model_wrapper self.sync_bn = sync_bn @@ -95,8 +90,16 @@ class DDPStrategy(SingleDeviceStrategy): model = self.convert_model(model) - default_args = dict(module=model) - default_args.setdefault('device_ids', [int(os.environ['LOCAL_RANK'])]) + if self.model_wrapper is None: + # set broadcast_buffers as False to keep compatibility with + # OpenMMLab repos + self.model_wrapper = dict( + type='MMDistributedDataParallel', broadcast_buffers=False) + + default_args = dict( + type='MMDistributedDataParallel', + module=model, + device_ids=[int(os.environ['LOCAL_RANK'])]) model = MODEL_WRAPPERS.build( self.model_wrapper, default_args=default_args) return model diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py new file mode 100644 index 00000000..1f856ed7 --- /dev/null +++ b/mmengine/_strategy/fsdp.py @@ -0,0 +1,595 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import os +import os.path as osp +import time +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Sequence, Union + +import torch.nn as nn +from torch.distributed.fsdp import (FullStateDictConfig, + FullyShardedDataParallel, + LocalStateDictConfig, StateDictType) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, LocalOptimStateDictConfig, OptimStateDictConfig, + StateDictConfig) +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +import mmengine +from mmengine.config import Config, ConfigDict +from mmengine.dist import get_rank, is_main_process +from mmengine.model import is_model_wrapper +from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper, + OptimWrapperDict, _ParamScheduler, + build_optim_wrapper) +from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, + PARAM_SCHEDULERS, STRATEGIES, Registry) +from mmengine.utils import get_git_hash, mkdir_or_exist +from .distributed import DDPStrategy +from .utils import MetaTensorContext + +FSDP = FullyShardedDataParallel +FSDP_CONFIGS = Registry('fsdp configs') +FSDP_CONFIGS.register_module(module=FullOptimStateDictConfig) +FSDP_CONFIGS.register_module(module=LocalOptimStateDictConfig) +FSDP_CONFIGS.register_module(module=FullStateDictConfig) +FSDP_CONFIGS.register_module(module=LocalStateDictConfig) + + +@STRATEGIES.register_module() +class FSDPStrategy(DDPStrategy): + """Support training model with FullyShardedDataParallel (FSDP). + + Keyword Args:: + model_wrapper (dict, optional): Config dict for model wrapper. The + default configuration is: + + Examples: + >>> model_wrapper = dict( + >>> type='MMFullyShardedDataParallel', + >>> use_orig_params=True, + >>> ) + + See more configurable arguments in + :class:`MMFullyShardedDataParallel`. Defaults to None + skip_init_weights (bool, optional): Whether to skip initialization of + weights. Defaults to False. This is useful when the parameters of + the large model are loaded from a checkpoint, since skipping the + initialization of weights can save a lot of time. + state_dict_cfg (str or dict): Configuration for + how to save and load the state dict of the model, optimizer, and + scheduler. + + - "local": save and load the sharded state dict in all ranks. + - "full": save and load the full state dict in rank 0. + - `dict` object: save and load the state dict more flexibly. For + example, you can first offload the state dict to the 'cpu' and + then save it to the disk. This can help you to load the + checkpoint in a non-gpu environment: + + Examples: + >>> state_dict_cfg=dict( + >>> state_dict_type='FULL_STATE_DICT', + >>> state_dict_config=dict(type='FullStateDictConfig', offload_to_cpu=True), + >>> optim_state_dict_config=dict(type='FullOptimStateDictConfig', offload_to_cpu=True), + + See more configurable arguments for ``state_dict_cfg``, + ``state_dict_config``, and ``optim_state_dict_config``in + `FSDP official api documents`_ + kwargs (dict): Additional arguments passed to :class:`DDPStrategy`: + + - work_dir (str): The working directory to save checkpoints. + The logs will be saved in the subdirectory of `work_dir` named + :attr:`timestamp`. Defaults to 'work_dirs'. + - experiment_name (str, optional): Name of current experiment. If + not specified, timestamp will be used as :attr:`experiment_name`. + Defaults to None. + - env_kwargs (dict, optional): Environment config passed in + :meth:`setup_env`. Defaults to None. + - log_kwargs (dict, optional): Logger config passed in + :meth:`build_logger`. Defaults to None. + + .. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type + """ # noqa: E501 + + def __init__(self, + *, + model_wrapper: Optional[dict] = None, + skip_init_weights=False, + state_dict_cfg: Union[str, dict] = 'local', + **kwargs): + super().__init__(model_wrapper=model_wrapper, **kwargs) + self._init_state_dict_cfg(state_dict_cfg) + if not isinstance(skip_init_weights, bool): + raise TypeError('skip_init_weights must be a boolean, but got ' + f'{type(skip_init_weights)}') + self.skip_init_weights = skip_init_weights + + def _wrap_model(self, model: nn.Module) -> None: + """Wrap the model to :obj:``MMFullyShardedDataParallel`` or other + custom fully sharded data parallel module wrappers. + + Args: + model (nn.Module): Model to be wrapped. + + Returns: + FullyShardedDataParallel: ``MMFullyShardedDataParallel`` + or subclass of ``FullyShardedDataParallel``. + """ + if is_model_wrapper(model): + return + + if self.model_wrapper is None: + self.model_wrapper = dict(type='MMFullyShardedDataParallel') + + default_args = dict( + module=model, + device_id=int(os.environ['LOCAL_RANK']), + type='MMFullyShardedDataParallel') + model = MODEL_WRAPPERS.build( + self.model_wrapper, default_args=default_args) + model.set_state_dict_type(model, self.state_dict_type, + self.state_dict_config, + self.optim_state_dict_config) + return model + + def _is_full_state_dict(self): + """Whether to save and load the full state_dict in rank 0.""" + return self.state_dict_type == StateDictType.FULL_STATE_DICT + + def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: + """Build model. + + If skip_init_weights is True, the model will be built with an empty + weights. It means that :meth:`load_checkpoint` must be called to fill + the weights before training. + + Args: + model (nn.Module or dict): A ``nn.Module`` object or a dict to + build ``nn.Module`` object. If ``model`` is a ``nn.Module`` + object, just returns itself. + + Returns: + nn.Module: Model build from ``model``. + """ + if self.skip_init_weights: + if isinstance(model, dict): + # Accelerate initialization by skipping init weights + with MetaTensorContext(): + model = super().build_model(model) + model.to_empty(device='cpu') + else: + model = super().build_model(model) + + # `id_to_name` will be used to convert the `optim_state_dict` of the + # raw optimizer to the `optim_state_dict` + # returned by `FSDP.optim_state_dict` in + # `StateDictType.FULL_STATE_DICT` mode. + self.id_to_name = dict() + for name, param in model.named_parameters(): + self.id_to_name[id(param)] = name + return model + + def save_checkpoint(self, + filename: str, + *, + save_optimizer: bool = True, + save_param_scheduler: bool = True, + extra_ckpt: Optional[dict] = None, + callback: Optional[Callable] = None) -> None: + """Save checkpoint to given ``filename``. + + If ``state_dict_type`` is `full`, the checkpoint will only be saved in + rank0. The structure of the saved checkpoint is the same as the one + saved by ``DDPStrategy`` + + If ``state_dict_type`` is `local`, each rank will save the sharded + state dict to a directory, which means the saved structure will look + like this: + + .. code-block:: bash + + ── epoch_0.pth + ├── rank0.pth + ├── rank1.pth + ├── ... + └── rank8.pth + + Args: + filename (str): Filename to save checkpoint. + + Keyword Args: + save_optimizer (bool): Whether to save the optimizer to + the checkpoint. Defaults to True. + save_param_scheduler (bool): Whether to save the param_scheduler + to the checkpoint. Defaults to True. + extra_ckpt (dict, optional): Extra checkpoint to save. + Defaults to None. + callback (callable, callable): Callback function to modify the + checkpoint before saving the checkpoint. + Defaults to None. + """ + from mmengine.runner.checkpoint import save_checkpoint + + state_dict: dict = dict() + state_dict['state_dict'] = self.model_state_dict() + + # save optimizer state dict + if save_optimizer and hasattr(self, 'optim_wrapper'): + state_dict['optimizer'] = self.optim_state_dict() + + # save param scheduler state dict + if save_param_scheduler and hasattr(self, 'param_schedulers'): + state_dict['param_schedulers'] = self.scheduler_state_dict() + + # save extra checkpoint passed by users + if extra_ckpt is None: + extra_ckpt = dict() + if 'meta' not in extra_ckpt: + extra_ckpt['meta'] = dict() + + extra_ckpt['meta'].update( + seed=self.seed, + time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), + mmengine=mmengine.__version__ + get_git_hash(), + ) + state_dict.update(extra_ckpt) + + # users can do some modification before saving checkpoint + if callback is not None: + callback(state_dict) + + # In non-FULL_STATE_DICT model, FSDPStrategy will save checkpoint + # of different ranks in different files. + if not self._is_full_state_dict(): + rank = get_rank() + mkdir_or_exist(filename) + ckpt_name = f'rank{rank}.pth' + filename = osp.join(filename, ckpt_name) + save_checkpoint(state_dict, filename) + + if is_main_process(): + save_checkpoint(state_dict, filename) + + def model_state_dict(self) -> dict: + """Get model state dict based on the ``state_dict_type``. + + If ``state_dict_type`` is `full`, the model state dict will be the + same as the one of original unsharded model. + + If ``state_dict_type`` is ``local``, and ``use_orig_params`` is ``True`` + in ``model_wrapper``. The key of the state dict will be the same as + the one of original unsharded model, but its value will be the sharded + one + + If ``state_dict_type`` is `local`, and ```use_orig_params``` is + ``False`` in ``model_wrapper``, the flatten and sharded state dict will + be returned. + + See more details in the `official api documents`_ + + .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict + """ # noqa: E501 + # We've set state_dict by `FSDP.set_state_dict_type`, therefore we + # should get model state dict by `FSDP.state_dict` + return self.model.state_dict() + + def optim_state_dict(self) -> dict: + """Get model state dict based on the ``state_dict_type``. + + If ``state_dict_type`` is ``full``, the optimizer state dict can be + loaded by the original unsharded optimizer. + + Otherwise, the optimizer state dict could only be loaded by the + optimizer with sharded parameters. + + Note: + The optimizer state dict is not the same as the one of original + optimizer even if in ``full`` mode, although they can be loaded + correctly. + + See more details in the `official api documents`_ + + .. _official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.optim_state_dict + """ # noqa: E501 + return FSDP.optim_state_dict(self.model, self.optim_wrapper) + + def load_checkpoint(self, filename: str, **kwargs) -> dict: + """Load checkpoint from given ``filename``. + + Note: + If ``state_dict_type`` is `local`, the filename should be a + directory contains ``rank{i}.pth``. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + + Keyword Args: + map_location (str or callable): A string or a callable function to + specifying how to remap storage locations. + Defaults to 'cpu'. + strict (bool): strict (bool): Whether to allow different params for + the model and checkpoint. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Defaults to strip + the prefix 'module.' by [(r'^module\\.', '')]. + callback (callable, callable): Callback function to modify the + checkpoint after loading the checkpoint. + Defaults to None. + """ + if self._is_full_state_dict(): + return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) + else: + rank = get_rank() + filename = osp.join(filename, f'rank{rank}.pth') + return super(DDPStrategy, self).load_checkpoint(filename, **kwargs) + + def load_model_state_dict( + self, + state_dict: dict, + *, + strict: bool = False, + revise_keys: list = [(r'^module.', '')], + ) -> None: # type: ignore + """Load model state from dict. + + Warning: + `revise_keys` is not supported yet. + + Args: + state_dict (dict): Model state dict returned by + :meth:`FSDPStrategy.model_state_dict`. If ``state_dict_type`` + is ``full``. ``state_dict`` could be the result of + ``model.state_dict()`` + strict (bool): Whether to load model state dict strictly. + Defaults to False. + """ + # We should load state dict by `FSDP.load_state_dict` + self.model.load_state_dict(state_dict, strict=strict) + + def load_optim_state_dict(self, state_dict: dict) -> None: + """Load optimizer state from dict. + + Args: + state_dict (dict): The optimizer state dict. If ``state_dict_type`` + is ``full``. ``state_dict`` could be the result of + ``optimizer.state_dict()`` + """ + optim_state_dict = FSDP.optim_state_dict_to_load( + state_dict, self.model, self.optim_wrapper.optimizer) + self.optim_wrapper.load_state_dict(optim_state_dict) + + def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: + """Make ``state_dict_type`` and ``state_dict_config`` can be configured + with string.""" + if isinstance(state_dict_cfg, str): + if state_dict_cfg == 'full': + self.state_dict_type = StateDictType.FULL_STATE_DICT + self.state_dict_config = FullStateDictConfig( + rank0_only=True, offload_to_cpu=True) + self.optim_state_dict_config = FullOptimStateDictConfig( + rank0_only=True, offload_to_cpu=True) + elif state_dict_cfg == 'local': + self.state_dict_type = StateDictType.LOCAL_STATE_DICT + self.state_dict_config = LocalStateDictConfig() + self.optim_state_dict_config = LocalOptimStateDictConfig() + else: + raise ValueError('FSDP only supports `full` and `local` ' + f'state_dict_type, but got {state_dict_cfg}') + elif isinstance(state_dict_cfg, dict): + if 'state_dict_type' not in state_dict_cfg: + self.state_dict_type = StateDictType.LOCAL_STATE_DICT + else: + state_dict_type = state_dict_cfg['state_dict_type'] + if isinstance(state_dict_type, str): + self.state_dict_type = StateDictType[ + state_dict_cfg['state_dict_type']] + else: + self.state_dict_type = state_dict_type + state_dict_config = state_dict_cfg.get('state_dict_config') + if state_dict_config is None: + self.state_dict_config = LocalStateDictConfig() + elif isinstance(state_dict_config, dict): + self.state_dict_config = FSDP_CONFIGS.build( + state_dict_cfg['state_dict_config']) + else: + self.state_dict_config = state_dict_config + + optim_state_dict_config = state_dict_cfg.get( + 'optim_state_dict_config') + if optim_state_dict_config is None: + self.optim_state_dict_config = LocalOptimStateDictConfig() + elif isinstance(optim_state_dict_config, dict): + self.optim_state_dict_config = FSDP_CONFIGS.build( + state_dict_cfg['optim_state_dict_config']) + else: + self.optim_state_dict_config = optim_state_dict_config + else: + raise TypeError('state_dict_cfg should be a `str` or a `dict`, ' + f'but got {type(state_dict_cfg)}') + + if not isinstance(self.state_dict_type, StateDictType): + raise TypeError('state_dict_type must be StateDictType, but got ' + f'{type(self.state_dict_type)}') + if not isinstance(self.state_dict_config, StateDictConfig): + raise TypeError('state_dict_config must be StateDictConfig, but ' + f'got {type(self.state_dict_config)}') + if not isinstance(self.optim_state_dict_config, OptimStateDictConfig): + raise TypeError('optim_state_dict_config must be ' + 'OptimStateDictConfig, but got ' + f'{type(self.optim_state_dict_config)}') + + def build_optim_wrapper( + self, + optim_wrapper: Union[Optimizer, OptimWrapper, dict], + model: Optional[nn.Module] = None, + ) -> BaseOptimWrapper: + """Support sharding the optimizer state dict given a built optimizer or + optim_wrapper. + + See specific usage in :meth:`BaseStrategy.build_optim_wrapper`. + """ + if isinstance(optim_wrapper, Optimizer): + optim_wrapper = OptimWrapper(optim_wrapper) + if isinstance(optim_wrapper, BaseOptimWrapper): + assert model is not None + # NOTE: The only difference is that FSDPStrategy will shard + # the the built OptimWrapper + optimizer = optim_wrapper.optimizer + param_groups = optimizer.param_groups + optim_state_dict = optimizer.state_dict() + assert not optim_state_dict['state'], ( + 'Optimizer state_dict should be empty when giving an built ' + 'optim_wrapper to FSDPStrategy') + # Align the state_dict with state_dict generated by + # FSDP.full_optim_state_dict + new_param_groups = [] + for group in param_groups: + new_group = { + key: value + for key, value in group.items() if key != 'param' + } + new_group['params'] = [ + self.id_to_name[id(param)] for param in group['params'] + ] + new_param_groups.append(new_group) + optim_state_dict['param_groups'] = new_param_groups + defaults = { + k: v + for k, v in optimizer.defaults.items() if k != 'differentiable' + } + + params_dict = {} + for k, v in model.named_parameters(): + if '_fsdp_wrapped_module' in k: + k = k.replace('_fsdp_wrapped_module.', '') + params_dict[k] = v + + params = [] + for param_group in new_param_groups: + _params = [] + for param_name in param_group['params']: + if param_name not in params_dict: + raise RuntimeError( + 'Failed to reconstruct the sharded optimizer. ' + 'You can try to set `use_orig_params=True` in ' + '`model_wrapper`') + _params.append(params_dict[param_name]) + param_group = { + k: v + for k, v in param_group.items() if k != 'param' + } + param_group['params'] = _params + params.append(param_group) + + new_optimizer = optimizer.__class__(params, **defaults) + + # Force to load the converted optim_state_dict in full mode. + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): + optim_state_dict = FSDP.optim_state_dict_to_load( + optim_state_dict, model, new_optimizer) + new_optimizer.load_state_dict(optim_state_dict) + optim_wrapper.optimizer = new_optimizer + return optim_wrapper + if isinstance(optim_wrapper, (dict, ConfigDict, Config)): + assert model is not None + # optimizer must be defined for single optimizer training. + optimizer = optim_wrapper.get('optimizer', None) + optim_wrapper.setdefault('type', 'OptimWrapper') + if optim_wrapper.get('type', + 'AmpOptimWrapper') in ('AmpOptimWrapper', + AmpOptimWrapper): + optim_wrapper.setdefault('use_fsdp', True) + + # If optimizer is a built `Optimizer` instance, the optimizer + # wrapper should be built by `OPTIM_WRAPPERS` registry. + if isinstance(optimizer, Optimizer): + return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore + + # If `optimizer` is not None or `constructor` is defined, it means, + # optimizer wrapper will be built by optimizer wrapper + # constructor. Therefore, `build_optim_wrapper` should be called. + if optimizer is not None or 'constructor' in optim_wrapper: + return build_optim_wrapper(model, optim_wrapper) + else: + # if `optimizer` is not defined, it should be the case of + # training with multiple optimizers. If `constructor` is not + # defined either, each value of `optim_wrapper` must be an + # `OptimWrapper` instance since `DefaultOptimizerConstructor` + # will not handle the case of training with multiple + # optimizers. `build_optim_wrapper` will directly build the + # `OptimWrapperDict` instance from `optim_wrapper.` + optim_wrappers = OrderedDict() + for name, optim in optim_wrapper.items(): + if not isinstance(optim, OptimWrapper): + raise ValueError( + 'each item mush be an optimizer object when ' + '"type" and "constructor" are not in ' + f'optimizer, but got {name}={optim}') + optim_wrappers[name] = optim + return OptimWrapperDict(**optim_wrappers) + else: + raise TypeError('optimizer wrapper should be an OptimWrapper ' + f'object or dict, but got {optim_wrapper}') + + def _build_param_scheduler( + self, + scheduler: Union[_ParamScheduler, Dict, List], + optim_wrapper: BaseOptimWrapper, + default_args: dict, + ) -> List[_ParamScheduler]: + """Override this method to update the scheduler with the reconstructed + sharded optimzer.""" + if not isinstance(scheduler, Sequence): + schedulers = [scheduler] + else: + schedulers = scheduler + + max_epochs = default_args.pop('max_epochs', None) + max_iters = default_args.pop('max_iters', None) + + param_schedulers = [] + for scheduler in schedulers: + # Update the built scheduler with the sharded optimizer + if isinstance(scheduler, (_ParamScheduler, LRScheduler)): + parameter_keys = inspect.signature( + scheduler.__class__).parameters.keys() + kwargs = { + k: v + for k, v in scheduler.state_dict().items() + if k in parameter_keys + } + scheduler = scheduler.__class__(optim_wrapper, **kwargs) + elif isinstance(scheduler, dict): + _scheduler = copy.deepcopy(scheduler) + + # Set default end + if _scheduler.get('by_epoch', True): + if max_epochs is None: + raise ValueError( + 'max_epochs must be specified in default_args') + default_end = max_epochs + else: + if max_iters is None: + raise ValueError( + 'max_iters must be specified in default_args') + default_end = max_iters + _scheduler.setdefault('end', default_end) + self.logger.debug( + f'The `end` of {_scheduler["type"]} is not set. ' + 'Use the max epochs/iters of train loop as default.') + + param_schedulers.append( + PARAM_SCHEDULERS.build( + _scheduler, + default_args=dict( + optimizer=optim_wrapper, **default_args))) + else: + raise TypeError( + 'scheduler should be a _ParamScheduler object or dict, ' + f'but got {scheduler}') + return param_schedulers diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py index 112f3ebd..180ad79c 100644 --- a/mmengine/_strategy/single_device.py +++ b/mmengine/_strategy/single_device.py @@ -49,26 +49,24 @@ class SingleDeviceStrategy(BaseStrategy): If ``accumulative_counts`` is set in ``optim_wrapper``, you need to provide ``max_iters`` in ``dispatch_kwargs``. """ + if self._prepared: + return self._prepared_components() if dispatch_kwargs is not None: self.dispatch_kwargs.update(dispatch_kwargs) - return_items = [] model = self.build_model(model) model = self._init_model_weights(model) model = self._wrap_model(model) model = self.compile_model(model, compile=compile) - return_items.append(model) self.model = model if optim_wrapper is not None: self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - return_items.append(self.optim_wrapper) if param_scheduler is not None: self.param_schedulers = self.build_param_scheduler( param_scheduler, self.optim_wrapper) - return_items.append(self.param_schedulers) if optim_wrapper is not None: self._scale_lr() @@ -84,8 +82,8 @@ class SingleDeviceStrategy(BaseStrategy): self.optim_wrapper.initialize_count_status( # type: ignore self.model, 0, self.dispatch_kwargs['max_iters']) - - return return_items[0] if len(return_items) == 1 else return_items + self._prepared = True + return self._prepared_components() def _wrap_model(self, model: nn.Module) -> nn.Module: model = self.convert_model(model) @@ -200,7 +198,7 @@ class SingleDeviceStrategy(BaseStrategy): if resume_optimizer: self.load_optim_state_dict(checkpoint.pop('optimizer')) - if resume_param_scheduler: + if resume_param_scheduler and hasattr(self, 'param_schedulers'): self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) # resume random seed @@ -267,15 +265,7 @@ class SingleDeviceStrategy(BaseStrategy): if save_optimizer and hasattr(self, 'optim_wrapper'): state_dict['optimizer'] = self.optim_state_dict() - # save param scheduler state dict - if save_param_scheduler and not hasattr(self, 'param_schedulers'): - self.logger.warning( - '`save_param_scheduler` is True but strategy has no ' - 'param_schedulers attribute, so skip saving parameter ' - 'schedulers') - save_param_scheduler = False - - if save_param_scheduler: + if save_param_scheduler and hasattr(self, 'param_schedulers'): state_dict['param_schedulers'] = self.scheduler_state_dict() # save extra checkpoint passed by users diff --git a/mmengine/_strategy/utils.py b/mmengine/_strategy/utils.py new file mode 100644 index 00000000..c691bd60 --- /dev/null +++ b/mmengine/_strategy/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch._subclasses.fake_tensor import _is_tensor_constructor +from torch.utils._python_dispatch import TorchDispatchMode + + +class MetaTensorContext(TorchDispatchMode): + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if _is_tensor_constructor(func): + device_idx = [arg.name + for arg in func._schema.arguments].index('device') + if len(args) > device_idx: + args = list(args) + args[device_idx] = 'meta' + else: + kwargs['device'] = 'meta' + return func(*args, **kwargs) diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index ac567786..033512a9 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -33,6 +33,6 @@ __all__ = [ 'convert_sync_batchnorm', 'BaseTTAModel' ] -if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): +if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): from .wrappers import MMFullyShardedDataParallel # noqa:F401 __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py index 06eacc44..90eddabb 100644 --- a/mmengine/model/wrappers/__init__.py +++ b/mmengine/model/wrappers/__init__.py @@ -10,7 +10,7 @@ __all__ = [ 'MMSeparateDistributedDataParallel' ] -if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): +if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): from .fully_sharded_distributed import \ MMFullyShardedDataParallel # noqa:F401 __all__.append('MMFullyShardedDataParallel') diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index 87780b3b..b4667958 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -1,18 +1,30 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, List, Optional, Union +from functools import partial +from itertools import chain +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup +# yapf: disable +from torch.distributed.fsdp.api import (FullStateDictConfig, + LocalOptimStateDictConfig, + LocalStateDictConfig, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, StateDictConfig, + StateDictSettings, StateDictType) from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, CPUOffload, FullyShardedDataParallel) + BackwardPrefetch, CPUOffload, FullOptimStateDictConfig, + FullyShardedDataParallel, MixedPrecision) +# yapf: enable from mmengine.optim import OptimWrapper -from mmengine.registry import MODEL_WRAPPERS, Registry +from mmengine.registry import FUNCTIONS, MODEL_WRAPPERS from mmengine.structures import BaseDataElement - -# support customize fsdp policy -FSDP_WRAP_POLICIES = Registry('fsdp wrap policy') +from mmengine.utils import digit_version, is_seq_of @MODEL_WRAPPERS.register_module() @@ -40,8 +52,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): Args: module (nn.Module): module to be wrapped with FSDP. - process_group (Optional[ProcessGroup]): process group for sharding. - cpu_offload (Optional[Union[bool,CPUOffload]]): + process_group (ProcessGroup, optional): process group for sharding. + cpu_offload (bool, CPUOffload, optional): CPU offloading config. Different from FullyShardedDataParallel,Since it can be set by users' pre-defined config in MMEngine,its type is expected to be @@ -54,7 +66,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): for params and grads to be on same device to work with optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. - fsdp_auto_wrap_policy: (Optional[Union[str,Callable]]): + auto_wrap_policy (str or Callable, optional): Specifying a policy to recursively wrap layers with FSDP. Different from FullyShardedDataParallel, Since it can be set by users' pre-defined config in MMEngine, its type is expected to be @@ -68,12 +80,12 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): the returned FSDP root instance. ``default_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is an example of - ``fsdp_auto_wrap_policy`` callable, this policy wraps layers with + ``auto_wrap_policy`` callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customized - ``fsdp_auto_wrap_policy`` callable that should accept following + ``auto_wrap_policy`` callable that should accept following arguments: ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, extra customized arguments could be - added to the customized ``fsdp_auto_wrap_policy`` callable as well. + added to the customized ``auto_wrap_policy`` callable as well. Example:: @@ -86,18 +98,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): >>> ) -> bool: >>> return unwrapped_params >= min_num_params - backward_prefetch: (Optional[Union[str,BackwardPrefetch]]): - Different from FullyShardedDataParallel, Since it will be set by - users' pre-defined config in MMEngine,its type is expected to be - `None`, `str` or `BackwardPrefetch`. + backward_prefetch (str or BackwardPrefetch, optional): + Different from FullyShardedDataParallel, this argument could be a + string or a BackwardPrefetch instance. If it's a string, then + it should be ``BACKWARD_PRE`` or ``BACKWARD_POST`` + mixed_precision (dict or MixedPrecision, optional): + This configures native mixed precision for FSDP. If this is set to + ``None``. Different from the native FSDP, this argument can a dict + like this: - This is an experimental feature that is subject to change in the - the near future. It allows users to enable two different - backward_prefetch algorithms to help backward communication and - computation overlapping. - Pros and cons of each algorithm is explained in class - ``BackwardPrefetch``. + Examples: + >>> mixed_precision=dict(param_dtype='float16', + >>> buffer_dtype='float32', + >>> reduce_dtype='float32') + Defaults to None. + use_orig_params (bool): Different from native + ``FullyShardedDataParallel``, it defaults to True. **kwargs: Keyword arguments passed to :class:`FullyShardedDataParallel`. """ @@ -105,61 +122,117 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): def __init__( self, module: nn.Module, - process_group: Optional[ProcessGroup] = None, - cpu_offload: Optional[Union[bool, CPUOffload]] = None, - fsdp_auto_wrap_policy: Optional[Union[str, Callable]] = None, - backward_prefetch: Optional[Union[str, BackwardPrefetch]] = None, + process_group: Union[dict, ProcessGroup, None] = None, + sharding_strategy: Union[str, ShardingStrategy] = None, + cpu_offload: Union[bool, CPUOffload, None] = None, + auto_wrap_policy: Union[str, Callable, None] = None, + backward_prefetch: Union[str, BackwardPrefetch, None] = None, + mixed_precision: Union[dict, MixedPrecision, None] = None, + ignored_modules: Union[Iterable[str], Iterable[nn.Module], + None] = None, + param_init_fn: Union[str, Callable[[nn.Module], None]] = None, + use_orig_params: bool = True, **kwargs, ): + if isinstance(sharding_strategy, str): + sharding_strategy = ShardingStrategy[sharding_strategy] + if not (isinstance(sharding_strategy, ShardingStrategy) + or sharding_strategy is None): + raise TypeError( + 'sharding_strategy must be str or enum of `ShardingStrategy` ' + f', but got {sharding_strategy}') - if cpu_offload is not None: - if isinstance(cpu_offload, bool): - cpu_offload = CPUOffload(offload_params=cpu_offload) - elif not isinstance(cpu_offload, CPUOffload): + if isinstance(cpu_offload, bool): + cpu_offload = CPUOffload(offload_params=cpu_offload) + if not (isinstance(cpu_offload, CPUOffload) or cpu_offload is None): + raise TypeError( + '`cpu_offload` should be `None`, `bool`' + f'or `CPUOffload`, but has type {type(cpu_offload)}') + + if isinstance(auto_wrap_policy, str): + auto_wrap_policy = FUNCTIONS.get( # type: ignore + auto_wrap_policy) + if auto_wrap_policy is None: + raise ValueError('`auto_wrap_policy` is not registered!') + elif isinstance(auto_wrap_policy, dict): + ori_func = FUNCTIONS.get( # type: ignore + auto_wrap_policy.pop('type')) + if auto_wrap_policy is None: + raise ValueError('`auto_wrap_policy` is not registered!') + auto_wrap_policy = partial(ori_func, **auto_wrap_policy) + + if not (auto_wrap_policy is None + or callable(auto_wrap_policy)): # type: ignore + raise TypeError('`auto_wrap_policy` should be a str, a ' + 'callable, a dict or None, but has type ' + f'{type(auto_wrap_policy)}') + + if isinstance(backward_prefetch, str): + backward_prefetch = BackwardPrefetch[backward_prefetch] + if not (isinstance(backward_prefetch, BackwardPrefetch) + or backward_prefetch is None): + raise TypeError( + '`backward_prefetch` should be `None`, string of ' + '"BACKWARD_PRE" and "BACKWARD_POST", or ' + f'`BackwardPrefetch`, but has type {type(backward_prefetch)}') + + if isinstance(param_init_fn, str): + param_init_fn = FUNCTIONS.get( # type: ignore + param_init_fn) + if param_init_fn is None: + raise ValueError('`param_init_fn` is not registered!') + elif isinstance(param_init_fn, dict): + param_init_fn = FUNCTIONS.get(param_init_fn.pop('type')) + if param_init_fn is None: + raise ValueError('`param_init_fn` is not registered!') + param_init_fn = partial(param_init_fn, **param_init_fn) + + if not (callable(param_init_fn) or param_init_fn is None): + raise TypeError('`param_init_fn` should be a str, a ' + 'callable, a dict or None, but has type ' + f'{type(param_init_fn)}') + + def parse_dtype(dtype): + if dtype is None: + return None + elif isinstance(dtype, str): + return getattr(torch, dtype) + elif isinstance(dtype, torch.dtype): + return dtype + else: raise TypeError( - '`cpu_offload` should be `None`, `bool`' - f'or `CPUOffload`, but has type {type(cpu_offload)}') + '`dtype` should be `None`, `str` or `torch.dtype`, ' + f'but has type {type(dtype)}') - if fsdp_auto_wrap_policy is not None: - if isinstance(fsdp_auto_wrap_policy, str): - assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \ - '`FSDP_WRAP_POLICIES` has no ' \ - f'function {fsdp_auto_wrap_policy}' - fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore - fsdp_auto_wrap_policy) - if not isinstance(fsdp_auto_wrap_policy, - Callable): # type: ignore - raise TypeError( - 'Registered `fsdp_auto_wrap_policy` needs to be ' - '`Callable`, but has type ' - f'{type(fsdp_auto_wrap_policy)}') - elif not isinstance(fsdp_auto_wrap_policy, - Callable): # type: ignore - raise TypeError( - '`fsdp_auto_wrap_policy` should be `None`, `str` ' - 'or `Callable`, but has type ' - f'{type(fsdp_auto_wrap_policy)}') - - if backward_prefetch is not None: - if isinstance(backward_prefetch, str): - assert backward_prefetch in ['pre', 'post'], \ - '`backward_prefetch` should be either `pre` or `post`,' \ - f' but get {backward_prefetch}' - if backward_prefetch == 'pre': - backward_prefetch = BackwardPrefetch.BACKWARD_PRE - else: - backward_prefetch = BackwardPrefetch.BACKWARD_POST - elif not isinstance(backward_prefetch, BackwardPrefetch): - raise TypeError('`backward_prefetch` should be `None`, `str` ' - 'or `BackwardPrefetch`, but has type ' - f'{type(backward_prefetch)}') + if isinstance(mixed_precision, dict): + mixed_precision['param_dtype'] = parse_dtype( + mixed_precision.get('param_dtype', None)) + mixed_precision['reduce_dtype'] = parse_dtype( + mixed_precision.get('reduce_dtype', None)) + mixed_precision['buffer_dtype'] = parse_dtype( + mixed_precision.get('buffer_dtype', None)) + mixed_precision = MixedPrecision(**mixed_precision) + elif isinstance(mixed_precision, MixedPrecision): + mixed_precision = mixed_precision + elif mixed_precision is not None: + raise TypeError( + '`mixed_precision` should be `None`, `dict` or ' + f'`MixedPrecision`, but has type {type(mixed_precision)}') + self._fixed_modules = self._get_fixed_module(module, ignored_modules) + ignored_modules = [] if ignored_modules is None else ignored_modules + ignored_modules = chain(ignored_modules, self._fixed_modules) super().__init__( module=module, process_group=process_group, - auto_wrap_policy=fsdp_auto_wrap_policy, + sharding_strategy=sharding_strategy, + auto_wrap_policy=auto_wrap_policy, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + use_orig_params=use_orig_params, **kwargs) def train_step(self, data: dict, @@ -207,8 +280,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): Returns: List[BaseDataElement] or dict: The predictions of given data. """ - inputs, data_sample = self.module.data_preprocessor(data, False) - return self(inputs, data_sample, mode='predict') + data = self.module.data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore def test_step(self, data: dict) -> List[BaseDataElement]: """Gets the predictions of module during testing process. @@ -219,5 +292,156 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): Returns: List[BaseDataElement]: The predictions of given data. """ - inputs, data_sample = self.module.data_preprocessor(data, False) - return self(inputs, data_sample, mode='predict') + data = self.module.data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore + + def _run_forward(self, data: Union[dict, tuple, list], + mode: str) -> Union[Dict[str, torch.Tensor], list]: + """Unpacks data for :meth:`forward` + Args: + data (dict or tuple or list): Data sampled from dataset. + mode (str): Mode of forward. + Returns: + dict or list: Results of training or testing mode. + """ + if isinstance(data, dict): + results = self(**data, mode=mode) + elif isinstance(data, (list, tuple)): + results = self(*data, mode=mode) + else: + raise TypeError('Output of `data_preprocessor` should be ' + f'list, tuple or dict, but got {type(data)}') + return results + + def _get_fixed_module(self, module, ignored_modules): + module_dict = dict(module.named_modules()) + if is_seq_of(ignored_modules, str): + ignored_modules = [module_dict[name] for name in ignored_modules] + if not is_seq_of(ignored_modules, + nn.Module) and ignored_modules is not None: + raise TypeError( + '`ignored_modules` should be `None`, `Iterable[str]` or ' + f'`Iterable[nn.Module]`, but has type {type(ignored_modules)}') + + def find_fixed_modules_recursively( + root_module: nn.Module) -> List[nn.Module]: + """Helper function to find fixed modules whose parameters are all + untrainable, i.e. `requires_grad=False`. + + This function performs + recursively. + Args: + root_module (nn.Module): root module for recursion + Returns: + List[nn.Module]: fixed modules in root_module + """ + if all(p.requires_grad for p in root_module.parameters()): + return [] + if all(not p.requires_grad for p in root_module.parameters()): + return [root_module] + fixed_modules = [] + for sub_module in root_module.children(): + fixed_modules.extend( + find_fixed_modules_recursively(sub_module)) + return fixed_modules + + fixed_modules = find_fixed_modules_recursively(module) + return fixed_modules + + if digit_version(torch.__version__) < digit_version('2.0.1'): + + @staticmethod + def optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + group: Optional[dist.ProcessGroup] = None, + ) -> Dict[str, Any]: + """copied from pytorch 2.0.1 which has fixed some bugs.""" + state_dict_settings = FullyShardedDataParallel.get_state_dict_type( + model) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=None, + rank0_only=getattr(state_dict_settings.optim_state_dict_config, + 'rank0_only', False), + full_state_dict=state_dict_settings.state_dict_type == + StateDictType.FULL_STATE_DICT, + group=group, + ) + + @staticmethod + def set_state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> StateDictSettings: + """copied from pytorch 2.0.1 which has fixed some bugs.""" + import torch.distributed.fsdp._traversal_utils as traversal_utils + _state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, + } + _optim_state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, + } + + # Use the default config if a state_dict config is not set. + state_dict_config_type = _state_dict_type_to_config[ + state_dict_type] + optim_state_dict_config_type = _optim_state_dict_type_to_config[ + state_dict_type] + if state_dict_config is None: + state_dict_config = state_dict_config_type() + if optim_state_dict_config is None: + optim_state_dict_config = optim_state_dict_config_type() + if state_dict_config_type != type(state_dict_config): + raise RuntimeError('Expected state_dict_config of type ' + f'{state_dict_config_type} ' + f'but got {type(state_dict_config)}') + if optim_state_dict_config_type != type(optim_state_dict_config): + raise RuntimeError('Expected optim_state_dict_config of type ' + f'{optim_state_dict_config_type} ' + f'but got {type(optim_state_dict_config)}') + + # Set the state_dict type and configurations. + prev_state_dict_type = None + prev_state_dict_config = None + prev_optim_state_dict_config = None + for submodule in traversal_utils._get_fsdp_states(module): + if prev_state_dict_type is None: + prev_state_dict_type = submodule._state_dict_type + else: + assert ( + prev_state_dict_type == submodule._state_dict_type + ), 'All FSDP modules should have the same state_dict_type.' + if prev_state_dict_config is None: + prev_state_dict_config = submodule._state_dict_config + else: + assert isinstance( + submodule._state_dict_config, + type(prev_state_dict_config)), ( + 'All FSDP modules must have the same type of ' + 'state_dict_config.') + if prev_optim_state_dict_config is None: + prev_optim_state_dict_config = \ + submodule._optim_state_dict_config + else: + assert isinstance( + submodule._optim_state_dict_config, + type(prev_optim_state_dict_config), + ), ('All FSDP modules must have the same type of ' + 'optim_state_dict_config.') + + submodule._state_dict_type = state_dict_type + submodule._state_dict_config = state_dict_config + submodule._optim_state_dict_config = optim_state_dict_config + + return StateDictSettings(prev_state_dict_type, + prev_state_dict_config, + prev_optim_state_dict_config) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 46eb96c0..7a82d166 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -48,6 +48,10 @@ class AmpOptimWrapper(OptimWrapper): `'float64'`. If set to ``None``, the default data type will be used. Defaults to None. `New in version 0.6.1.` + use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should + be enabled when using ``FullyShardedDataParallel``. + Defaults to False. + `New in version 0.8.0.` **kwargs: Keyword arguments passed to OptimWrapper. Warnings: @@ -65,6 +69,7 @@ class AmpOptimWrapper(OptimWrapper): def __init__(self, loss_scale: str = 'dynamic', dtype: Union[str, torch.dtype] = None, + use_fsdp: bool = False, **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') @@ -73,17 +78,29 @@ class AmpOptimWrapper(OptimWrapper): 'on gpu, npu or mlu') super().__init__(**kwargs) self._scale_update_param = None + + if use_fsdp: + if digit_version(torch.__version__) >= digit_version('2.0.0'): + from torch.distributed.fsdp.sharded_grad_scaler import \ + ShardedGradScaler + scaler_type = ShardedGradScaler + else: + raise RuntimeError( + 'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') + else: + scaler_type = GradScaler + if loss_scale == 'dynamic': # If loss_scale is a string, it must be 'dynamic', then dynamic # loss scaling will be used. - self.loss_scaler = GradScaler() + self.loss_scaler = scaler_type() elif isinstance(loss_scale, float): # Static loss scaling self._scale_update_param = loss_scale - self.loss_scaler = GradScaler(init_scale=loss_scale) + self.loss_scaler = scaler_type(init_scale=loss_scale) elif isinstance(loss_scale, dict): # More specific configuration. - self.loss_scaler = GradScaler(**loss_scale) + self.loss_scaler = scaler_type(**loss_scale) else: raise TypeError('loss_scale must be of type float, dict, or ' f'"dynamic", but got {loss_scale}') diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 7a76ebbc..6a59d5c5 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -632,6 +632,18 @@ class FlexibleRunner: assert isinstance(strategy, dict) + # train_micro_batch_size_per_gpu is required by DeepSpeed + if isinstance(strategy['type'], str): + strategy_name = strategy['type'] + else: + strategy_name = strategy['type'].__name__ + if strategy_name == 'DeepSpeedStrategy': + if self._train_dataloader is None: + strategy['train_micro_batch_size_per_gpu'] = 1 + else: + strategy['train_micro_batch_size_per_gpu'] = \ + _get_batch_size(self._train_dataloader) + strategy.setdefault('work_dir', self._work_dir) strategy.setdefault('experiment_name', experiment_name) strategy.setdefault('auto_scale_lr', self._auto_scale_lr) @@ -1140,18 +1152,13 @@ class FlexibleRunner: compile = copy.copy(self._compile) compile.setdefault('target', 'train_step') - if self.train_dataloader.batch_size is not None: - micro_batch_size = self.train_dataloader.batch_size - else: - micro_batch_size = self.train_dataloader.batch_sampler.batch_size dispatch_kwargs = dict( - train_micro_batch_size_per_gpu=micro_batch_size, - num_batches_per_epoch=len(self.train_dataloader), + epoch_length=len(self.train_dataloader), max_epochs=self.max_epochs, max_iters=self.max_iters, ) - result = self.strategy.prepare( + self.strategy.prepare( self.model, optim_wrapper=self.optim_wrapper, param_scheduler=self.param_schedulers, @@ -1159,10 +1166,10 @@ class FlexibleRunner: dispatch_kwargs=dispatch_kwargs, ) + self.model = self.strategy.model + self.optim_wrapper = self.strategy.optim_wrapper # type: ignore if self.param_schedulers is not None: - self.model, self.optim_wrapper, self.param_schedulers, *_ = result - else: - self.model, self.optim_wrapper, *_ = result + self.param_schedulers = self.strategy.param_schedulers self.load_or_resume() @@ -1187,7 +1194,11 @@ class FlexibleRunner: self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - self.model = self.strategy.prepare(self.model) + dispatch_kwargs = dict( + init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) + self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) + self.model = self.strategy.model self.load_or_resume() @@ -1210,8 +1221,11 @@ class FlexibleRunner: '`test_evaluator` arguments when initializing runner.') self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - - self.model = self.strategy.prepare(self.model) + dispatch_kwargs = dict( + init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) + self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) + self.model = self.strategy.model self.load_or_resume() @@ -1613,3 +1627,20 @@ class FlexibleRunner: if self.cfg._cfg_dict: self.logger.info(f'Config:\n{self.cfg.pretty_text}') + + +def _get_batch_size(dataloader): + if isinstance(dataloader, dict): + if 'batch_size' in dataloader: + return dataloader['batch_size'] + elif ('batch_sampler' in dataloader + and 'batch_size' in dataloader['batch_sampler']): + return dataloader['batch_sampler']['batch_size'] + else: + raise ValueError('Please set batch_size in `Dataloader` or ' + '`batch_sampler`') + elif isinstance(dataloader, DataLoader): + return dataloader.batch_sampler.batch_size + else: + raise ValueError('dataloader should be a dict or a Dataloader ' + f'instance, but got {type(dataloader)}') diff --git a/mmengine/testing/runner_test_case.py b/mmengine/testing/runner_test_case.py index a05c41d3..f64594ac 100644 --- a/mmengine/testing/runner_test_case.py +++ b/mmengine/testing/runner_test_case.py @@ -36,14 +36,14 @@ class ToyModel(BaseModel): if isinstance(inputs, list): inputs = torch.stack(inputs) if isinstance(data_samples, list): - data_sample = torch.stack(data_samples) + data_samples = torch.stack(data_samples) outputs = self.linear1(inputs) outputs = self.linear2(outputs) if mode == 'tensor': return outputs elif mode == 'loss': - loss = (data_sample - outputs).sum() + loss = (data_samples - outputs).sum() outputs = dict(loss=loss) return outputs elif mode == 'predict': diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index eabe10ea..cd3e539f 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -8,7 +8,7 @@ import torch.distributed as torch_dist import torch.nn as nn from torch.optim import SGD -from mmengine.dist import all_gather +from mmengine.dist import all_gather, broadcast from mmengine.model import (BaseDataPreprocessor, BaseModel, ExponentialMovingAverage, MMDistributedDataParallel, @@ -222,8 +222,8 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): @unittest.skipIf( torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp') @unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('1.11.0'), - reason='fsdp needs Pytorch 1.11 or higher') + digit_version(TORCH_VERSION) < digit_version('2.0.0'), + reason='fsdp needs Pytorch 2.0.0 or higher') class TestMMFullyShardedDataParallel(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -247,19 +247,38 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase): model = ToyModel() fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) optimizer = SGD(fsdp_model.parameters(), lr=0) - optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) + optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 - data = dict(inputs=[inputs], data_sample=MagicMock()) + data = dict(inputs=inputs, data_sample=MagicMock()) fsdp_model.train() self.assertTrue(fsdp_model.training) fsdp_model.train_step(data, optim_wrapper=optim_wrapper) + # require_grad=False + model = ToyModel() + for _, param in model.state_dict().items(): + broadcast(param) + model.conv1.requires_grad_(False) + ori_weight = model.conv1.weight.clone() + fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) + optimizer = SGD(fsdp_model.parameters(), lr=0.1) + optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=MagicMock()) + fsdp_model.train() + self.assertTrue(fsdp_model.training) + fsdp_model.train_step(data, optim_wrapper=optim_wrapper) + + with fsdp_model.summon_full_params(fsdp_model): + updated_weight = fsdp_model.module.conv1.weight.cpu() + assert_allclose(ori_weight, updated_weight) + def test_val_step(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel() fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 - data = dict(inputs=[inputs], data_sample=MagicMock()) + data = dict(inputs=inputs, data_sample=MagicMock()) # Test get predictions. predictions = fsdp_model.val_step(data) self.assertIsInstance(predictions, torch.Tensor) diff --git a/tests/test_strategies/test_fsdp.py b/tests/test_strategies/test_fsdp.py new file mode 100644 index 00000000..64b900d2 --- /dev/null +++ b/tests/test_strategies/test_fsdp.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from tempfile import TemporaryDirectory +from unittest import TestCase, skipIf + +import torch +import torch.nn as nn + +try: + from torch.distributed.fsdp import (FullStateDictConfig, + FullyShardedDataParallel, + LocalStateDictConfig, StateDictType) + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, LocalOptimStateDictConfig) + + from mmengine._strategy import FSDPStrategy +except: # noqa: E722 + pass +from torch.multiprocessing.spawn import start_processes +from torch.optim import SGD + +from mmengine.dist import (all_gather_object, broadcast_object_list, + is_main_process) +from mmengine.optim import LinearLR, OptimWrapper +from mmengine.testing.runner_test_case import ToyModel +from mmengine.utils import digit_version + + +def linear_wrap_policy( + module, + recurse, + nonwrapped_numel, +) -> bool: + if recurse: + return True # always recurse + return isinstance(module, nn.Linear) + + +@skipIf( + digit_version(torch.__version__) < digit_version('2.0.0') + or not torch.cuda.is_available(), + 'Only test FSDP with CUDA and PyTorch >= 2.0.0') +class TestStrategy(TestCase): + + def setUp(self): + self.world_size = 2 + self.temp_dir = TemporaryDirectory() + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_init(self): + strategy = FSDPStrategy() + self.assertFalse(strategy.skip_init_weights) + strategy = FSDPStrategy(state_dict_cfg='local') + self._assert_local(strategy) + + strategy = FSDPStrategy(state_dict_cfg='full') + self._assert_full(strategy) + + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=StateDictType.LOCAL_STATE_DICT)) + self._assert_local(strategy) + + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), + optim_state_dict_config=FullOptimStateDictConfig(), + )) + self._assert_full(strategy) + + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type='FULL_STATE_DICT', + state_dict_config=dict(type='FullStateDictConfig'), + optim_state_dict_config=dict(type='FullOptimStateDictConfig'), + )) + self._assert_full(strategy) + + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) + self._assert_full(strategy) + + with self.assertRaises(ValueError): + strategy = FSDPStrategy(state_dict_cfg='error-str') + + # state_dict_cfg should be a str or a dict + with self.assertRaises(TypeError): + strategy = FSDPStrategy(state_dict_cfg=[]) + + # state_dict_type must be a str or a enumerate of StateDictType + with self.assertRaises(TypeError): + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=[], + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict( + type=FullOptimStateDictConfig), + )) + + # state_dict_config should be a dict or a subclass of StateDictConfig + with self.assertRaises(TypeError): + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=[], + optim_state_dict_config=dict( + type=FullOptimStateDictConfig), + )) + + # optim_state_dict_config should be a dict or a subclass of + # OptimStateDictConfig + with self.assertRaises(TypeError): + strategy = FSDPStrategy( + state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=[], + )) + + def run_strategy(self): + # Strategy can run with the built model, optimizer and schedulers. + for skip_init_weights, state_dict_cfg in [(True, 'local'), + (False, 'full')]: + strategy = FSDPStrategy( + skip_init_weights=skip_init_weights, + state_dict_cfg=state_dict_cfg, + model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + model = ToyModel() + optim = OptimWrapper(SGD(model.parameters(), lr=0.1, momentum=0.9)) + lr_scheduler = LinearLR(optimizer=optim) + model, optim, lr_scheduler = strategy.prepare( + model=model, optim_wrapper=optim, param_scheduler=lr_scheduler) + self.assertIsInstance(model, FullyShardedDataParallel) + self.assertIsInstance(model.linear1, FullyShardedDataParallel) + self.assertIsInstance(model.linear2, FullyShardedDataParallel) + + data = torch.ones(2, 2).cuda() + data_samples = torch.zeros(2, 2).cuda() + loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss.backward() + optim.step() + [scheduler.step() for scheduler in lr_scheduler] + + ckpt_path = osp.join(self.temp_dir.name, + f'checkpoint_{state_dict_cfg}.pth') + strategy.save_checkpoint(ckpt_path) + + if state_dict_cfg == 'full': + if not is_main_process(): + self.assertFalse(osp.exists(ckpt_path)) + ckpt_path = [ckpt_path] + broadcast_object_list(ckpt_path) + ckpt_path = ckpt_path[0] + + strategy.load_checkpoint(ckpt_path) + loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss.backward() + optim.step() + [scheduler.step() for scheduler in lr_scheduler] + + # optimizer with multiple param_groups can be reconstructed. + model = ToyModel() + strategy = FSDPStrategy( + model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + param_groups = [] + for param in model.parameters(): + param_groups.append(dict(params=[param], lr=0.1)) + optim = SGD(param_groups, lr=0.1, momentum=0.9) + lr_scheduler = LinearLR(optimizer=optim) + model, optim, lr_scheduler = strategy.prepare( + model=model, optim_wrapper=optim, param_scheduler=lr_scheduler) + data = torch.ones(2, 2).cuda() + data_samples = torch.zeros(2, 2).cuda() + loss = model(data, data_samples=data_samples, mode='loss')['loss'] + loss.backward() + optim.step() + [scheduler.step() for scheduler in lr_scheduler] + optim_state = optim.state_dict()['state'] + optim_state = all_gather_object(optim_state) + + @classmethod + def _worker(cls, rank, func): + # local mode + self = cls() + self.setUp() + self.rank = rank + + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(self.world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(12123) + torch.cuda.set_device(f'cuda:{rank}') + + getattr(self, func)() + self.tearDown() + + def test_run_strategy(self): + start_processes( + TestStrategy._worker, + args=('run_strategy', ), + nprocs=self.world_size) + + def test_build_model(self): + ... + # TODO + # strategy = FSDPStrategy() + # model = ToyModel() + # state_dict = dict() + + def _assert_local(self, strategy): + self.assertEqual(strategy.state_dict_type, + StateDictType.LOCAL_STATE_DICT) + self.assertIsInstance(strategy.state_dict_config, LocalStateDictConfig) + self.assertIsInstance(strategy.optim_state_dict_config, + LocalOptimStateDictConfig) + + def _assert_full(self, strategy): + self.assertEqual(strategy.state_dict_type, + StateDictType.FULL_STATE_DICT) + self.assertIsInstance(strategy.state_dict_config, FullStateDictConfig) + self.assertIsInstance(strategy.optim_state_dict_config, + FullOptimStateDictConfig)