Move data_preprocessor to target device in FSDPStrategy (#1261)

pull/1276/head
Mashiro 2023-07-24 10:42:53 +08:00 committed by GitHub
parent f4f2555324
commit 6187595677
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -19,8 +19,9 @@ from torch.optim.lr_scheduler import LRScheduler
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.device import get_device
from mmengine.dist import get_rank, is_main_process
from mmengine.model import is_model_wrapper
from mmengine.model import BaseDataPreprocessor, is_model_wrapper
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
@ -118,6 +119,10 @@ class FSDPStrategy(DDPStrategy):
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``.
"""
for module in model.modules():
if isinstance(module, BaseDataPreprocessor):
module.to(get_device())
if is_model_wrapper(model):
return