Fix docstring format (#1223)

This commit is contained in:
Mashiro 2023-06-30 10:39:19 +08:00 committed by GitHub
parent 399f76ffa8
commit f930b9fe53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 9 deletions

View File

@ -826,13 +826,14 @@ class BaseStrategy(metaclass=ABCMeta):
) -> Optional[dict]: ) -> Optional[dict]:
"""Load checkpoint or resume from checkpoint. """Load checkpoint or resume from checkpoint.
Keyword Args: load_from (str, optional): The checkpoint file to Args:
load from. Defaults to None. resume (bool or str): Whether load_from (str, optional): The checkpoint file to load from.
to resume training. Defaults to False. If ``resume`` is True Defaults to None.
and ``load_from`` is None, automatically to find latest resume (bool or str): Whether to resume training. Defaults to
checkpoint from ``work_dir``. If not found, resuming does False. If ``resume`` is True and ``load_from`` is None,
nothing. If ``resume`` is a string, it will be treated as the automatically to find latest checkpoint from ``work_dir``.
checkpoint file to resume from. If not found, resuming does nothing. If ``resume`` is a string,
it will be treated as the checkpoint file to resume from.
""" """
from mmengine.runner import find_latest_checkpoint from mmengine.runner import find_latest_checkpoint

View File

@ -42,7 +42,7 @@ FSDP_CONFIGS.register_module(module=LocalStateDictConfig)
class FSDPStrategy(DDPStrategy): class FSDPStrategy(DDPStrategy):
"""Support training model with FullyShardedDataParallel (FSDP). """Support training model with FullyShardedDataParallel (FSDP).
Keyword Args:: Keyword Args:
model_wrapper (dict, optional): Config dict for model wrapper. The model_wrapper (dict, optional): Config dict for model wrapper. The
default configuration is: default configuration is:

View File

@ -19,7 +19,7 @@ from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version from mmengine.utils.version_utils import digit_version
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
from mmengine.model import MMFullyShardedDataParallel # noqa: F401 from mmengine.model import MMFullyShardedDataParallel # noqa: F401