Fix typo in FSDP (#569)
parent
8864bd88d7
commit
da38b4113d
|
@ -12,7 +12,7 @@ from mmengine.registry import MODEL_WRAPPERS, Registry
|
|||
from mmengine.structures import BaseDataElement
|
||||
|
||||
# support customize fsdp policy
|
||||
FSDP_WRAP_POLICYS = Registry('fsdp wrap policy')
|
||||
FSDP_WRAP_POLICIES = Registry('fsdp wrap policy')
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
|
@ -60,7 +60,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
users' pre-defined config in MMEngine, its type is expected to be
|
||||
`None`, `str` or `Callable`. If it's `str`, then
|
||||
MMFullyShardedDataParallel will try to get specified method in
|
||||
``FSDP_WRAP_POLICYS`` registry,and this method will be passed to
|
||||
``FSDP_WRAP_POLICIES`` registry,and this method will be passed to
|
||||
FullyShardedDataParallel to finally initialize model.
|
||||
|
||||
Note that this policy currently will only apply to child modules of
|
||||
|
@ -122,10 +122,10 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
|
||||
if fsdp_auto_wrap_policy is not None:
|
||||
if isinstance(fsdp_auto_wrap_policy, str):
|
||||
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICYS, \
|
||||
'`FSDP_WRAP_POLICYS` has no ' \
|
||||
assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \
|
||||
'`FSDP_WRAP_POLICIES` has no ' \
|
||||
f'function {fsdp_auto_wrap_policy}'
|
||||
fsdp_auto_wrap_policy = FSDP_WRAP_POLICYS.get( # type: ignore
|
||||
fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore
|
||||
fsdp_auto_wrap_policy)
|
||||
if not isinstance(fsdp_auto_wrap_policy,
|
||||
Callable): # type: ignore
|
||||
|
|
Loading…
Reference in New Issue