mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Fix typo in FSDP (#569)
This commit is contained in:
parent
8864bd88d7
commit
da38b4113d
@ -12,7 +12,7 @@ from mmengine.registry import MODEL_WRAPPERS, Registry
|
|||||||
from mmengine.structures import BaseDataElement
|
from mmengine.structures import BaseDataElement
|
||||||
|
|
||||||
# support customize fsdp policy
|
# support customize fsdp policy
|
||||||
FSDP_WRAP_POLICYS = Registry('fsdp wrap policy')
|
FSDP_WRAP_POLICIES = Registry('fsdp wrap policy')
|
||||||
|
|
||||||
|
|
||||||
@MODEL_WRAPPERS.register_module()
|
@MODEL_WRAPPERS.register_module()
|
||||||
@ -60,7 +60,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
users' pre-defined config in MMEngine, its type is expected to be
|
users' pre-defined config in MMEngine, its type is expected to be
|
||||||
`None`, `str` or `Callable`. If it's `str`, then
|
`None`, `str` or `Callable`. If it's `str`, then
|
||||||
MMFullyShardedDataParallel will try to get specified method in
|
MMFullyShardedDataParallel will try to get specified method in
|
||||||
``FSDP_WRAP_POLICYS`` registry,and this method will be passed to
|
``FSDP_WRAP_POLICIES`` registry,and this method will be passed to
|
||||||
FullyShardedDataParallel to finally initialize model.
|
FullyShardedDataParallel to finally initialize model.
|
||||||
|
|
||||||
Note that this policy currently will only apply to child modules of
|
Note that this policy currently will only apply to child modules of
|
||||||
@ -122,10 +122,10 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
|
|
||||||
if fsdp_auto_wrap_policy is not None:
|
if fsdp_auto_wrap_policy is not None:
|
||||||
if isinstance(fsdp_auto_wrap_policy, str):
|
if isinstance(fsdp_auto_wrap_policy, str):
|
||||||
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICYS, \
|
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \
|
||||||
'`FSDP_WRAP_POLICYS` has no ' \
|
'`FSDP_WRAP_POLICIES` has no ' \
|
||||||
f'function {fsdp_auto_wrap_policy}'
|
f'function {fsdp_auto_wrap_policy}'
|
||||||
fsdp_auto_wrap_policy = FSDP_WRAP_POLICYS.get( # type: ignore
|
fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore
|
||||||
fsdp_auto_wrap_policy)
|
fsdp_auto_wrap_policy)
|
||||||
if not isinstance(fsdp_auto_wrap_policy,
|
if not isinstance(fsdp_auto_wrap_policy,
|
||||||
Callable): # type: ignore
|
Callable): # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user