diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index bb4cf2df..7b91d30f 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -826,13 +826,14 @@ class BaseStrategy(metaclass=ABCMeta): ) -> Optional[dict]: """Load checkpoint or resume from checkpoint. - Keyword Args: load_from (str, optional): The checkpoint file to - load from. Defaults to None. resume (bool or str): Whether - to resume training. Defaults to False. If ``resume`` is True - and ``load_from`` is None, automatically to find latest - checkpoint from ``work_dir``. If not found, resuming does - nothing. If ``resume`` is a string, it will be treated as the - checkpoint file to resume from. + Args: + load_from (str, optional): The checkpoint file to load from. + Defaults to None. + resume (bool or str): Whether to resume training. Defaults to + False. If ``resume`` is True and ``load_from`` is None, + automatically to find latest checkpoint from ``work_dir``. + If not found, resuming does nothing. If ``resume`` is a string, + it will be treated as the checkpoint file to resume from. """ from mmengine.runner import find_latest_checkpoint diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 1f856ed7..a14001a0 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -42,7 +42,7 @@ FSDP_CONFIGS.register_module(module=LocalStateDictConfig) class FSDPStrategy(DDPStrategy): """Support training model with FullyShardedDataParallel (FSDP). - Keyword Args:: + Keyword Args: model_wrapper (dict, optional): Config dict for model wrapper. The default configuration is: diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index cd3e539f..791630c5 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -19,7 +19,7 @@ from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version -if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): +if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): from mmengine.model import MMFullyShardedDataParallel # noqa: F401