Move data_preprocessor to target device in FSDPStrategy (#1261)
parent
f4f2555324
commit
6187595677
|
@ -19,8 +19,9 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import get_rank, is_main_process
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.model import BaseDataPreprocessor, is_model_wrapper
|
||||
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
|
||||
OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
|
@ -118,6 +119,10 @@ class FSDPStrategy(DDPStrategy):
|
|||
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
|
||||
or subclass of ``FullyShardedDataParallel``.
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseDataPreprocessor):
|
||||
module.to(get_device())
|
||||
|
||||
if is_model_wrapper(model):
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in New Issue